# -*- coding: utf-8 -*-
#DOP1 implementation. Not efficient! Space and time complexity is crazy.
#Andreas van Cranenburgh <andreas@unstable.nl>


import nltk
from itertools import *
from math import log

def subsets(seq):
	""" iterate over subsets of a sequence """
	def combinations(items, n):
		""" iterate over combinations of length n of a sequence """
		if n==0: yield ()
		else:
			for i in xrange(len(items)-n+1):
				for cc in combinations(items[i+1:],n-1):
					yield (items[i],)+cc
	for n in xrange(len(seq)-1, -1, -1):
		for a in combinations(seq, n):
			if a: yield a
	yield ()

class Dop:
	def __init__(self, corpus, removeinternal=False, PCFG=False):
		""" initialize a DOP model given a treebank.
			removeinternal: remove internal nodes from trees;
			pcfg: only keep trees of height 2, 
				ie., a standard pcfg read off from the treebank """
		self.corpus = []
		for a in corpus:
			self.corpus.extend(self.subtrees(a))
		
		if PCFG:
			self.corpus = [a for a in self.corpus if a.height() == 2]

		def removeint(a):
			return nltk.Tree(a.node, [a[x] for x in a.treepositions() if len(a[x]) == 0 or type(a[x]) == type("")]).freeze()
		if removeinternal:
			self.corpus = [removeint(a) for a in self.corpus if removeint(a).height() > 1]
		self.fd = nltk.FreqDist(a.freeze() for a in self.corpus)
		self.fdl = nltk.FreqDist(a.node for a in self.corpus)
		self.corpus = dict((b, set(c for c in self.corpus if c.node == b)) for b in set(a.node for a in self.corpus))
		#self.corpus = set(self.corpus)
	
	def parse(self, sent, derivation=nltk.Tree('(S )'), prob=1, fr=[], prefix=False):
		""" return an iterator over all possible parses of a sentence """
		def firstemptynode(tree):
			return (a for a in tree.treepositions() if len(tree[a]) == 0).next()
		def leaves(tree):
			return (tree[a] for a in tree.treepositions() 
					if type(tree[a]) == type("") or len(tree[a]) == 0)
		def firstdiff(a,b):
			return (n for n,(x,y) in enumerate(zip(a,b)) if x != y).next()
		def eligible(tree, sent):
			return all(a == b or len(a) == 0 for a,b in zip(leaves(tree), sent))
		if sent == derivation.leaves() or (prefix 
		and (len(sent) <= len(derivation.leaves()) 
		and derivation.leaves()[:len(sent)] == sent)): 
			yield prob, derivation.freeze(), fr
		elif len(derivation.leaves()) == len(list(leaves(derivation))):
			return
		else:
			sentremaining = sent[firstdiff(sent, leaves(derivation)):]
			node = firstemptynode(derivation)
			fragments = [a for a in self.corpus[derivation[node].node] 
							if eligible(a, sentremaining)]
			for fragment in fragments:
				d = Dop.join(derivation, fragment)
				for a in self.parse(sent, d, prob * (self.fd[fragment.freeze()]/float(self.fdl[fragment.node])), fr + [fragment], prefix): 
					yield a

	def subtrees(self, tree):
		""" yield all subtrees as defined by DOP1 """
		def validdeletion(s):
			""" check whether a set of frontier nodes to be deleted is valid,
				ie., test whether it does not contain children of nodes to 
				be deleted """
			if () in s: return False
			for n, a in enumerate(s):
				for m, b in enumerate(a[:-1]):
					if a[:m+1] in s[:n]:
						return False
			return True
					
		#iterate over all complete subtrees:
		for a in tree.subtrees():
			#and then yield trees with all possible deletions of frontier nodes
			leaves = set(a.treepositions(order='leaves'))
			for b in subsets([x for x in a.treepositions() if x not in leaves]):
				if not validdeletion(b): continue
				#make deep copies or you'll go crazy.
				copy = a.copy(deep=True)
				for c in b:
					copy[c] = nltk.Tree(copy[c].node, [])
				yield copy.freeze()

	@classmethod
	def join(cls, a, b):
		if not type(a) in (nltk.tree.Tree, nltk.tree.ImmutableTree): return
		if not type(b) in (nltk.tree.Tree, nltk.tree.ImmutableTree): return
		for c in a.treepositions():
			# non-terminal node, same label as b, empty node
			if type(a[c]) != str and a[c].node == b.node and len(a[c]) == 0:
				d = a.copy(True)
				d[c].extend(b)
				return d

	def prodprob(self):
		for a in sorted(self.fd):
			print '%s %d / %d' % (a.productions()[0], 
									self.fd[a], self.fdl[a.node])
	
	def prefixprob(self, sent):
		return sum(p for p,a,f in self.parse(sent, prefix=True))

	def surprisal(self, sent):
		if len(sent) <= 1: p1 = 1
		else: p1 = self.prefixprob(sent[:-1])
		p2 = self.prefixprob(sent)
		if p1 and p2:
			return log(p1) - log(p2)

def altmain():
	corpus = ['(A (Aff on) (A (V dank) (Aff baar)))',
                    '(A (Aff on) (A (V vind) (Aff baar)))',
                    '(N (V (Aff be) (V vind)) (Aff ing))',
                    '(N (N (N (P aan) (N deel)) (V houd) (Aff er)) (Aff s) (N (V vergader) (Aff ing)))',
                    '(N (N (P aan) (N (Aff ge) (N zicht))) (Aff s) (N (V lig) (Aff ing)))',
                    '(N (N (A (V (P aan) (V spreek)) (Aff elijk)) (Aff heid)) (Aff s) (N (V (Aff ver) (A zeker)) (Aff ing)))']
	corpus = [nltk.Tree(a) for a in corpus]
	dop = Dop(corpus)
	for p,d, f in sorted(chain(
		dop.parse('on dank baar'.split(), derivation=nltk.Tree('(A )')),
		dop.parse('on dank baar'.split(), derivation=nltk.Tree('(V )')))):
		print p, d
		for a in f:
			print '\t', a

def main():
	t = nltk.Tree.parse('(S (NP John) (VP (V likes) (NP Mary)))')
	u = nltk.Tree.parse('(S (NP Peter) (VP (V hates) (NP Susan)))')
	w = nltk.Tree('(S (NP Harry) (VP eats (NP pizza)))')
	v = nltk.Tree('(S (NP Hermione) (VP eats))')
	x = nltk.Tree('(S (NP Harry) (VP (V likes) (NP Susan) (RB very much)))')

	#y = nltk.Tree("(S (PP (VP given (NP (NP the complexity) (PP of physics)))), (S (NP i) (VP (would (VP assume (NP the conditions) to (PP always (VP be (ADV sufficiently complex) as to rule out starting anywhere near a global minimum by chance.")
	dop = Dop([t,u,w,v,x], removeinternal=False)
	pcfg = Dop([w, v], PCFG=True)

	print "DOP productions:"
	dop.prodprob()
	print "PCFG productions:"
	pcfg.prodprob()

	s1 = "Hermione eats pizza".split()
	s2 = "pizza eats Harry".split()

	print "PCFG:"
	for a in (s1, s2):
		for b in range(1,4):
			print "prefix prob:", " ".join(a[:b]), pcfg.prefixprob(a[:b])

	for a in (s1, s2):
		for b in range(1,4):
			print "surprisal: ", " ".join(a[:b]), pcfg.surprisal(a[:b])

	print "DOP:"
	for a in (s1, s2):
		for b in range(1,4):
			print "prefix prob:", " ".join(a[:b]), dop.prefixprob(a[:b])

	for a in (s1, s2):
		for b in range(1,4):
			print "surprisal: ", " ".join(a[:b]), dop.surprisal(a[:b])

def test():
	corpus = u"""(NN (N (A (A (A mal) (J riĉ)) (A eg)) (A ul)) o)
(NN ((((V kudr) (A ist)) (A in)) (N edz)) o)
(JJ ((((P en) (V konduk)) (V it)) a) j)
(RB ((A dis (V send)) (V ant)) e)
(NN ((A tra) (V rigard)) o)
(NN ((V dezir) (A eg)) o)
(JJ ((V estim) (V at)) a)
(RB ((V ricev) (V int)) e)
(NN ((A mal) (A diligent)) o)
(VB ((A mal) (V help)) is)
(VB ((J grand) (A iĝ)) os)
(NN (((N fingr) (A ing)) o) n)
(NN (N ŝuld) o)
(JJ (J du (N hor)) a)
(NN (N sfer) o)
(VB (V serĉ) is)
(JJ ((D ĉiu) (N tag)) a)
(RB ((D ĉiu) (N tag)) e)
(NN (N popol) o)
(NN (((N man) o) j) n)
(NN ((N kat) (A in) o) n)
(VB (V ekzamen) i)
(NN (V konduk) ((V ant) o))
(JJ ((J propr) a) n)
(NN (N mebl) ((A ist) o))
(VB (V kresk) i)
(VB (V lum) is)
(VB (V sci) u)
(VB ((V perd) (A iĝ)) os)
(NN (N onkl) o)
(JJ (((V far) (V at)) a) j)
(NN (((N gazet) o) j) n)""".splitlines()
	corpus = """(S (NP John) (VP (V likes) (NP Mary)))
(S (NP Peter) (VP (V hates) (NP Susan)))
(S (NP Harry) (VP (V eats) (NP pizza)))
(S (NP Hermione) (VP (V eats)))""".splitlines()
	corpus = """(S (NP (DT The) (NN cat)) (VP (VBP saw) (NP (DT the) (JJ hungry) (NN dog))))
(S (NP (DT The) (JJ little) (NN mouse)) (VP (VBP ate) (NP (DT the) (NN cat))))""".splitlines()
	#(S (NP Harry) (VP (V likes) (NP Susan) (ADVP (RB very) (RB much))))
	dd = Dop(nltk.Tree(a) for a in corpus)
	print corpus
        w = "foo!"
        while w:
                print "word:",
                w = raw_input()
		fd = nltk.FreqDist()
                for p,t,f in dd.parse(w.split(), nltk.Tree("(S )")):
			fd.inc(nltk.ImmutableTree.convert(t), p)
                	#print p, t, len(f)
		for a,b in fd.items():
			print a,b
                #for root in "RB VB JJ NN".split():
                #        for p,t,f in dd.parse(w.split(), nltk.Tree("(%s )" % root)):
                #                print p, t, len(f)

if __name__ == '__main__': test()
