from os import listdir
from nltk import FreqDist, AnnotationTask
from nltk.metrics.distance import masi_distance
from numpy import std, mean
import logging
import inspect
logging.basicConfig(level=50)

#input files: all .csv files in current directory
files = [a for a in listdir(".") if ".csv" in a]
# process all examples, where mm < id < m:
mm = 0
m = 154

# to calculate data for some/any separately, ids of examples with some:
some = [143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 98, 99, 100,
	101, 102,73, 74, 75, 76, 77]

do_some = True
do_any = True

def read_data(do_some, do_any):
	data = []
	table = [FreqDist() for a in range(200)]
	hist = FreqDist()
	for f in files:
		# read file and ignore first line
		file1 = open(f).read().splitlines()[1:]
		coder = f.split(".")[0]
		for a in file1:
			n = int(a.split(";")[0])
			if not mm < n < m: continue
			if not do_any and n not in some: continue
			if not do_some and n in some: continue
			label = frozenset(a.split(";")[1].strip().split(","))
			#out.write("%s	%d	%s\n" % (f.split(".")[0], n, label))
			data.append((coder, n, label))
			table[n][label] += 1
			hist[label] += 1
	return data

def my_dist(a,b):
	""" A distance function that equates AM/AA/DN, and SK/SU.
	IR and SK/SU are half correct"""
	if len(set(a | b).intersection(set("AM AA DN".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("SK SU".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("IR SK SU".split()))) == 2:
		return 0.5
	return masi_distance(a, b)

def dist_uniform(a, b):
	return masi_distance(a, b)

def dist_specific_collapsed(a, b):
	""" a and b are sets
	a | b is the union of a and b
	we only collapse a & b if they are singleton sets (no ambiguity).
	"""
	if len(set(a | b).intersection(set("SK SU".split()))) == 2:
		return 0
	return masi_distance(a, b)

def dist_haspelmath(a, b):
	if len(set(a | b).intersection(set("AM AA DN".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("SK SU".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("UFC FC GEN".split()))) == 2:
		return 0
	return masi_distance(a, b)

def dist_haspelmath1(a, b):
	if len(set(a | b).intersection(set("AM AA".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("UFC FC GEN IND".split()))) == 2:
		return 0
	return masi_distance(a, b)

def dist_haspelmath2(a, b):
	if len(set(a | b).intersection(set("AM AA".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("UFC FC GEN".split()))) == 2:
		return 0
	return masi_distance(a, b)

def dist_haspelmath3(a, b):
	if len(set(a | b).intersection(set("SK SU".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("AM AA".split()))) == 2:
		return 0
	if len(set(a | b).intersection(set("UFC FC GEN IND".split()))) == 2:
		return 0
	return masi_distance(a, b)

def collapse(label):
	""" For a given label, produces a distance function that collapses
	all labels except that label, so that the performance of this label
	can be measured compared to all the rest. """
	def f(a, b):
		if a == label and b == label: return 0
		if a == label or b == label: return 1
		return 0
	return f

# read data from files

data = read_data(do_some, do_any)

#calculate kappas
#kappas = []
#wkappas = []
#at = AnnotationTask(data, distance=masi_distance)
#print "items:", len(at.I)
#print "all disagreements equal:"
#for n,a in enumerate(at.C):
#	for b in list(at.C)[n+1:]:
#		kappas.append(at.kappa_pairwise(a, b))
#		wkappas.append(at.weighted_kappa_pairwise(a, b))
#		print a,b, kappas[-1]
#print "kappa", at.kappa(), "std dev", std(kappas)
#print "weighted kappa", at.weighted_kappa(), "std dev", std(wkappas)
#print "alpha", at.alpha()
#print

# new stuff for LREC
for func in (dist_uniform,
			dist_specific_collapsed,
			dist_haspelmath,
			dist_haspelmath1,
			dist_haspelmath2,
			dist_haspelmath3):
	print "".join(inspect.getsourcelines(func)[0])
print "\n\n\n"

for a in range(3):
	if a == 0: do_some = True; do_any = True
	if a == 1: do_some = True; do_any = False
	if a == 2: do_some = False; do_any = True
	data = read_data(do_some, do_any)
	at = AnnotationTask(data, distance=masi_distance)
	if a == 2: print "ANY",
	if a == 1: print "SOME",
	if a == 0: print "ANY + SOME",
	print "=", len(at.I), "examples"
	for func in (dist_uniform,
				dist_specific_collapsed,
				dist_haspelmath,
				dist_haspelmath1,
				dist_haspelmath2,
				dist_haspelmath3):
		kappas = []
		wkappas = []
		at = AnnotationTask(data, distance=func)
		print "distance function:", func.__name__
		for n,a in enumerate(at.C):
			for b in list(at.C)[n+1:]:
				kappas.append(at.kappa_pairwise(a, b))
				wkappas.append(at.weighted_kappa_pairwise(a, b))
				#print a,b, kappas[-1]
		print "kappa", at.kappa(), "std dev", std(kappas)
		print "weighted kappa", at.weighted_kappa(), "std dev", std(wkappas)
		print "alpha", at.alpha()
		print
	print

exit()

# original distance function as in report
at = AnnotationTask(data, distance=my_dist)
print "{AM,AA,DN} / {SK,SU} collapsed, d(IR, {SK,SU})=0.5"
for n,a in enumerate(at.C):
	for b in list(at.C)[n+1:]:
		kappas.append(at.kappa_pairwise(a, b))
		wkappas.append(at.weighted_kappa_pairwise(a, b))
		print a,b, kappas[-1]
print "kappa", at.kappa(), "std dev", std(kappas)
print "weighted kappa", at.weighted_kappa(), "std dev", std(wkappas)
print "alpha", at.alpha()
print

#calculate scores for individual labels
collapsedlabels = FreqDist()
lkappas = FreqDist()
for a in "SK SU IR Q IND CA CO FC AA AM DN UFC GEN".split():
	label = frozenset([a])
	# calculate weighted kappa with a distance function collapsed for a
	# single label
	at=AnnotationTask(data, distance=collapse(label))
	collapsedlabels.inc(a, count=at.weighted_kappa())
	kappas = []
	# calculate kappa for a single label
	for n,cA in enumerate(at.C):
		for cB in list(at.C)[n+1:]:
			A = set(a['item'] for a in at.data if a['coder']==cA
				and a['labels'] == label)
			B = set(a['item'] for a in at.data if a['coder']==cB
				and a['labels'] == label)
			Ao = len(A.intersection(B)) / float(len(at.I))
			Ae = ((at.N(c=cA,k=label) / float(len(at.I))) *
				(at.N(c=cB,k=label) / float(len(at.I))))
			kappas.append((Ao - Ae) / (1.0 - Ae))
	lkappas.inc(a, count=sum(kappas) / float(len(kappas)))

print "label scores with weighted kappas"
for a,k in collapsedlabels.items():
	print "kappa %s-rest" % a, k

print "\nlabel scores with individual kappas"
for a,k in lkappas.items():
	print "kappa %s" % a, k

exit()

# print an agreement table
for n,a in enumerate(table):
	if a.items(): print n, " ".join("%s:%d" %(b,c) for b,c in a.items())

#print a histogram showing frequency of each label
print
print hist
out.close()
