#DOP1 implementation
#Andreas van Cranenburgh <andreas@unstable.nl>

import nltk
from itertools import *
from math import log

def subsets(seq):
	""" iterate over subsets of a sequence """
	def combinations(items, n):
		""" iterate over combinations of length n of a sequence """
		if n==0: yield ()
		else:
			for i in xrange(len(items)-n+1):
				for cc in combinations(items[i+1:],n-1):
					yield (items[i],)+cc
	for n in reversed(range(len(seq))):
		for a in combinations(seq, n):
			if a: yield a
	yield ()

class Dop:
	def __init__(self, corpus, removeinternal=False, PCFG=False):
		""" initialize a DOP model given a treebank """
		self.corpus = []
		for a in corpus:
			self.corpus.extend(self.subtrees(a))
		
		if PCFG:
			self.corpus = [a for a in self.corpus if a.height() == 2]

		def removeint(a):
			return nltk.Tree(a.node, [a[x] for x in a.treepositions() if len(a[x]) == 0 or type(a[x]) == type("")]).freeze()
		if removeinternal:
			self.corpus = [removeint(a) for a in self.corpus if removeint(a).height() > 1]
		self.fd = nltk.FreqDist(a.freeze() for a in self.corpus)
		self.fdl = nltk.FreqDist(a.node for a in self.corpus)
		self.corpus = dict((b, set(c for c in self.corpus if c.node == b)) for b in set(a.node for a in self.corpus))
		#self.corpus = set(self.corpus)
	
	def parse(self, sent, derivation=nltk.Tree('(S )'), prob=1, fr=[], prefix=False):
		""" return an iterator over all possible parses of a sentence """
		def firstemptynode(tree):
			return (a for a in tree.treepositions() if len(tree[a]) == 0).next()
		def leaves(tree):
			return (tree[a] for a in tree.treepositions() 
					if type(tree[a]) == type("") or len(tree[a]) == 0)
		def firstdiff(a,b):
			return (n for n,(x,y) in enumerate(zip(a,b)) if x != y).next()
		def eligible(tree, sent):
			return all(a == b or len(a) == 0 for a,b in zip(leaves(tree), sent))
		if sent == derivation.leaves() or (prefix 
		and (len(sent) <= len(derivation.leaves()) 
		and derivation.leaves()[:len(sent)] == sent)): 
			yield prob, derivation.freeze(), fr
		elif len(derivation.leaves()) == len(list(leaves(derivation))):
			return
		else:
			sentremaining = sent[firstdiff(sent, leaves(derivation)):]
			node = firstemptynode(derivation)
			fragments = [a for a in self.corpus[derivation[node].node] 
							if eligible(a, sentremaining)]
			for fragment in fragments:
				d = Dop.join(derivation, fragment)
				for a in self.parse(sent, d, prob * (self.fd[fragment.freeze()]/float(self.fdl[fragment.node])), fr + [fragment], prefix): 
					yield a

	def subtrees(self, tree):
		""" yield all subtrees as defined by DOP1 """
		def validdeletion(s):
			""" check whether a set of frontier nodes to be deleted is valid,
				ie., test whether it does not contain children of nodes to 
				be deleted """
			if () in s: return False
			for n, a in enumerate(s):
				for m, b in enumerate(a[:-1]):
					if a[:m+1] in s[:n]:
						return False
			return True
					
		#iterate over all complete subtrees:
		for a in tree.subtrees():
			#and then yield trees with all possible deletions of frontier nodes
			leaves = set(a.treepositions(order='leaves'))
			for b in subsets([x for x in a.treepositions() if x not in leaves]):
				if not validdeletion(b): continue
				#make deep copies or you'll go crazy.
				copy = a.copy(deep=True)
				for c in b:
					copy[c] = nltk.Tree(copy[c].node, [])
				yield copy.freeze()

	@classmethod
	def join(cls, a, b):
		if not type(a) in (nltk.tree.Tree, nltk.tree.ImmutableTree): return
		if not type(b) in (nltk.tree.Tree, nltk.tree.ImmutableTree): return
		for c in a.treepositions():
			# non-terminal node, same label as b, empty node
			if type(a[c]) != str and a[c].node == b.node and len(a[c]) == 0:
				d = a.copy(True)
				d[c].extend(b)
				return d

	def prodprob(self):
		for a in sorted(self.fd):
			print '%s %d / %d' % (a.productions()[0], 
									self.fd[a], self.fdl[a.node])
	
	def prefixprob(self, sent):
		return sum(p for p,a,f in self.parse(sent, prefix=True))

	def surprisal(self, sent):
		if len(sent) <= 1: p1 = 1
		else: p1 = self.prefixprob(sent[:-1])
		p2 = self.prefixprob(sent)
		if p1 and p2:
			return log(p1) - log(p2)

def main():
	t = nltk.Tree.parse('(S (NP John) (VP (V likes) (NP Mary)))')
	u = nltk.Tree.parse('(S (NP Peter) (VP (V hates) (NP Susan)))')
	w = nltk.Tree('(S (NP Harry) (VP eats (NP pizza)))')
	v = nltk.Tree('(S (NP Hermione) (VP eats))')

	dop = Dop([w, v], removeinternal=True)
	pcfg = Dop([w, v], PCFG=True)

	print "DOP productions:"
	dop.prodprob()
	print "PCFG productions:"
	pcfg.prodprob()

	s1 = "Hermione eats pizza".split()
	s2 = "pizza eats Harry".split()

	print "PCFG:"
	for a in (s1, s2):
		for b in range(1,4):
			print "prefix prob:", " ".join(a[:b]), pcfg.prefixprob(a[:b])

	for a in (s1, s2):
		for b in range(1,4):
			print "surprisal: ", " ".join(a[:b]), pcfg.surprisal(a[:b])

	print "DOP:"
	for a in (s1, s2):
		for b in range(1,4):
			print "prefix prob:", " ".join(a[:b]), dop.prefixprob(a[:b])

	for a in (s1, s2):
		for b in range(1,4):
			print "surprisal: ", " ".join(a[:b]), dop.surprisal(a[:b])

if __name__ == '__main__': main()
