import sys, re
from nltk import Tree
sys.path.append("../../disco-dop")
from treetransforms import unbinarize

words = re.compile(" [A-Za-z]+\)")

def find(sub, str, start):
	idx = str.find(sub, start)
	if idx == -1:
		yield idx
		return
	while idx != -1:
		yield idx
		idx = str.find(sub, idx+1)

def contains(tree, frag):
	""" Check whether a tree contains a fragment; both in string representation."""
	splitfrag = frag.split(" )") # components around frontier nodes
	for idx in find(splitfrag[0], tree, 0):
		if idx == -1: return False
		idx += len(splitfrag[0])
		for a in splitfrag[1:]:
			for newidx in find(")"+a, tree, idx):
				if newidx == -1: return False
				sub = tree[idx:newidx]
				if sub.count("(") - sub.count(")") == 0:
					idx = newidx + len(a) + 1
					break
			else: break
		else: return True
	return False

def f((a, b)):
	return len(words.findall(a)), b
	return len(a.leaves()), b

# read trees
trees = dict(
	(a[:a.index("\t")], None)
	#(a[:a.index("\t")], int(a[a.index("\t")+1:]))
	for a in open(sys.argv[1]))

def contains1(a, b):
	if contains(a[1], b[1]): return 1
	return 0

#newtrees = [a for a in enumerate(trees)]
#newtrees.sort(cmp=contains1)
#prev = -1
#for m, (n, a) in enumerate(newtrees):
#	if n < prev:
#		newtrees = newtrees[:m]
#		break
#	prev = n
#	a1 = Tree(a)
#	try: unbinarize(a1)
#	except: newtrees[m] = (a, trees[a])
#	else: newtrees[m] = (a1.pprint(margin=9999), trees[a])

# longest first
# compared to current list + other fragments of same length
# add to current list

# write output
out = open("%s.maximal" % sys.argv[1], "w")
# remove non-maximal fragments
# unbinarize remaining fragments
newtrees = []
curtrees = []
thetrees = sorted(trees, key=len, reverse=True)
for n, a in enumerate(thetrees):
	candidates = curtrees[:]
	for b in thetrees[n+1:]:
		if len(b) < len(a): break
		candidates.append(b)
	if not any(contains(b, a) for b in candidates if a != b):
		curtrees.append(a)
		a1 = Tree(a)
		try: unbinarize(a1)
		except: newtrees.append((a, trees[a]))
		else: newtrees.append((a1.pprint(margin=9999), trees[a]))
		#a1.un_chomsky_normal_form()
		print n
		out.write("%s\n" % a)
#for a, b in sorted(newtrees, key=f):
#	out.write("%s\t%d\n" % (a, b))
#	out.write("%s\n" % a)
out.close()
print sys.argv[1] + ".maximal"
	#print "%s\t%d" % (" ".join(Tree(a).leaves()), b)
	#print "%s\t%d" % (a.pprint(margin=9999), b)

