# -*- coding: utf-8 -*-
"""DOP1 implementation. Andreas van Cranenburgh <andreas@unstable.nl>
TODOs: unicode support (replace str and repr calls)"""
try:
	import psyco
	psyco.full()
except ImportError:
	print "consider installing psyco (32bit only)"

from time import time, sleep
from collections import defaultdict
from itertools import chain, count
from subprocess import Popen, PIPE
from uuid import uuid1
#from math import log #do something with logprobs instead?
from nltk import Production, WeightedProduction, WeightedGrammar
from nltk import Tree, Nonterminal, FreqDist, InsideChartParser, UnsortedChartParser

def cartprod(a, b):
	""" cartesian product of two sequences """
	for x in a:
		for y in b:
			yield x, y

def cartpi(seq):
	""" produce a flattened fold using cartesian product as operator

	>>> list(cartpi( ( (1,2), (3,4), (5,6) ) ) )
	[(1, 3, 5), (2, 3, 5), (1, 4, 5), (2, 4, 5), (1, 3, 6), (2, 3, 6), 
	(1, 4, 6), (2, 4, 6)] """
	if len(seq) == 0: return ((), )
	else: return ((a,) + b for b in cartpi(seq[1:]) for a in seq[0])

#NB: the following code is equivalent to nltk.Tree.productions, except for accepting unicode
def productions(tree):
	"""
	Generate the productions that correspond to the non-terminal nodes of the tree.
	For each subtree of the form (P: C1 C2 ... Cn) this produces a production of the
	form P -> C1 C2 ... Cn.
		@rtype: list of C{Production}s
	"""

	if not (isinstance(tree.node, str) or isinstance(tree.node, unicode)):
		raise TypeError, 'Productions can only be generated from trees having node labels that are strings'

	prods = [Production(Nonterminal(tree.node), tree._child_names())]
	for child in tree:
		if isinstance(child, Tree):
			prods += productions(child)
	return prods

class BitParChartParser:
	def __init__(self, weightedrules, rootsymbol="S"):
		""" Interface to bitpar chart parser. Expects a list of weighted
		productions with frequencies (not probabilities).

		>>> tree = Tree("(S (NP mary) (VP walks))")
		>>> d = GoodmanDOP([tree], parser=BitParChartParser)
		    writing grammar
		>>> d.parser.parse("mary walks".split())
		Tree('S', [Tree('NP@1', ['mary']), Tree('VP@2', ['walks'])])

		should become: (by parsing bitpar's chart output)
		ProbabilisticTree('S', [ProbabilisticTree('NP@1', ['mary'])
		(p=1.0), ProbabilisticTree('VP@2', ['walks']) (p=1.0)])
		(p=0.444444444444)"""

		self.grammar = weightedrules
		self.rootsymbol = rootsymbol
		self.id = uuid1()
		self.writegrammar('/tmp/g%s.pcfg' % self.id, '/tmp/g%s.lex' % self.id)
		self.start()

	def __del__(self):
		cmd = "rm /tmp/g%s.pcfg /tmp/g%s.lex" % (self.id, self.id)
		Popen(cmd.split())
		self.stop()

	def start(self):
		options = "bitpar -q -b 1 -p -s %s -u unknownwords /tmp/g%s.pcfg /tmp/g%s.lex" % (self.rootsymbol, self.id, self.id)
		self.bitpar = Popen(options.split(), stdin=PIPE, stdout=PIPE, stderr=PIPE)

	def stop(self):
		if not isinstance(self.bitpar.poll(), int):
			self.bitpar.terminate()

	def parse(self, sent, timeout=2):
		if isinstance(self.bitpar.poll(), int): self.start()
		"""
        # poll for terminated status till timeout is reached
		t_beginning = time()
		seconds_passed = 0
		while True:
			if self.bitpar.poll() is not None: break
			seconds_passed = time() - t_beginning
			if timeout and seconds_passed > timeout:
				self.bitpar.terminate()
				raise ValueError #TimeoutError(cmd, timeout)
			sleep(0.1)
		"""
		result, stderr = self.bitpar.communicate("%s\n\n" % "\n".join(sent))
		try:
			return Tree(result)
		except:
			# bitpar returned some error or didn't produce output
			raise ValueError("no output. stdout: \n%s\nstderr:\n%s " % (result.strip(), stderr.strip()))

	def writegrammar(self, f, l):
		""" write a grammar to files f and l in a format that bitpar 
		understands. f will contain the grammar rules, l the lexicon 
		with pos tags. """
		f, l = open(f, 'w'), open(l, 'w')
		lex = defaultdict(list)
		def process():
			for (lhs, rhs), freq in self.grammar:
				if len(rhs) == 1 and not isinstance(rhs[0], Nonterminal):
					#if
					lex[rhs[0]].append(" ".join(map(repr, (lhs, freq))))
				# this should NOT happen:
				elif len(rhs) == 0 or '' in (str(a).strip() for a in rhs): continue
				else:
					# prob^Wfrequency	lhs	rhs1	rhs2
					yield "%s\t%s\t%s\n" % (repr(freq), str(lhs), 
								"\t".join(str(x) for x in rhs))
		#f.write(''.join(process()))
		f.writelines(process())
		def proc(lex):
			for word, tags in lex.items():
				# word	POS1 prob^Wfrequency1	POS2 freq2 ...
				yield "%s\t%s\n" % (word, "\t".join(tags))
		l.writelines(proc(lex))
		f.close()
		l.close()

class GoodmanDOP:
	def __init__(self, treebank, rootsymbol='S', wrap=False, parser=InsideChartParser):
		""" initialize a DOP model given a treebank. uses the Goodman
		reduction of a STSG to a PCFG.  after initialization,
		self.parser will contain an InsideChartParser.

		>>> tree = Tree("(S (NP mary) (VP walks))")
		>>> d = GoodmanDOP([tree])
		>>> d.parser.parse("mary walks".split())
		ProbabilisticTree('S', [ProbabilisticTree('NP@1', ['mary']) (p=1.0), 
		ProbabilisticTree('VP@2', ['walks']) (p=1.0)]) (p=0.444444444444)

		instance variables:
		- self.grammar a WeightedGrammar containing the PCFG reduction
		- self.fcfg a list of WeightedProductions containing the PCFG reduction with frequencies instead of probabilities
		- self.parser an InsideChartParser object
		- self.exemplars dictionary of known parse trees (memoization)"""
		nonterminalfd, ids = FreqDist(), count()
		cfg = []
		self.exemplars = {}
		if wrap:
			# wrap trees in a common root symbol (eg. for morphology)
			treebank = [Tree(rootsymbol, [a]) for a in treebank]
		utreebank = list((tree, self.decorate_with_ids(tree, ids)) for tree in treebank)
		cfg = reduce(chain, (self.goodman(tree, utree) for tree, utree in utreebank))

		for tree,utree in utreebank:
			#self.exemplars[tuple(tree.leaves())] = tree
			self.nodefreq(tree, nonterminalfd)
			self.nodefreq(utree, nonterminalfd)
			#cfg.extend(self.goodman(tree, utree))
			#cfg.update(zip(self.goodmanfd(tree, ids, nonterminalfd), ones))
			#cfg.extend(self.goodmanfd(tree, ids, nonterminalfd))
		self.fcfg = self.frequencies(cfg, nonterminalfd)
		if parser == BitParChartParser:
			print "writing grammar"
			self.parser = BitParChartParser(self.fcfg, rootsymbol)
		else:
			self.grammar = WeightedGrammar(Nonterminal(rootsymbol),
				self.probabilities(cfg, nonterminalfd))
			self.parser = InsideChartParser(self.grammar)
		#stuff for self.mccparse
		#the highest id
		#self.addresses = ids.next()
		#a list of interior + exterior nodes, 
		#ie., non-terminals with and without ids
		#self.nonterminals = nonterminalfd.keys()
		#a mapping of ids to nonterminals without their IDs
		#self.nonterminal = dict(a.split("@")[::-1] for a in 
		#	nonterminalfd.keys() if "@" in a)

		#clean up
		#del cfg, nonterminalfd

	def goodmanfd(self, tree, ids, nonterminalfd):
		utree = self.decorate_with_ids(tree, ids)
		self.nodefreq(tree, nonterminalfd)
		self.nodefreq(utree, nonterminalfd)
		return self.goodman(tree, utree)
	
	def decorate_with_ids(self, tree, ids):
		""" add unique identifiers to each non-terminal of a tree.

		>>> tree = Tree("(S (NP mary) (VP walks))")
		>>> d = GoodmanDOP([tree])
		>>> d.decorate_with_ids(tree, count())
		Tree('S@0', [Tree('NP@1', ['mary']), Tree('VP@2', ['walks'])])

			@param ids: an iterator yielding a stream of IDs"""
		utree = tree.copy(True)
		for a in utree.subtrees():
			a.node = "%s@%d" % (a.node, ids.next())
		return utree
	
	def nodefreq(self, tree, nonterminalfd, leaves=1):
		"""count frequencies of nodes by calculating the number of
		subtrees headed by each node.

		>>> fd = FreqDist()
		>>> tree = Tree("(S (NP mary) (VP walks))")
		>>> d = GoodmanDOP([tree])
		>>> d.nodefreq(tree, fd)
		9
		>>> fd.items()
		[('S', 9), ('NP', 2), ('VP', 2), ('mary', 1), ('walks', 1)]

			@param nonterminalfd: the FreqDist to store the counts in."""
		if isinstance(tree, Tree) and len(tree) > 0:
			n = reduce((lambda x,y: x*y), 
				(self.nodefreq(x, nonterminalfd) + 1 for x in tree))
			nonterminalfd.inc(tree.node, count=n)
			return n
		else:
			nonterminalfd.inc(str(tree), count=leaves)
			return leaves

	def goodman(self, tree, utree):
		""" given a parsetree from a treebank, yield a goodman
		reduction of eight rules per node (in the case of a binary tree).

		>>> tree = Tree("(S (NP mary) (VP walks))")
		>>> d = GoodmanDOP([tree])
		>>> utree = d.decorate_with_ids(tree, count())
		>>> sorted(list(d.goodman(tree, utree)))
		[(NP, ('mary',)), (NP, ('mary',)), (NP@1, ('mary',)), 
		(NP@1, ('mary',)), (S, (NP, VP)), (S, (NP, VP@2)), 
		(S, (NP@1, VP)), (S, (NP@1, VP@2)), (VP, ('walks',)), 
		(VP, ('walks',)), (VP@2, ('walks',)), (VP@2, ('walks',))]
		
		got:
		[(NP, ('mary',)), (NP, ('mary',)), (NP@1, ('mary',)),
		(NP@1, ('mary',)), (S, (NP, VP)), (S, (NP, VP@2)),
		(S, (NP@1, VP)), (S, (NP@1, VP@2)), (S@0, (NP, VP)),
		(S@0, (NP, VP@2)), (S@0, (NP@1, VP)), (S@0, (NP@1, VP@2)),
		(VP, ('walks',)), (VP, ('walks',)), (VP@2, ('walks',)), 
		(VP@2, ('walks',))]
		"""
		# linear: nr of nodes
		for p, up in zip(tree.productions(), utree.productions()): 
			# THIS SHOULD NOT HAPPEN:
			if len(p.rhs()) == 0: continue #raise ValueError
			if len(p.rhs()) == 1: rhs = (p.rhs(), up.rhs())
			#else: rhs = cartprod(*zip(p.rhs(), up.rhs()))
			else: rhs = cartpi(zip(p.rhs(), up.rhs()))

			# constant factor: 8
			#for l, r in cartpi(((p.lhs(), up.lhs()), rhs)):
			for l, r in cartprod((p.lhs(), up.lhs()), rhs):
				yield l, r
				# yield a function that given a distribution of nonterminals
				# yields a rule (ie., lazy evaluation) with probabilities 
				# based on that distribution.
				#yield (lambda fd: WeightedProduction(l, r, prob= 
				#reduce(lambda x,y: x*y, map(lambda z: '@' in z and 
				#fd[z] or 1, r)) / float(fd[l])))
	
	def probabilities(self, cfg, fd):
		"""merge cfg and frequency distribution into a pcfg with the right
		probabilities.

			@param cfg: a list of Productions
			@param nonterminalfd: a FreqDist of (non)terminals (with and
			without IDs)""" 
		#return [a(nonterminalfd) for a in cfg)
		def prob(l, r):
			return reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) 
				and fd[str(z)] or 1), r)) / float(fd[str(l)])
		# format expected by mccparse()
		#self.pcfg = dict((Production(l, r), (reduce((lambda x,y: x*y), 
		#	map((lambda z: '@' in (type(z) == Nonterminal and z.symbol() or z) 
		#	and nonterminalfd[z] or 1), r)) / nonterminalfd[l]))
		#	for l, r in set(cfg))

		# merge identical rules:
		cfgfd = FreqDist(cfg)
		return [WeightedProduction(l, r, prob=cfgfd[(l,r)]*prob(l, r)) for l, r in cfgfd]
		# do not merge identical rules
		#return [WeightedProduction(l, r, prob=prob(l, r)) for l, r in cfg]
	
	def frequencies(self, cfg, fd):
		"""merge cfg and frequency distribution into a list of weighted 
		productions with frequencies as weights (as expected by bitpar).

			@param cfg: a list of Productions
			@param nonterminalfd: a FreqDist of (non)terminals (with and
			without IDs)""" 
		def prob(r):
			return reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) 
				and fd[str(z)] or 1), r), 1)

		# merge identical rules:
		#cfgfd = FreqDist(cfg)
		#for rule,cnt in cfgfd.items():
		#	cfgfd.inc(rule, count=(cnt-1) * prob(*rule))
		#return cfgfd
		return ((rule, reduce((lambda x,y: x*y), map((lambda z: '@' in str(z) and fd[str(z)] or 1), rule[1]), 1)) for rule in cfg)
			#rule.append(prob(*rule))
		#return [(rule, prob(rule[1])) for rule in cfg]

	def parse(self, sent):
		""" memoize parse trees. TODO: maybe add option to add every
		parse tree to the set of exemplars, ie., incremental learning. """
		try:
			return self.exemplars[tuple(sent)]
		except KeyError:
			self.exemplars[tuple(sent)] = self.parser.parse(sent)
			return self.exemplars[tuple(sent)]

	def mccparse(self, sent):
		""" not working yet. almost verbatim translation of Goodman's (1996)
		most constituents correct parsing algorithm, except for python's
		zero-based indexing. needs to be modified to return the actual parse
		tree. expects a pcfg in the form of a dictionary from productions to
		probabilities """ 
		def g(s, t, x):
			def f(s, t, x):
				return self.pcfg[Production(rootsymbol, 
					sent[1:s] + [x] + sent[s+1:])]
			def e(s, t, x): 
				return self.pcfg[Production(x, sent[s:t+1])]
			return f(s, t, x) * e(s, t, x ) / e(1, n, rootsymbol)

		sumx = defaultdict(int) #zero
		maxc = defaultdict(int) #zero
		for length in range(2, len(sent)+1):
			for s in range(1, len(sent) + length):
				t = s + length - 1
				for x in self.nonterminals:
					sumx[x] = g(s, t, x)
				for k in range(self.addresses):
					#ordered dictionary here
					x = self.nonterminal[k]
					sumx[x] += g(s, t, "%s@%d" % (x, k))
				max_x = max(sumx[x] for x in self.nonterminals)
				#for x in self.nonterminals:
				#	max_x = argmax(sumx, x) #???
				best_split = max(maxc[(s,r)] + maxc[(r+1,t)] 
									for r in range(s, t))
				#for r in range(s, t):
				#	best_split = max(maxc[(s,r)] + maxc[(r+1,t)])
				maxc[(s,t)] = sumx(max_x) + best_split
		return maxc[(1, len(sent) + 1)]
				
def main():
	""" a basic REPL for testing """
	corpus = """(S (NP John) (VP (V likes) (NP Mary)))
(S (NP Peter) (VP (V hates) (NP Susan)))
(S (NP Harry) (VP eats (NP pizza)))
(S (NP Hermione) (VP eats))""".split('\n')
	#corpus = """(NN (N (A (A (A mal) (J riĉ)) (A eg)) (A ul)) o)
	#(NN (N (N (N (V kudr) (A ist)) (A in)) (N edz)) o)
	#(JJ (JJ (V (V (P en) (V konduk)) (V it)) a) j)""".split('\n')
	print d.grammar
	w = "foo!"
	while w:
		print "sentence:",
		w = raw_input().split()
		try:
			print d.parse(w)
		except Exception as e:
			print "error", e

def cnf(tree):
	""" make sure all terminals have POS tags; 
	invent one if necessary ("parent_word") """
	result = tree.copy(True)
	for a in tree.treepositions('leaves'):
		if len(tree[a[:-1]]) != 1:
			result[a] = Tree("%s_%s" % (tree[a[:-1]].node, tree[a]), [tree[a]])
	return result

def monato():
	d = GoodmanDOP((Tree(a) for a in open("arbobanko.sexp")), rootsymbol='STA:fcl', parser=BitParChartParser)
	#d = GoodmanDOP((Tree(a) for a in corpus), rootsymbol='S', wrap=True)
	#print d.grammar
	w = "foo!"

	# basic REPL
	while w:
		print "sentence:",
		w = raw_input().split()
		if not w: break	#quit
		try:
			print d.parse(w)
		except Exception as e:
			print "error:", e

def morphology():
	from corpus import corpus
	#corpus = ["(S (NP (NN amiko)) (VP (VB venis)))"]
	d = GoodmanDOP((Tree(a) for a in corpus), rootsymbol='S', parser=BitParChartParser)
	print "built syntax model"

	mcorpus = open("morph.corp.txt").readlines()
	md = GoodmanDOP((cnf(Tree(a)) for a in mcorpus), rootsymbol='W', wrap=True, parser=BitParChartParser)
	print "built morphology model"

	segmentd = dict(("".join(a), tuple(a)) for a in (Tree(a).leaves() for a in mcorpus))
	print "morphology exemplars: ", " ".join(segmentd.keys())
	print "segmentation dictionary size:", len(segmentd),

	def dos(words):
		""" `Data-Oriented Segmentation:' given a sequence of segmented words
		(ie., a sequence of morphemes), produce a dictionary with extrapolated
		segmentations (mapping words to sequences of morphemes). 
		Assumes non-ambiguity. 
		Method: cartesian product of all possible morphemes at position 0..n, where n is maximum word length."""
		l = [len(a) for a in words]
		morph_at = dict((x, set(a[x] for a,n in zip(words, l) if n > x)) 
							for x in range(0, max(l)))
		return dict(("".join(a), a) for a in 
			reduce(chain, (cartpi([morph_at[x] for x in range(n)]) 
				for n in range(min(l), max(l)))))
	def dos1(words):
		""" `Data-Oriented Segmentation:' given a sequence of segmented words
		(ie., a sequence of morphemes), produce a dictionary with extrapolated
		segmentations (mapping words to sequences of morphemes). 
		Discards ambiguous results.
		Method: cartesian product of all words with the same number of morphemes. """
		l = [len(a) for a in words]
		return dict(("".join(a), a) for a in 
			reduce(chain, (cartpi(zip(*(w for w, m in zip(words, l) if m == n)))
				for n in range(min(l), max(l)))))
	def dos2(words):
		#bigram model
		pass
	segmentd = dos1(set(segmentd.values()))
	#restore original original words in case they were overwritten
	for a in (Tree(a).leaves() for a in mcorpus):
		segmentd["".join(a)] = tuple(a)

	print "extrapolated:", len(segmentd) #, " ".join(segmentd.keys())

	def segment(w):
		""" consult segmentation dictionary with fallback to rule-based heuristics. """
		try: return segmentd[w]
		#naive esperanto segmentation (assume root of the appropriate type)
		except KeyError:
			if w[-1] in 'jn': return segment(w[:-1]) + (w[-1],)
			if w[-1] in 'oaeu': return (w[:-1], w[-1])
			if w[-1] == 's': return (w[:-2], w[-2:])
		return (w,)

	def removeids(tree):
		""" remove unique IDs introduced by the Goodman reduction """
		for a in tree.treepositions():
			if '@' in str(tree[a]):
				tree[a].node = tree[a].node.split('@')[0]
		return tree

	def morphmerge(tree):
		""" merge morphology into phrase structure tree """
		copy = tree.copy(True)
		for a in tree.treepositions('leaves'):
			try:
				print tree[a[:-1]][0],
				copy[a[:-1]] = removeids(md.parse(segment(tree[a[:-1]][0]))[0])
			except Exception as e:
				print "word:", tree[a[:-1]][0],
				print "error:", e
		return copy

	print "analyzing morphology of treebank"
	mtreebank = []
	for n, a in enumerate(corpus):
		print '%d / %d:' % (n, len(corpus)-1),
		mtreebank.append(morphmerge(Tree(a)))
		print

	#mtreebank = [m(Tree(a)) for a in corpus]
	#for a in mtreebank: print a
	msd = GoodmanDOP(mtreebank, rootsymbol='S', parser=BitParChartParser)
	print "built combined morphology-syntax model"

	#d.writegrammar('/tmp/syntax.pcfg', '/tmp/syntax.lex')
	#md.writegrammar('/tmp/morph.pcfg', '/tmp/morph.lex')
	#msd.writegrammar('/tmp/morphsyntax.pcfg', '/tmp/morphsyntax.lex')

	#print d.grammar
	w = "foo!"

	# basic REPL
	while w:
		print "sentence:",
		w = raw_input().split()
		if not w: break	#quit

		print "morphology:"
		for a in w:
			try:
				print a, md.parse(segment(a))[0]
			except Exception as e:
				print "error:", e
		
		print "morphology + syntax combined:"
		try:
			sent = list(reduce(chain, (segment(a) for a in w)))
			print sent
			print msd.parse(sent)
			#for tree in d.parser.nbest_parse(w):
			#	print tree
		except Exception as e:
			print "error", e

		try:
			print "syntax:"
			a = d.parse(w + ['\n'])
			print a

			print "syntax & morphology separate:"
			print morphmerge(a)
			#sent = ["".join(a.split('|')) for a in w]
			#for tree in d.parser.nbest_parse(w):
			#	print tree
		except Exception as e:
			print "error:", e

if __name__ == '__main__': 
	import doctest
	# do doctests, but don't be pedantic about whitespace (I suspect it is the
	# militant anti-tab faction who are behind this obnoxious default)
	fail, attempted = doctest.testmod(verbose=False,
	optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS)
	if attempted and not fail:
		print "%d doctests succeeded!" % attempted
	morphology()
	#monato()
