"""
A natural language parser for PLCFRS (probabilistic linear context-free
rewriting systems). PLCFRS is an extension of context-free grammar which
rewrites tuples of strings instead of strings; this allows it to produce
parse trees with discontinuous constituents.

Copyright 2011 Andreas van Cranenburgh <andreas@unstable.nl>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

from sys import argv, stderr
from math import exp, log
from array import array
from itertools import chain
from heapq import heappush, heappop, heapify
#from bit import nextset, nextunset

def parse(sent, grammar, tags, start, exhaustive): # sent: [list(str)], grammar: [Grammar], tags: [list(str)], start: [int], exhaustive: [bool]
    """ parse sentence, a list of tokens, optionally with gold tags, and
    produce a chart, either exhaustive or up until the viterbi parse.
    """
    unary = grammar.unary                # [list(list(Rule))]
    lbinary = grammar.lbinary            # [list(list(Rule))]
    rbinary = grammar.rbinary            # [list(list(Rule))]
    lexical = grammar.lexical            # [dict(str, list(Terminal))]
    toid = grammar.toid                  # [dict(str, int)]
    tolabel = grammar.tolabel            # [dict(int, str)]
    goal = ChartItem(start, (1 << len(sent)) - 1) # [ChartItem]
    maxA = 0                             # [int]
    blocked = 0                          # [int]
    Cx = [{} for _ in toid]              # [list(dict(ChartItem, Edge))]
    C = {}                               # [dict(ChartItem, list(Edge))]
    A = agenda()                         # [agenda]

    # scan: assign part-of-speech tags
    Epsilon = toid["Epsilon"]            # [int]
    for i, w in enumerate(sent):         # [tuple(int, str)]
        recognized = False               # [bool]
        for terminal in lexical.get(w, []): # [__iter(Terminal)]
            if not tags or tags[i] == tolabel[terminal.lhs].split("@")[0]: # []
                item = ChartItem(terminal.lhs, 1 << i) # [ChartItem]
                I = ChartItem(Epsilon, i) # [ChartItem]
                z = terminal.prob        # [float]
                A[item] = Edge(z, z, z, I, None) # [Edge]
                C[item] = []             # [list(Edge)]
                recognized = True        # [bool]
        if not recognized and tags and tags[i] in toid: # []
            item = ChartItem(toid[tags[i]], 1 << i) # [ChartItem]
            I = ChartItem(Epsilon, i)    # [ChartItem]
            A[item] = Edge(0.0, 0.0, 0.0, I, None) # [Edge]
            C[item] = []                 # [list(Edge)]
            recognized = True            # [bool]
        elif not recognized:             # [bool]
            print "not covered:", tags[i] if tags else w # [str], [str]
            return C, None               # [tuple(dict(ChartItem, list(Edge)), None)]

    # parsing
    while A:                             # [agenda]
        item, edge = A.popitem()         # [tuple(ChartItem, Edge)]
        C[item].append(edge)             # [None]
        Cx[item.label][item] = edge      # [Edge]

        if item == goal:                 # [bool]
            if exhaustive: continue      # [bool]
            else: break 
        for rule in unary[item.label]:   # [__iter(Rule)]
            blocked += process_edge(     # [int]
                ChartItem(rule.lhs, item.vec), # [ChartItem]
                Edge(edge.inside + rule.prob, edge.inside + rule.prob, # [Edge]
                     rule.prob, item, None), A, C, exhaustive) # [float]
        for rule in lbinary[item.label]: # [__iter(Rule)]
            for sibling in Cx[rule.rhs2]: # [__iter(ChartItem)]
                e = Cx[rule.rhs2][sibling] # [Edge]
                if (item.vec & sibling.vec == 0 # []
                    and concat(rule, item.vec, sibling.vec)): # [bool]
                    blocked += process_edge( # [int]
                        ChartItem(rule.lhs, item.vec ^ sibling.vec), # [ChartItem]
                        Edge(edge.inside + e.inside + rule.prob, # [Edge]
                             edge.inside + e.inside + rule.prob, # [float]
                             rule.prob, item, sibling), A, C, exhaustive) # [float]
        for rule in rbinary[item.label]: # [__iter(Rule)]
            for sibling in Cx[rule.rhs1]: # [__iter(ChartItem)]
                e = Cx[rule.rhs1][sibling] # [Edge]
                if (sibling.vec & item.vec == 0 # []
                    and concat(rule, sibling.vec, item.vec)): # [bool]
                    blocked += process_edge( # [int]
                        ChartItem(rule.lhs, sibling.vec ^ item.vec), # [ChartItem]
                        Edge(e.inside + edge.inside + rule.prob, # [Edge]
                             e.inside + edge.inside + rule.prob, # [float]
                             rule.prob, sibling, item), A, C, exhaustive) # [float]
        if len(A) > maxA: maxA = len(A)  # [int]
        #if len(A) % 10000 == 0:
        #    print "agenda max %d, now %d, items %d" % (maxA, len(A), len(C))
    stderr.write("agenda max %d, now %d, items %d (%d labels), " % ( # [None]
                                maxA, len(A), len(C), len(filter(None, Cx)))) # []
    stderr.write("edges %d, blocked %d\n" # [None]
							% (sum(map(len, C.values())), blocked)) # []
    if goal not in C: goal = None        # [None]
    return (C, goal)                     # [tuple(dict(ChartItem, list(Edge)), ChartItem)]

def process_edge(newitem, newedge, A, C, exhaustive): # newitem: [ChartItem], newedge: [Edge], A: [agenda], C: [dict(ChartItem, list(Edge))], exhaustive: [bool]
    if newitem not in C and newitem not in A: # [bool]
        # prune improbable edges
        if newedge.score > 300.0: return 1 # [int]
        # haven't seen this item before, add to agenda
        A[newitem] = newedge             # [Edge]
        C[newitem] = []                  # [list(Edge)]
    elif newitem in A and newedge.inside < A[newitem].inside: # [Edge]
        # item has lower score, update agenda
        C[newitem].append(A[newitem])    # [None]
        A[newitem] = newedge             # [Edge]
    elif exhaustive:                     # [bool]
        # item is suboptimal, only add to exhaustive chart
        C[newitem].append(newedge)       # [None]
    return 0                             # [int]

def concat(rule, lvec, rvec):            # rule: [Rule], lvec: [int], rvec: [int]
    lpos = nextset(lvec, 0)              # [int]
    rpos = nextset(rvec, 0)              # [int]
    #this algorithm was taken from rparse, FastYFComposer.
    for x in range(len(rule.args)):      # [__iter(int)]
        m = rule.lengths[x] - 1          # [int]
        for n in range(m + 1):           # [__iter(int)]
            if testbit(rule.args[x], n): # [int]
                # check if there are any bits left, and
                # if any bits on the right should have gone before
                # ones on this side
                if rpos == -1 or (lpos != -1 and lpos <= rpos): # []
                    return False         # [bool]
                # jump to next gap
                rpos = nextunset(rvec, rpos) # [int]
                if lpos != -1 and lpos < rpos: # []
                    return False         # [bool]
                # there should be a gap if and only if
                # this is the last element of this argument
                if n == m:               # []
                    if testbit(lvec, rpos): # [int]
                        return False     # [bool]
                elif not testbit(lvec, rpos): # [int]
                    return False         # [bool]
                #jump to next argument
                rpos = nextset(rvec, rpos) # [int]
            else:
                # vice versa to the above
                if lpos == -1 or (rpos != -1 and rpos <= lpos): # []
                    return False         # [bool]
                lpos = nextunset(lvec, lpos) # [int]
                if rpos != -1 and rpos < lpos: # []
                    return False         # [bool]
                if n == m:               # []
                    if testbit(rvec, lpos): # [int]
                        return False     # [bool]
                elif not testbit(rvec, lpos): # [int]
                    return False         # [bool]
                lpos = nextset(lvec, lpos) # [int]
            #else: raise ValueError("non-binary element in yieldfunction")
    if lpos != -1 or rpos != -1:         # []
        return False                     # [bool]
    # everything looks all right
    return True                          # [bool]

def mostprobablederivation(chart, start, tolabel): # chart: [dict(ChartItem, list(Edge))], start: [ChartItem], tolabel: [dict(int, str)]
    """ produce a string representation of the viterbi parse in bracket
    notation"""
    edge = min(chart[start])             # [Edge]
    return getmpd(chart, start, tolabel), edge.inside # [tuple(str, float)]

def getmpd(chart, start, tolabel):       # chart: [dict(ChartItem, list(Edge))], start: [ChartItem], tolabel: [dict(int, str)]
    edge = min(chart[start])             # [Edge]
    if edge.right and edge.right.label:  # [int]
        return "(%s %s %s)" % (tolabel[start.label], # [str]
                    getmpd(chart, edge.left, tolabel), # []
                    getmpd(chart, edge.right, tolabel)) # []
    else: #unary or terminal
        return "(%s %s)" % (tolabel[start.label], # [str]
                    getmpd(chart, edge.left, tolabel) # [str]
                        if edge.left.label else str(edge.left.vec)) # []

def binrepr(a, sent):                    # a: [ChartItem], sent: [list(str)]
    return "".join(reversed(bin(a.vec)[2:].rjust(len(sent), "0"))) # [str]

def pprint_chart(chart, sent, tolabel):  # chart: [dict(ChartItem, list(Edge))], sent: [list(str)], tolabel: [dict(int, str)]
    print "chart:"                       # [str]
    for n, a in sorted((bitcount(a.vec), a) for a in chart): # [tuple(int, ChartItem)]
        if not chart[a]: continue        # [list(Edge)]
        print "%s[%s] =>" % (tolabel[a.label], binrepr(a, sent)) # [str]
        for edge in chart[a]:            # [__iter(Edge)]
            print "%g\t%g" % (exp(-edge.inside), exp(-edge.prob)), # [str]
            if edge.left.label:          # [int]
                print "\t%s[%s]" % (tolabel[edge.left.label], # [str]
                                    binrepr(edge.left, sent)), # []
            else:
                print "\t", repr(sent[edge.left.vec]), # [str], [str]
            if edge.right:               # [ChartItem]
                print "\t%s[%s]" % (tolabel[edge.right.label], # [str]
                                    binrepr(edge.right, sent)), # []
            print
        print

def do(sent, grammar):                   # sent: [str], grammar: [Grammar]
    print "sentence", sent               # [str], [str]
    chart, start = parse(sent.split(), grammar, None, grammar.toid['S'], False) # [tuple(dict(ChartItem, list(Edge)), ChartItem)]
    pprint_chart(chart, sent.split(), grammar.tolabel) # [None]
    if start:                            # [ChartItem]
        t, p = mostprobablederivation(chart, start, grammar.tolabel) # [tuple(str, float)]
        print exp(-p), t, '\n'           # [float], [str], [str]
    else:
        print "no parse"                 # [str]
    return start is not None             # [bool]

def read_srcg_grammar(rulefile, lexiconfile): # rulefile: [str], lexiconfile: [str]
    """ Reads a grammar as produced by write_srcg_grammar. """
    srules = [line[:len(line)-1].split('\t') for line in open(rulefile)] # [list(list(str))]
    slexicon = [line[:len(line)-1].split('\t') for line in open(lexiconfile)] # [list(list(str))]
    rules = [((tuple(a[:len(a)-2]), tuple(tuple(map(int, b)) # [list(tuple(tuple(tuple(str), tuple(tuple(int))), float))]
                    for b in a[len(a)-2].split(","))), # [__iter(tuple(int))]
                float(a[len(a)-1])) for a in srules] # [list(tuple(tuple(tuple(str), tuple(tuple(int))), float))]
    lexicon = [((tuple(a[:len(a)-2]), a[len(a)-2]), float(a[len(a)-1])) # [list(tuple(tuple(tuple(str), str), float))]
                    for a in slexicon]   # [list(tuple(tuple(tuple(str), str), float))]
    return rules, lexicon                # [tuple(list(tuple(tuple(tuple(str), tuple(tuple(int))), float)), list(tuple(tuple(tuple(str), str), float)))]

def splitgrammar(grammar, lexicon):      # grammar: [list(tuple(tuple(tuple(str), tuple(tuple(int))), float))], lexicon: [list(tuple(tuple(tuple(str), str), float))]
    """ split the grammar into various lookup tables, mapping nonterminal
    labels to numeric identifiers. Also negates log-probabilities to
    accommodate min-heaps.
    Can only represent ordered SRCG rules (monotone LCFRS). """
    # get a list of all nonterminals; make sure Epsilon and ROOT are first,
    # and assign them unique IDs
    nonterminals = list(enumerate(["Epsilon", "ROOT"] # [list(tuple(int, str))]
        + sorted(set(nt for (rule, yf), weight in grammar for nt in rule) # [tuple(tuple(tuple(str), tuple(tuple(int))), float)]
            - set(["Epsilon", "ROOT"])))) # [set(str)]
    toid = dict((lhs, n) for n, lhs in nonterminals) # [dict(str, int)]
    tolabel = dict((n, lhs) for n, lhs in nonterminals) # [dict(int, str)]
    bylhs = [[] for _ in nonterminals]   # [list(list(Rule))]
    unary = [[] for _ in nonterminals]   # [list(list(Rule))]
    lbinary = [[] for _ in nonterminals] # [list(list(Rule))]
    rbinary = [[] for _ in nonterminals] # [list(list(Rule))]
    lexical = {}                         # [dict(str, list(Terminal))]
    arity = array('B', [0] * len(nonterminals)) # [array::array(int)]
    for (tag, word), w in lexicon:       # [tuple(tuple(tuple(str), str), float)]
        t = Terminal(toid[tag[0]], toid[tag[1]], 0, word, abs(w)) # [Terminal]
        assert arity[t.lhs] in (0, 1)    # []
        arity[t.lhs] = 1                 # [int]
        lexical.setdefault(word, []).append(t) # [None]
    for (rule, yf), w in grammar:        # [tuple(tuple(tuple(str), tuple(tuple(int))), float)]
        args, lengths = yfarray(yf)      # [tuple(array::array(int))]
        assert yf == arraytoyf(args, lengths) # []
        #cyclic unary productions
        if len(rule) == 2 and w == 0.0: w += 0.00000001 # [float]
        r = Rule(toid[rule[0]], toid[rule[1]], # [Rule]
            toid[rule[2]] if len(rule) == 3 else 0, args, lengths, abs(w)) # [float]
        if arity[r.lhs] == 0:            # []
            arity[r.lhs] = len(args)     # [int]
        assert arity[r.lhs] == len(args) # []
        if len(rule) == 2:               # []
            unary[r.rhs1].append(r)      # [None]
            bylhs[r.lhs].append(r)       # [None]
        elif len(rule) == 3:             # []
            lbinary[r.rhs1].append(r)    # [None]
            rbinary[r.rhs2].append(r)    # [None]
            bylhs[r.lhs].append(r)       # [None]
        else: raise ValueError("grammar not binarized: %r" % r) # [ValueError]
    #assert 0 not in arity[1:]
    return Grammar(unary, lbinary, rbinary, lexical, bylhs, toid, tolabel) # [Grammar]

def yfarray(yf):                         # yf: [tuple(tuple(int))]
    """ convert a yield function represented as a 2D sequence to an array
    object. """
    # I for 32 bits (int), H for 16 bits (short), B for 8 bits (char)
    vectype = 'I'; vecsize = 32          # [int]
    lentype = 'H'; lensize = 16          # [int]
    assert len(yf) <= lensize            # [int]
    assert all(len(a) <= vecsize for a in yf) # [__iter(bool)]
    initializer = [sum(1 << n for n, b in enumerate(a) if b) for a in yf] # [list(int)]
    args = array('I', initializer)       # [array::array(int)]
    lengths = array('H', map(len, yf))   # [array::array(int)]
    return args, lengths                 # [tuple(array::array(int))]

def arraytoyf(args, lengths):            # args: [array::array(int)], lengths: [array::array(int)]
    return tuple(tuple(1 if a & (1 << m) else 0 for m in range(n)) # [__iter(int)]
                            for n, a in zip(lengths, args)) # [__iter(tuple(int))]

# bit operations
def nextset(a, pos):                     # a: [int], pos: [int]
    """ First set bit, starting from pos """
    result = pos                         # [int]
    if a >> result:                      # [int]
        while (a >> result) & 1 == 0:    # []
            result += 1                  # [int]
        return result                    # [int]
    return -1                            # [int]

def nextunset(a, pos):                   # a: [int], pos: [int]
    """ First unset bit, starting from pos """
    result = pos                         # [int]
    while (a >> result) & 1:             # [int]
        result += 1                      # [int]
    return result                        # [int]

def bitcount(a):                         # a: [int]
    """ Number of set bits (1s) """
    count = 0                            # [int]
    while a:                             # [int]
        a &= a - 1                       # [int]
        count += 1                       # [int]
    return count                         # [int]

def testbit(a, offset):                  # a: [int], offset: [int]
    """ Mask a particular bit, return nonzero if set """
    return a & (1 << offset)             # [int]

# various data types
class Grammar(object):                   # unary: [list(list(Rule))], bylhs: [list(list(Rule))], lbinary: [list(list(Rule))], lexical: [dict(str, list(Terminal))], tolabel: [dict(int, str)], toid: [dict(str, int)], rbinary: [list(list(Rule))]
    __slots__ = ('unary', 'lbinary', 'rbinary', 'lexical', # [tuple(str)]
                    'bylhs', 'toid', 'tolabel') # [str]
    def __init__(self, unary, lbinary, rbinary, lexical, bylhs, toid, tolabel): # self: [Grammar], unary: [list(list(Rule))], lbinary: [list(list(Rule))], rbinary: [list(list(Rule))], lexical: [dict(str, list(Terminal))], bylhs: [list(list(Rule))], toid: [dict(str, int)], tolabel: [dict(int, str)]
        self.unary = unary               # [list(list(Rule))]
        self.lbinary = lbinary           # [list(list(Rule))]
        self.rbinary = rbinary           # [list(list(Rule))]
        self.lexical = lexical           # [dict(str, list(Terminal))]
        self.bylhs = bylhs               # [list(list(Rule))]
        self.toid = toid                 # [dict(str, int)]
        self.tolabel = tolabel           # [dict(int, str)]

class ChartItem:                         # label: [int], vec: [int]
    __slots__ = ("label", "vec")         # [tuple(str)]
    def __init__(self, label, vec):      # self: [ChartItem], label: [int], vec: [int]
        self.label = label               # [int]
        self.vec = vec                   # [int]
    def __hash__(self):                  # self: [ChartItem]
        #form some reason this does not work well w/shedskin:
        #h = self.label ^ (self.vec << 31) ^ (self.vec >> 31)
        #the DJB hash function:
        h = ((5381 << 5) + 5381) * 33 ^ self.label # [int]
        h = ((h << 5) + h) * 33 ^ self.vec # [int]
        return -2 if h == -1 else h      # [int]
    def __eq__(self, other):             # self: [ChartItem], other: [ChartItem]
        if other is None: return False   # [bool]
        return self.label == other.label and self.vec == other.vec # [bool]

class Edge:                              # right: [ChartItem], score: [float], prob: [float], inside: [float], left: [ChartItem]
    __slots__ = ('score', 'inside', 'prob', 'left', 'right') # [tuple(str)]
    def __init__(self, score, inside, prob, left, right): # self: [Edge], score: [float], inside: [float], prob: [float], left: [ChartItem], right: [ChartItem]
        self.score = score; self.inside = inside; self.prob = prob # [float]
        self.left = left; self.right = right # [ChartItem]
    def __lt__(self, other):             # self: [Edge], other: [Edge]
        # the ordering only depends on inside probability
        # (or on estimate of outside score when added)
        return self.score < other.score  # [bool]
    def __gt__(self, other):             # self: [Edge], other: [Edge]
        return self.score > other.score  # [bool]
    def __eq__(self, other):             # self: [Edge], other: [Edge]
        return (self.inside == other.inside # [bool]
                and self.prob == other.prob # []
                and self.left == other.right # [bool]
                and self.right == other.right) # [bool]

class Terminal:                          # word: [str], rhs2: [int], rhs1: [int], lhs: [int], prob: [float]
    __slots__ = ('lhs', 'rhs1', 'rhs2', 'word', 'prob') # [tuple(str)]
    def __init__(self, lhs, rhs1, rhs2, word, prob): # self: [Terminal], lhs: [int], rhs1: [int], rhs2: [int], word: [str], prob: [float]
        self.lhs = lhs; self.rhs1 = rhs1; self.rhs2 = rhs2 # [int]
        self.word = word; self.prob = prob # [float]

class Rule:                              # rhs2: [int], lengths: [array::array(int)], _args: [array::array(int)], rhs1: [int], prob: [float], args: [array::array(int)], _lengths: [array::array(int)], lhs: [int]
    __slots__ = ('lhs', 'rhs1', 'rhs2', 'prob', # [tuple(str)]
                'args', 'lengths', '_args', 'lengths') # [str]
    def __init__(self, lhs, rhs1, rhs2, args, lengths, prob): # self: [Rule], lhs: [int], rhs1: [int], rhs2: [int], args: [array::array(int)], lengths: [array::array(int)], prob: [float]
        self.lhs = lhs; self.rhs1 = rhs1; self.rhs2 = rhs2 # [int]
        self.args = args; self.lengths = lengths; self.prob = prob # [float]
        self._args = self.args; self._lengths = self.lengths # [array::array(int)]

#the agenda (priority queue)
class Entry(object):                     # count: [int], value: [Edge], key: [ChartItem]
    __slots__ = ('key', 'value', 'count') # [tuple(str)]
    def __init__(self, key, value, count): # self: [Entry], key: [ChartItem], value: [Edge], count: [int]
        self.key = key          #the `task' # [ChartItem]
        self.value = value               # [Edge]
        self.count = count               # [int]
    def __eq__(self, other):             # self: [Entry], other: [Entry]
        return self.count == other.count # [bool]
    def __lt__(self, other):             # self: [Entry], other: [Entry]
        return self.value < other.value or (self.value == other.value # [bool]
                and self.count < other.count) # [int]

INVALID = 0                              # [int]
class agenda(object):                    # counter: [int], mapping: [dict(ChartItem, Entry)], heap: [list(Entry)]
    def __init__(self):                  # self: [agenda]
        self.heap = []                   # [list(Entry)]
        self.mapping = {}                # [dict(ChartItem, Entry)]
        self.counter = 1                 # [int]

    def __setitem__(self, key, value):   # self: [agenda], key: [ChartItem], value: [Edge]
        if key in self.mapping:          # []
            oldentry = self.mapping[key] # [Entry]
            entry = Entry(key, value, oldentry.count) # [Entry]
            self.mapping[key] = entry    # [Entry]
            heappush(self.heap, entry)   # [None]
            oldentry.count = INVALID     # [int]
        else:
            entry = Entry(key, value, self.counter) # [Entry]
            self.counter += 1            # [int]
            self.mapping[key] = entry    # [Entry]
            heappush(self.heap, entry)   # [None]

    def __getitem__(self, key):          # self: [agenda], key: [ChartItem]
        return self.mapping[key].value   # [Edge]

    def __contains__(self, key):         # self: [agenda], key: [ChartItem]
        return key in self.mapping       # [bool]

    def __len__(self):                   # self: [agenda]
        return len(self.mapping)         # [int]

    def popitem(self):                   # self: [agenda]
        entry = heappop(self.heap)       # [Entry]
        while entry.count is INVALID:    # [int]
            entry = heappop(self.heap)   # [Entry]
        del self.mapping[entry.key]      # [None]
        return entry.key, entry.value    # [tuple(ChartItem, Edge)]

def batch(rulefile, lexiconfile, sentfile): # rulefile: [str], lexiconfile: [str], sentfile: [str]
    rules, lexicon = read_srcg_grammar(rulefile, lexiconfile) # [tuple(list(tuple(tuple(tuple(str), tuple(tuple(int))), float)), list(tuple(tuple(tuple(str), str), float)))]
    root = rules[0][0][0][0]             # [str]
    grammar = splitgrammar(rules, lexicon) # [Grammar]
    lines = open(sentfile).read().splitlines() # [list(str)]
    sents = [[a.split("/") for a in sent.split()] for sent in lines] # [list(list(list(str)))]
    for wordstags in sents:              # [__iter(list(list(str)))]
        sent = [a[0] for a in wordstags] # [list(str)]
        tags = [a[1] for a in wordstags] # [list(str)]
        stderr.write("parsing: %s\n" % " ".join(sent)) # [None]
        chart, start = parse(sent, grammar, tags, grammar.toid[root], False) # [tuple(dict(ChartItem, list(Edge)), ChartItem)]
        if start:                        # [ChartItem]
            t, p = mostprobablederivation(chart, start, grammar.tolabel) # [tuple(str, float)]
            print "p=%g\n%s\n\n" % (exp(-p), t) # [str]
        else: print "no parse\n"         # [str]

def demo():
    rules = [                            # [list(tuple(tuple(tuple(str), tuple(tuple(int))), float))]
        ((('S','VP2','VMFIN'),    ((0,1,0),)),   log(1.0)), # [float]
        ((('VP2','VP2','VAINF'),  ((0,),(0,1))), log(0.5)), # [float]
        ((('VP2','PROAV','VVPP'), ((0,),(1,))),  log(0.5)), # [float]
        ((('VP2','VP2'),          ((0,),(0,))),  log(0.1))] # [float]
    lexicon = [                          # [list(tuple(tuple(tuple(str), str), float))]
        ((('PROAV', 'Epsilon'), 'Darueber'),     0.0), # [tuple(str)]
        ((('VAINF', 'Epsilon'), 'werden'),      0.0), # [tuple(tuple(str), str)]
        ((('VMFIN', 'Epsilon'), 'muss'),        0.0), # [tuple(str)]
        ((('VVPP', 'Epsilon'),  'nachgedacht'), 0.0)] # [tuple(str)]
    grammar = splitgrammar(rules, lexicon) # [Grammar]

    chart, start = parse("Darueber muss nachgedacht werden".split(), # [tuple(dict(ChartItem, list(Edge)), ChartItem)]
          grammar, "PROAV VMFIN VVPP VAINF".split(), grammar.toid['S'], False) # [int]
    pprint_chart(chart, "Darueber muss nachgedacht werden".split(), # [None]
          grammar.tolabel)               # [dict(int, str)]
    assert (mostprobablederivation(chart, start, grammar.tolabel) == # []
        ('(S (VP2 (VP2 (PROAV 0) (VVPP 2)) (VAINF 3)) (VMFIN 1))', -log(0.25))) # [float]
    assert do("Darueber muss nachgedacht werden", grammar) # [bool]
    assert do("Darueber muss nachgedacht werden werden", grammar) # [bool]
    assert do("Darueber muss nachgedacht werden werden werden", grammar) # [bool]
    print "ungrammatical sentence:"      # [str]
    assert not do("werden nachgedacht muss Darueber", grammar) # [bool]
    print "(as expected)\n"              # [str]

if __name__ == '__main__':               # []
    if len(argv) == 4:                   # []
        batch(argv[1], argv[2], argv[3]) # [None]
    else:
        demo()                           # [None]
        print """usage: %s grammar lexicon sentences # [str]

grammar is a tab-separated text file with one rule per line, in this format:

LHS	RHS1	RHS2	YIELD-FUNC	LOGPROB
e.g., S	NP	VP	[01,10]	0.1

LHS, RHS1, and RHS2 are strings specifying the labels of this rule.
The yield function is described by a list of bit vectors such as [01,10],
where 0 is a variable that refers to a contribution by RHS1, and 1 refers to
one by RHS2. Adjacent variables are concatenated, comma-separated components
indicate discontinuities.
The final element of a rule is its log probability.
The LHS of the first rule will be used as the start symbol.

lexicon is also tab-separated, in this format:

WORD	Epsilon	TAG	LOGPROB
e.g., nachgedacht	Epsilon	VVPP	0.1

Finally, sentences is a file with one sentence per line, consisting of a space
separated list of word/tag pairs, for example:

Darueber/PROAV muss/VMFIN nachgedacht/VVPP werden/VAINF

The output consists of Viterbi parse trees where terminals have been replaced
by indices; this makes it possible to express discontinuities in otherwise
context-free trees.""" % argv[0]         # []
