#! /usr/bin/python
# -*- coding: utf-8 -*-

# Folgert Karsdorp <fbkarsdorp@gmail.com>
# 2010

"""
Tree class voor DOP-parser. Class is met name bedoeld voor het maken van
een PCFG-grammatica op basis van DOP-bomen.
"""

#///////////////////////////////////////////////////////////////////////////////
import re, copy, uuid, time
from nltk.probability import ImmutableProbabilisticMixIn
from dop_grammar import WeightedGrammar, Nonterminal, Production, WeightedProduction
from dop_pchart import *
#///////////////////////////////////////////////////////////////////////////////

class DopTree(list):
    """
    Tree class. Op basis van celex_parses worden hierarchische lijsten van
    tree-objecten gemaakt.

    @leaves Geeft alle terminal-leaves terug uit de boom
    @subtrees Geeft alle mogelijke subbomen uit de afleiding
    @treepositions Geeft alle posities in de boom terug
    (terminals en non-terminals).
    @frontier Geeft alle mogelijke permutaties voor een bepaalde subboom.
    """
    def __init__(self, node, rhs=None):
        if rhs==None: return
        
        list.__init__(self, rhs)
        self._node = node

    def node(self): return self._node

    def __getitem__(self, index):
        if isinstance(index, (int, slice)):
            return list.__getitem__(self, index)
        else:
            if len(index) == 0:
                return self
            elif len(index) == 1:
                return self[int(index[0])]
            else:
                return self[int(index[0])][index[1:]]

    def __setitem__(self, index, value):
        if isinstance(index, (int, slice)):
            return list.__setitem__(self, index, value)
        else:
            if len(index) == 0:
                raise IndexError('The tree position () may not be '
                                 'assigned to.')
            elif len(index) == 1:
                self[index[0]] = value
            else:
                self[index[0]][index[1:]] = value

    def __delitem__(self, index):
        if isinstance(index, (int, slice)):
            return list.__delitem__(self, index)
        else:
            if len(index) == 0:
                raise IndexError('The tree position () may not be deleted.')
            elif len(index) == 1:
                del self[index[0]]
            else:
                del self[index[0]][index[1:]]

    def __repr__(self):
        childstr = ", ".join(repr(c) for c in self)
        return '%s(%r, [%s])' % (self.__class__.__name__,
                                 self._node, childstr)
    
    def __str__(self):
        return self.pprint_flat()

    def __eq__(self, other):
        if not isinstance(other, DopTree): return False
        return self._node == other._node and list.__eq__(self, other)

    def __ne__(self, other):
        return not (self==other)
        
    def __cmp__(self, other):
        try:
            return cmp(self, other)
        except:
            return -1

    def leaves(self):
        """
        Geeft alle terminal leaves van een boom terug.
        """
        leaves = []
        for child in self:
            if isinstance(child, DopTree):
                leaves.extend(child.leaves())
            else:
                leaves.append(child)
        return leaves

    def subtrees(self, filter=None):
        """
        Haal alle subbomen uit een parse. Deze subbomen bevatten nog geen
        lege knopen. Uit een parse als:

        (N (N (N post) (N zegel)) (N (V (Aff ver) (V koop)) (Aff er)))

        kunnen we de volgende subbomen halen:

        1. (N (N (N post) (N zegel)) (N (V (Aff ver) (V koop)) (Aff er)))
        2. (N (N post) (N zegel))
        3. (N post)
        4. (N zegel)
        5. (N (V (Aff ver) (V koop)) (Aff er))
        6. (V (Aff ver) (V koop)
        7. (Aff ver)
        8. (V koop)
        9. (Aff er)

        De functie geeft een Generator-object terug.
        
        """
        if not filter or filter(self):
            yield self
        for child in self:
            if isinstance(child, DopTree):
                for subtree in child.subtrees(filter):
                    yield subtree
                    
    def treepositions(self):
        """ De posities van van knopen in de boom in tuple-formaat """
        positions = []
        positions.append( () )
        for i, child in enumerate(self):
            if isinstance(child, DopTree):
                childpos = child.treepositions()
                positions.extend((i,)+p for p in childpos)
            else:
                positions.append( (i,) )
        return positions

    def rec_frontier(self):
        """Hulp functie voor frontier()"""
        for pos in self.treepositions():
            if pos != (): # geen root nodes
                if isinstance(self[pos], (str)): # terminals
                    f = copy.deepcopy(self)
                    f[pos] = ''
                    yield f
                elif isinstance(self[pos], DopTree): # non-terminals
                    f = copy.deepcopy(self)
                    f[pos][:] = ['']
                    yield f

    def frontier(self):
        """ Frontier-operatie. Op basis van de indeces gegeven door
        self.treepositions worden alle permutaties van de boom opgevraagd. Voorbeeld:
        - (N (N loon) (V (Aff be) (V heers) (Aff ing)))
        - (N (N ) (V (Aff be) (V heers) (Aff ing)))
        - (N (N loon) (V ))
        - (N (N loon) (V (Aff ) (V heers) (Aff ing)))
          etc.
        """
        if len(self) == 1: return self
        else:
            i = 0
            trees = [self]
            while i < len(trees):
                trees.extend(t for t in trees[i].rec_frontier() if t not in trees)
                i += 1
            return trees
    
    def print_frontier(self):
        """ Generator voor alle permutaties voor alle subbomen van een bepaalde afleiding """
        t1 = time.time()
        bomen = []
        for tree in self.subtrees():
            trees = copy.deepcopy(tree.frontier())
            if len(trees) > 1:
                bomen.extend(t for t in trees)
            else:
                bomen.append(trees)
        print time.time() - t1, len(bomen)
        return bomen#, 

    def goodman_reduction(self):
        """
        Geeft alle niet-root en nonterminals een unieke code. Dat zijn
        dus alle internal nodes. Een boom als:

        (N (V (Aff ver) (V koop)) (Aff ))
        wordt:
        (N (V1 (Aff2 ver) (V3 koop)) (Aff ))

        Hiermee zorgen we dat in de herschrijfregels (Chomsky normal form)
        (1) de waarschijnlijkheid van een boom behouden blijft (alle unieke
        regels krijgen dezelfde waarschijnlijkheid 1.0 en (2) dat het boomkarakter
        behouden blijft aangezien een unieke regel als Aff2 --> ver alleen in deze
        parse voorkomt.
        """
        for pos in self.treepositions():
            if pos != ():
                if isinstance(self[pos], DopTree):
                    if self[pos][0] != '':
                        self[pos]._node = (self[pos]._node, str(uuid.uuid4()))
        return self

    def productions(self):
        """
        Geeft voor elke subboom met frontieroperatie de Chomsky Normal Form
        herschrijfregels terug.
        """
        prods = []
        if self.leaves() != ['']:
            prods += [Production(Nonterminal(self._node), child_names(self))]
        for child in self:
            if isinstance(child, DopTree):
                prods += child.productions()
        return prods

    def pprint_flat(self, nodesep='', parens='()', quotes=False):
        """
        Print boom in formaat:
        
        (N (N nest) (V (Aff be) (N scherm)) (Aff er))
        
        """
        childstrs = []
        for child in self:
            if isinstance(child, DopTree):
                childstrs.append(child.pprint_flat(nodesep,
                                                   parens, quotes))
            elif isinstance(child, tuple):
                childstrs.append("/".join(child))
            elif isinstance(child, str) and not quotes:
                childstrs.append('%s' % child)
            else:
                childstrs.append('%r' % child)
        if isinstance(self._node, basestring):
            return '%s%s%s %s%s' % (parens[0], self._node, nodesep, 
                                    ' '.join(childstrs), parens[1])
        else:
            return '%s%r%s %s%s' % (parens[0], self._node, nodesep, 
                                    ' '.join(childstrs), parens[1])

    def pprint_latex_qtree(self):
        """
        Functie om een latex-output te geven (Afhankelijkheid LaTeX = qtree)
        """
        return r'\Tree ' + self.pprint_flat(nodesep='', parens=('[.', ' ]'))

class groupby(dict):
    """
    Groepeer-class. Groepeert een lijst als ['a', 'b', 'a', 'c', 'c'] in
    afzonderlijke lijsten met identieke items: [['a', 'a'], ['b'], ['c', 'c']
    """
    def __init__(self, seq, key=lambda x:x):
        for value in seq:
            k = key(value)
            self.setdefault(str(k), []).append(value)
    __iter__ = dict.iteritems

def get_frontiers(productions):
    """Extraheer voor alle parses alle frontier-bomen + subbomen"""
    return [p for production in productions for p in production.print_frontier()]

def get_productions(frontiers):
    """Geef Goodman-trees terug voor alle frontier-bomen"""
    sorted_frontier_lists = [g for k, g in groupby(frontiers)]
    for sorted_frontier_list in sorted_frontier_lists:
        for goodman_frontier in get_goodman(sorted_frontier_list):
            yield goodman_frontier

def get_goodman(sorted_frontier_list):
    """Geef identieke bomen dezelfde unieke codes"""
    frontier_list = sorted_frontier_list[0].goodman_reduction()
    for i in range(len(sorted_frontier_list)):
        sorted_frontier_list[i] = frontier_list
    return sorted_frontier_list

def make_productions(goodman_frontiers):
    """Geef Chomsky herschrijfregels terug van alle bomen."""
    return [goodman for goodman_frontier in goodman_frontiers
            for goodman in goodman_frontier.productions()]

def child_names(tree):
    """Geef alle knoopnamen terug uit een boom."""
    names = []
    for child in tree:
        if isinstance(child, DopTree):
            names.append(Nonterminal(child._node))
        else:
            names.append(child)
    return names

def induce_probabilities(productions):
    """
    Geeft de waarschijnlijkheid van een boom in production-rewrite-formaat.
    Input is lijst van productions per tree.
    """
    t1 = time.time()
    pcount = {}
    lcount = {}

    for prod in productions:
        lcount[prod.lhs()] = lcount.get(prod.lhs(), 0) + 1
        pcount[prod] = pcount.get(prod, 0) + 1
        
    prods = [WeightedProduction(p.lhs(), p.rhs(),
                                prob=float(pcount[p]) / lcount[p.lhs()])
             for p in pcount]

    N = Nonterminal('W') # ROOT NODE
    return WeightedGrammar(N, prods)#, time.time() - t1

def parse(string):
    """
    Maak van string een lijst van tree-classes. Inputstrings zijn van het formaat:

    postzegelverkoper -->
    (N (N (N post) (N zegel)) (N (V (Aff ver) (V koop)) (Aff er)))

    De output van celex_parser.pprint_parse is in dit formaat. 
    """
    open_b, close_b = '()'
    open_pattern, close_pattern = (re.escape(open_b),
                                   re.escape(close_b))
    node_pattern = '[^\s%s%s]+' % (open_pattern, close_pattern)
    leaf_pattern = '[^\s%s%s]+' % (open_pattern, close_pattern)
    token_re = re.compile('%s\s*(%s)?|%s|(%s)' % (
        open_pattern, node_pattern, close_pattern, leaf_pattern))

    stack = [(None, [])] # lijst van (node, children) tuples
    for match in token_re.finditer(string):
        token = match.group()
#        stack.append(('W', []))
        if token[0] == open_b:
            node = token[1:].lstrip() # haal spaties aan linkerkant weg.
            stack.append((node, []))

        elif token == close_b:
            node, children = stack.pop()
            stack[-1][1].append(DopTree(node, children))
        else:
            stack[-1][1].append(token)
    tree = stack[0][1][0]
    return tree


def main():
    oefenwoorden = ['(A (Aff on) (A (V dank) (Aff baar)))',
                    '(A (Aff on) (A (V vind) (Aff baar)))',
                    '(N (V (Aff be) (V vind)) (Aff ing))',
                    '(N (N (N (P aan) (N deel)) (V houd) (Aff er)) (Aff s) (N (V vergader) (Aff ing)))',
                    '(N (N (P aan) (N (Aff ge) (N zicht))) (Aff s) (N (V lig) (Aff ing)))',
                    '(N (N (A (V (P aan) (V spreek)) (Aff elijk)) (Aff heid)) (Aff s) (N (V (Aff ver) (A zeker)) (Aff ing)))']
    from corpus import corpus
    #oefenwoorden = corpus
    oefenwoorden = ['(W (JJ ((((P en) (V konduk)) (V it)) a) j))',
			'(W (NN ((((A mal) (J riĉ)) (A eg)) (A ul)) o))']
    oefenwoorden = """(S (NP John) (VP (V likes) (NP Mary)))
(S (NP Peter) (VP (V hates) (NP Susan)))
(S (NP Harry) (VP (V eats) (NP pizza)))
(S (NP Hermione) (VP (V eats)))""".split('\n')
        #(S (NP Harry) (VP (V likes) (NP Susan) (ADVP (RB very) (RB much))))
    productions = [parse(woord) for woord in oefenwoorden]
    print "parsed"
    
    frontiers = get_frontiers(productions)
    print "frontiers"

    prods = get_productions(frontiers)
    print "got productions"

    produs = make_productions(prods)
    print "made productions"

    grammar = induce_probabilities(produs)
    print "induced probabilities"

    inside_parser = InsideChartParser(grammar)
    print "got parser"

    #inside_parser.trace(3)
    while 1:
        print "sentence:",
	a = raw_input()
	try:
	   for tree in inside_parser.nbest_parse(a.split()):
	       print tree
        except Exception as e:
           print 'error', e

if __name__ == '__main__':
    main()

