from itertools import count
from random import choice
import glob, re, sys
import numpy as np
from matplotlib.pyplot import *
from nltk import Tree
ops = dict(max=max,
		sum=sum,
		median=np.median,
		mean=np.mean,
		freqtypes=len,
		fragtypes=len)
if len(sys.argv) == 1:
	print 'usage: plot.py <op> [syn|front|lex] [java]'
	exit()
op = ops[sys.argv[1]]
termsre = re.compile(r'[A-Za-z]\)')
convpenn = re.compile(r' ([^ ()"]+)\b')
#folder = "lit3"
folder = "lit2500"
#folder = "bks2500"
files = sorted(glob.glob("%s/*.fragments" % folder))
texts = sorted(glob.glob("%s/*.stp1" % folder))
if "java" in sys.argv:
	files = sorted(glob.glob("%s/*.fragmentsjava/fragments_exactFreq.txt" % folder))
booknames = [file.split('/',1)[1].split(".")[0] for file in files]
figure()
output = {}
exampfrag = {}

print "nodes &\tsentences & words & avg.~sent.~len.&\tfragments & \ttext \\\\ \\midrule"
for m, file, text, book in zip(count(), files, texts, booknames):
	assert file.startswith(text)
	parsetrees = open(text).read()
	book = file.split('/',1)[1].rsplit(".", 3)[0]
	nodes = parsetrees.count("(")
	sents = parsetrees.count("\n")
	words = float(len(termsre.findall(parsetrees)))
	if "java" in sys.argv: # fragment<TAB>freq.
		# convert (NP DT (NN "this")) to (NP (DT ) (NN this))
		freqs = dict((convpenn.sub(r" (\1 )",
			line[:line.index("\t")]).replace('"', ''),
			int(line[line.index("\t")+1:])) for line in open(file))
	else: # fragment<TAB>indices
		freqs = dict((line[:line.index("\t")],
			len(line[line.index("\t")+2:-2].split(",")))
			for line in open(file))
	if "syn" in sys.argv: # syntactic fragments only (no terminals)
		freqs = dict((a, b) for a, b in freqs.iteritems()
			if termsre.search(a) is None)
	elif "front" in sys.argv: # at least one terminal and at least one frontier
		freqs = dict((a, b) for a, b in freqs.iteritems()
			if ' )' in a and termsre.search(a))
	elif "lex" in sys.argv: # fragments with only terminals (no frontiers)
		freqs = dict((a, b) for a, b in freqs.iteritems() if ' )' not in a)
	print "%d &\t%d &\t%d &\t%5.2f &\t%d &\t%s \\\\" % (
			nodes, sents, words, words/float(sents), len(freqs), book)
	nrnodes = dict((frag, frag.count("(")) for frag in freqs)
	h = {}
	#exampfrag[book] = max((frag for frag, n in nrnodes.items() if n == 9), key=freqs.get)
	#exampfrag[book] = choice([frag for frag, n in nrnodes.items() if n == 13 and freqs[frag] > 10])
	#exampfrag[book] = choice([frag for frag, n in nrnodes.items()
	exampfrag[book] = max(
		filter(lambda x: freqs[x] > 2 and len(termsre.findall(x)) > 3
			and ' )' in x, freqs),
		key=lambda x: freqs[x] * len(termsre.findall(x)))

	exampfrag[book] = (exampfrag[book], freqs[exampfrag[book]])
	#	if n > 5 and freqs[frag] > 10 and frag.count(" )") == 0])
	if sys.argv[1] == 'fragtypes':
		for frag in freqs: h.setdefault(nrnodes[frag], []).append(frag)
		#for n in h: print n, max(h[n], key=lambda x: freqs[x])
	else:
		for frag in freqs: h.setdefault(nrnodes[frag], []).append(freqs[frag])
	for n in h: h[n] = op(h[n])
	if 'wordsnorm' in sys.argv:
		for n in h: h[n] /= words
	xvalues = [x for x in sorted(h.keys()) if 5 <= x <= 30]
	#yvalues = [150000.0 * float(x)/nodes for x in h.values()]
	yvalues = [float(h[x]) for x in xvalues]
	# fixme: add null values.
	for a, b in h.items(): output.setdefault(a, [''] * len(texts))[m] = str(b)
	plot(xvalues, yvalues, label=book)
	#hist(h.values(), bins=h.keys(), label=file.split(".")[0])

print
for a,b in exampfrag.values(): print b, a
print
for a, (b, c) in exampfrag.iteritems():
	print Tree(b).pprint_latex_qtree().replace("_",r"\_"),
	print a.replace("_"," "), 'freq: %d\n' % c
with open("%s/data-%s.txt" % (folder, sys.argv[1]), "w") as f:
	f.writelines("#%d\t%s\n" % a for a in enumerate(booknames))
	f.write("#size\t%s\n" % "\t".join(map(str, range(len(booknames)))))
	f.writelines("%d\t%s\n" % (a, "\t".join(output[a]))
			for a in sorted(output))
title('histogram of recurring syntactic fragments ' + ' '.join(sys.argv[1:]))
xlabel('# nodes in fragment')
if sys.argv[1] == 'freqtypes': ylabel('# frequency types')
else: ylabel('%s of frequencies' % (sys.argv[1]))
legend()
show()
