Module dopg
[hide private]
[frames] | no frames]

Source Code for Module dopg

  1  # -*- coding: utf-8 -*- 
  2  """DOP1 implementation. Andreas van Cranenburgh <andreas@unstable.nl> 
  3  TODOs: unicode support (replace str and repr calls)""" 
  4  #import psyco 
  5  #psyco.full() 
  6   
  7  from collections import defaultdict 
  8  from itertools import chain, count 
  9  #from math import log #do something with logprobs instead? 
 10  from nltk import Production, WeightedProduction, WeightedGrammar, FreqDist 
 11  from nltk import Tree, ImmutableTree, Nonterminal, InsideChartParser, UnsortedChartParser 
 12  from bitpar import BitParChartParser 
 13   
14 -def cartprod(a, b):
15 """ cartesian product of two sequences """ 16 for x in a: 17 for y in b: 18 yield x, y
19
20 -def cartpi(seq):
21 """ produce a flattened fold using cartesian product as operator 22 23 >>> list(cartpi( ( (1,2), (3,4), (5,6) ) ) ) 24 [(1, 3, 5), (2, 3, 5), (1, 4, 5), (2, 4, 5), (1, 3, 6), (2, 3, 6), 25 (1, 4, 6), (2, 4, 6)] """ 26 if len(seq) == 0: return ((), ) 27 else: return ((a,) + b for b in cartpi(seq[1:]) for a in seq[0])
28 29 #NB: the following code is equivalent to nltk.Tree.productions, except for accepting unicode
30 -def productions(tree):
31 """ 32 Generate the productions that correspond to the non-terminal nodes of the tree. 33 For each subtree of the form (P: C1 C2 ... Cn) this produces a production of the 34 form P -> C1 C2 ... Cn. 35 @rtype: list of C{Production}s 36 """ 37 38 if not (isinstance(tree.node, str) or isinstance(tree.node, unicode)): 39 raise TypeError, 'Productions can only be generated from trees having node labels that are strings' 40 41 prods = [Production(Nonterminal(tree.node), tree._child_names())] 42 for child in tree: 43 if isinstance(child, Tree): 44 prods += productions(child) 45 return prods
46
47 -class GoodmanDOP:
48 - def __init__(self, treebank, rootsymbol='S', wrap=False, cnf=True, cleanup=True, parser=InsideChartParser, **parseroptions):
49 """ initialize a DOP model given a treebank. uses the Goodman 50 reduction of a STSG to a PCFG. after initialization, 51 self.parser will contain an InsideChartParser. 52 53 >>> tree = Tree("(S (NP mary) (VP walks))") 54 >>> d = GoodmanDOP([tree]) 55 >>> print d.grammar 56 Grammar with 12 productions (start state = S) 57 NP -> 'mary' [1.0] 58 NP@1 -> 'mary' [1.0] 59 S -> NP VP [0.25] 60 S -> NP VP@2 [0.25] 61 S -> NP@1 VP [0.25] 62 S -> NP@1 VP@2 [0.25] 63 S@0 -> NP VP [0.25] 64 S@0 -> NP VP@2 [0.25] 65 S@0 -> NP@1 VP [0.25] 66 S@0 -> NP@1 VP@2 [0.25] 67 VP -> 'walks' [1.0] 68 VP@2 -> 'walks' [1.0] 69 >>> print d.parser.parse("mary walks".split()) 70 (S (NP mary) (VP@2 walks)) (p=0.25) 71 72 @param treebank: a list of Tree objects. Caveat lector: 73 terminals may not have (non-terminals as) siblings. 74 @param wrap: boolean specifying whether to add the start symbol 75 to each tree 76 @param parser: a class which will be instantiated with the DOP 77 model as its grammar. Support BitParChartParser. 78 79 instance variables: 80 - self.grammar a WeightedGrammar containing the PCFG reduction 81 - self.fcfg a list of strings containing the PCFG reduction 82 with frequencies instead of probabilities 83 - self.parser an InsideChartParser object 84 - self.exemplars dictionary of known parse trees (memoization)""" 85 nonterminalfd, ids = FreqDist(), count() 86 cfg = FreqDist() 87 self.exemplars = {} 88 if wrap: 89 # wrap trees in a common root symbol (eg. for morphology) 90 treebank = [Tree(rootsymbol, [a]) for a in treebank] 91 if cnf: 92 for a in treebank: 93 a.chomsky_normal_form() #todo: sibling annotation necessary? 94 # add unique IDs to nodes 95 utreebank = list((tree, self.decorate_with_ids(tree, ids)) for tree in treebank) 96 lexicon = set(reduce(chain, (a.leaves() for a,b in utreebank))) 97 # count node frequencies 98 for tree,utree in utreebank: 99 #self.exemplars[tuple(tree.leaves())] = tree 100 self.nodefreq(tree, nonterminalfd) 101 self.nodefreq(utree, nonterminalfd) 102 #cfg.extend(self.goodman(tree, utree)) 103 #cfg.update(zip(self.goodmanfd(tree, ids, nonterminalfd), ones)) 104 #cfg.extend(self.goodmanfd(tree, ids, nonterminalfd)) 105 #print type(parser) == type(InsideChartParser) 106 #print type(parser) is type(InsideChartParser) 107 if type(parser) == type(BitParChartParser): 108 # this takes the most time, produce CFG rules: 109 cfg = FreqDist(reduce(chain, (self.goodman(tree, utree) for tree, utree in utreebank))) 110 # annotate rules with frequencies 111 self.fcfg = self.frequencies(cfg, nonterminalfd) 112 print "writing grammar" 113 self.parser = BitParChartParser(self.fcfg, lexicon, rootsymbol, cleanup=False, **parseroptions) 114 else: 115 cfg = FreqDist(reduce(chain, (self.goodman(tree, utree, False) for tree, utree in utreebank))) 116 probs = self.probabilities(cfg, nonterminalfd) # DELETE ME 117 #for a in probs: print a 118 self.grammar = WeightedGrammar(Nonterminal(rootsymbol), probs) 119 #self.probabilities(cfg, nonterminalfd) 120 self.parser = InsideChartParser(self.grammar) 121 122 #stuff for self.mccparse 123 #the highest id 124 #self.addresses = ids.next() 125 #a list of interior + exterior nodes, 126 #ie., non-terminals with and without ids 127 #self.nonterminals = nonterminalfd.keys() 128 #a mapping of ids to nonterminals without their IDs 129 #self.nonterminal = dict(a.split("@")[::-1] for a in 130 # nonterminalfd.keys() if "@" in a) 131 132 #clean up 133 del cfg, nonterminalfd
134
135 - def goodmanfd(self, tree, ids, nonterminalfd):
136 utree = self.decorate_with_ids(tree, ids) 137 self.nodefreq(tree, nonterminalfd) 138 self.nodefreq(utree, nonterminalfd) 139 return self.goodman(tree, utree)
140
141 - def decorate_with_ids(self, tree, ids):
142 """ add unique identifiers to each non-terminal of a tree. 143 144 >>> tree = Tree("(S (NP mary) (VP walks))") 145 >>> d = GoodmanDOP([tree]) 146 >>> d.decorate_with_ids(tree, count()) 147 Tree('S@0', [Tree('NP@1', ['mary']), Tree('VP@2', ['walks'])]) 148 149 @param ids: an iterator yielding a stream of IDs""" 150 utree = tree.copy(True) 151 for a in utree.subtrees(): 152 a.node = "%s@%d" % (a.node, ids.next()) 153 return utree
154
155 - def nodefreq(self, tree, nonterminalfd, leaves=1):
156 """count frequencies of nodes by calculating the number of 157 subtrees headed by each node. updates "nonterminalfd" as 158 a side effect 159 160 >>> fd = FreqDist() 161 >>> tree = Tree("(S (NP mary) (VP walks))") 162 >>> d = GoodmanDOP([tree]) 163 >>> d.nodefreq(tree, fd) 164 4 165 >>> fd.items() 166 [('S', 4), ('NP', 1), ('VP', 1)] 167 168 #[('S', 9), ('NP', 2), ('VP', 2), ('mary', 1), ('walks', 1)] 169 170 @param nonterminalfd: the FreqDist to store the counts in.""" 171 if isinstance(tree, Tree) and len(tree) > 0 and tree.height() > 2: 172 n = reduce((lambda x,y: x*y), 173 (self.nodefreq(x, nonterminalfd) + 1 for x in tree)) 174 nonterminalfd.inc(tree.node, count=n) 175 #print n, tree, nonterminalfd 176 return n 177 elif tree.height() == 2: 178 #nonterminalfd.inc(str(tree), count=leaves) 179 nonterminalfd.inc(tree.node, count=len(tree)) 180 #print 1, tree, nonterminalfd 181 return 1
182
183 - def goodman(self, tree, utree, bitparfmt=True):
184 """ given a parsetree from a treebank, yield a goodman 185 reduction of eight rules per node (in the case of a binary tree). 186 187 >>> tree = Tree("(S (NP mary) (VP walks))") 188 >>> d = GoodmanDOP([tree]) 189 >>> utree = d.decorate_with_ids(tree, count()) 190 >>> sorted(d.goodman(tree, utree, False)) 191 [(NP, ('mary',)), (NP@1, ('mary',)), (S, (NP, VP)), (S, (NP, VP@2)), 192 (S, (NP@1, VP)), (S, (NP@1, VP@2)), (S@0, (NP, VP)), 193 (S@0, (NP, VP@2)), (S@0, (NP@1, VP)), (S@0, (NP@1, VP@2)), 194 (VP, ('walks',)), (VP@2, ('walks',))] 195 """ 196 # linear: nr of nodes 197 sep = "\t" 198 for p, up in zip(tree.productions(), utree.productions()): 199 # THIS SHOULD NOT HAPPEN: 200 if len(p.rhs()) == 0: raise ValueError 201 if len(p.rhs()) == 1: 202 if not isinstance(p.rhs()[0], Nonterminal): rhs = (p.rhs(), ) 203 else: rhs = (p.rhs(), up.rhs()) 204 #else: rhs = cartprod(*zip(p.rhs(), up.rhs())) 205 else: 206 if all(isinstance(a, Nonterminal) for a in up.rhs()): 207 rhs = set(cartpi(zip(p.rhs(), up.rhs()))) 208 else: rhs = cartpi(zip(p.rhs(), up.rhs())) 209 210 # constant factor: 8 211 #for l, r in cartpi(((p.lhs(), up.lhs()), rhs)): 212 for l, r in cartprod((p.lhs(), up.lhs()), rhs): 213 #yield Production(l, r) 214 if bitparfmt: 215 yield sep.join((str(l), sep.join(map(str, r)))) 216 else: 217 yield l, r
218 # yield a delayed computation that also gives the frequencies 219 # given a distribution of nonterminals 220 #yield (lambda fd: WeightedProduction(l, r, prob= 221 # reduce(lambda x,y: x*y, map(lambda z: '@' in z and 222 # fd[z] or 1, r)) / float(fd[l]))) 223
224 - def probabilities(self, cfg, fd):
225 """merge cfg and frequency distribution into a pcfg with the right 226 probabilities. 227 228 @param cfg: a list of Productions 229 @param nonterminalfd: a FreqDist of (non)terminals (with and 230 without IDs)""" 231 #return [a(nonterminalfd) for a in cfg) 232 def prob(l, r): 233 #print l, '->', r, reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) 234 # and fd[str(z)] or 1), r)), '/', float(fd[str(l)]) 235 return reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) 236 and fd[str(z)] or 1), r)) / float(fd[str(l)])
237 # format expected by mccparse() 238 #self.pcfg = dict((Production(l, r), (reduce((lambda x,y: x*y), 239 # map((lambda z: '@' in (type(z) == Nonterminal and z.symbol() or z) 240 # and nonterminalfd[z] or 1), r)) / nonterminalfd[l])) 241 # for l, r in set(cfg)) 242 243 # merge identical rules: 244 #return [WeightedProduction(rule[0], rule[1:], prob=freq*prob(rule[0], rule[1:])) for rule, freq in ((rule.split('\t'), freq) for rule,freq in cfg.items())] 245 246 return [WeightedProduction(l, r, prob=freq*prob(l, r)) for (l,r),freq in cfg.items()]
247 # do not merge identical rules 248 #return [WeightedProduction(l, r, prob=prob(l, r)) for l, r in cfg] 249
250 - def frequencies(self, cfg, fd):
251 """merge cfg and frequency distribution into a list of weighted 252 productions with frequencies as weights (as expected by bitpar). 253 254 @param cfg: a list of Productions 255 @param nonterminalfd: a FreqDist of (non)terminals (with and 256 without IDs)""" 257 def prob(r): 258 return reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) 259 and fd[str(z)] or 1), r), 1)
260 261 # merge identical rules: 262 #cfgfd = FreqDist(cfg) 263 #for rule,cnt in cfgfd.items(): 264 # cfgfd.inc(rule, count=(cnt-1) * prob(*rule)) 265 #return cfgfd 266 return ((rule, freq * reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) and fd[str(z)] or 1), rule.split('\t')[1:]), 1)) for rule, freq in cfg.items()) 267 #rule.append(prob(*rule)) 268 #return [(rule, prob(rule[1])) for rule in cfg] 269
270 - def removeids(self, tree):
271 """ remove unique IDs introduced by the Goodman reduction """ 272 for a in tree.treepositions(): 273 if '@' in str(tree[a]): 274 tree[a].node = tree[a].node.split('@')[0] 275 return tree
276
277 - def parse(self, sent):
278 """ memoize parse trees. TODO: maybe add option to add every 279 parse tree to the set of exemplars, ie., incremental learning. 280 this uses the most probable derivation (not very good).""" 281 try: 282 return self.exemplars[tuple(sent)] 283 except KeyError: 284 self.exemplars[tuple(sent)] = self.parser.parse(sent) 285 return self.exemplars[tuple(sent)]
286
287 - def mostprobableparse(self, sent, sample=None):
288 """ warning: this problem is NP-complete. using an unsorted 289 chart parser avoids unnecessary sorting (since we need all 290 derivations anyway). 291 292 @param sent: a sequence of terminals 293 @param sample: None or int; if int then sample that many parses""" 294 p = FreqDist() 295 for a in self.parser.nbest_parse(sent, sample): 296 p.inc(ImmutableTree.convert(self.removeids(a)), a.prob()) 297 if p.max(): return p.max() 298 else: raise ValueError("no parse")
299
300 - def mostconstituentscorrect(self, sent):
301 """ not working yet. almost verbatim translation of Goodman's (1996) 302 most constituents correct parsing algorithm, except for python's 303 zero-based indexing. needs to be modified to return the actual parse 304 tree. expects a pcfg in the form of a dictionary from productions to 305 probabilities """ 306 def g(s, t, x): 307 def f(s, t, x): 308 return self.pcfg[Production(rootsymbol, 309 sent[1:s] + [x] + sent[s+1:])]
310 def e(s, t, x): 311 return self.pcfg[Production(x, sent[s:t+1])] 312 return f(s, t, x) * e(s, t, x ) / e(1, n, rootsymbol) 313 314 sumx = defaultdict(int) #zero 315 maxc = defaultdict(int) #zero 316 for length in range(2, len(sent)+1): 317 for s in range(1, len(sent) + length): 318 t = s + length - 1 319 for x in self.nonterminals: 320 sumx[x] = g(s, t, x) 321 for k in range(self.addresses): 322 #ordered dictionary here 323 x = self.nonterminal[k] 324 sumx[x] += g(s, t, "%s@%d" % (x, k)) 325 max_x = max(sumx[x] for x in self.nonterminals) 326 #for x in self.nonterminals: 327 # max_x = argmax(sumx, x) #??? 328 best_split = max(maxc[(s,r)] + maxc[(r+1,t)] 329 for r in range(s, t)) 330 #for r in range(s, t): 331 # best_split = max(maxc[(s,r)] + maxc[(r+1,t)]) 332 maxc[(s,t)] = sumx(max_x) + best_split 333 334 return maxc[(1, len(sent) + 1)] 335
336 -def main():
337 """ a basic REPL for testing """ 338 corpus = """(S (NP John) (VP (V likes) (NP Mary))) 339 (S (NP Peter) (VP (V hates) (NP Susan))) 340 (S (NP Harry) (VP (V eats) (NP pizza))) 341 (S (NP Hermione) (VP (V eats)))""".splitlines() 342 corpus ="""(S (NP (DT The) (NN cat)) (VP (VBP saw) (NP (DT the) (JJ hungry) (NN dog)))) 343 (S (NP (DT The) (JJ little) (NN mouse)) (VP (VBP ate) (NP (DT the) (NN cat))))""".splitlines() 344 #corpus = """(S (NP mary) (VP walks) (AP quickly))""".splitlines() 345 #(S (NP Harry) (VP (V likes) (NP Susan) (ADVP (RB very) (RB much)))) 346 corpus = [Tree(a) for a in corpus] 347 for a in corpus: 348 #continue 349 a.chomsky_normal_form() 350 #d = GoodmanDOP(corpus, rootsymbol='S') 351 d = GoodmanDOP(corpus, rootsymbol='TOP', wrap='TOP', parser=BitParChartParser) 352 #d = GoodmanDOP(corpus, rootsymbol='TOP', wrap='TOP') 353 #print d.grammar 354 from nltk import ImmutableTree 355 print "corpus" 356 for a in corpus: print a 357 w = "foo!" 358 while w: 359 print "sentence:", 360 w = raw_input().split() 361 #print d.parser.prob_parse(w) 362 try: 363 p = FreqDist() 364 for a in d.parser.nbest_parse(w): 365 print a 366 p.inc(ImmutableTree.convert(d.removeids(a)), a.prob()) 367 #for b, a in sorted((b,a) for (a,b) in p.items()): 368 # print a, b 369 print 370 print 'best', p.max() 371 #print d.parse(w) 372 except Exception as e: 373 print "error", e
374 375 if __name__ == '__main__': 376 import doctest 377 # do doctests, but don't be pedantic about whitespace (I suspect it is the 378 # militant anti-tab faction who are behind this obnoxious default) 379 fail, attempted = doctest.testmod(verbose=False, 380 optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS) 381 if attempted and not fail: 382 print "%d doctests succeeded!" % attempted 383 main() 384