from oneword import *
from sys import stdin, stdout, argv
"""
The two-word algorithm will learn linguistic abstractions/structures (LA) for
combining two words of whom it knows the meaning from the one-word stage. When
it receives a situation (SI) together with the coupled utterance it will
analyze this and either reinforce already existing linguistic abstractions or
hypothesize a new linguistic abstraction.

Datastructures:
 - LADict: tuples of linguistic abstractions and their scores, with hashes as
   keys
 - MDict: similar for meanings
 - M2WLADict: dictionary with meaning hashes as keys and list of tuples with a
   word and a linguistic abstraction hashes as values.
 - associations and frameindex from oneword
"""

global associations, frameindex, LADict, MDict, M2WLADict, W2MLADict, meaningcounters
LADict, MDict, M2WLADict, W2MLADict = {}, {}, {}, {}

def twMain():
	"""
	Create one-word data. Pass situations in the corpus to analyze using this
	data.
	"""
	def help():
		print "Argument missing. Please choose one of the following modes of operation:"
		print "f : frames to words experiment"
		print "w : words to frames experiment"
		print "o : oneword output test"
		print "c : chat demo."
		print "x : count unique generated utterances"
	
	#decide which reporting method is wanted
	try:
		method = argv[1]
	except:
		help()
		return
	if method not in "fwocx":
		help()
		return
	
	#oneword part: make meaning-word associations
	global associations, frameindex, meaningcounters, situations
	try:
		corpus = argv[2]
	except:
		corpus = "corpus.xml"
	xmldoc = minidom.parse(corpus).documentElement

	associations, frameindex, meaningcounters = oneword(xmldoc)
	
	
	#twoword part: analyze situations, making and reinforcing new LAs
	situations = xmldoc.getElementsByTagName("situation")
	print "twoword: Analyzing"
	for i, a in enumerate(situations):
		analyze(a)
		print i,
		stdout.flush()
	print
	global LADict
	
	if method == "f":
		#report found utterances for all situations
		frames2wordstest()
	elif method == "w":
		#ask for words and report corresponding frames
		words2framestest()
	elif method == "o":
		#do oneword test
		onewordtest()
	elif method == "x":
		uniqueutterancestest()
	elif method == "c":
		chattest()
	else:
		print "unrecognized option:", argv[1]
	xmldoc.unlink()
	#end program

def onewordtest():
	"""
	Ask for a word and show which meanings oneword associates with it.
	"""
	global associations
	while True:
		print "Talk to me (or enter \'quit\'): ",
		text = stdin.readline()
		if text == ['quit']: break
		for word in text.split():
			if word in associations:
				data = associations[word]
				print word
				list = [(data[fhash], fhash) for fhash in data]
				list.sort(reverse=True)
	# temporarily show the complete list for debugging
				for i, match in enumerate(list[:5]):
	#			for i, match in enumerate(list):
					print "match", i + 1, "score:", match[0]
					printframe(frameindex[match[1]])
			else:
				print word, "not in corpus."

def words2framestest():
	"""
	Input words and show a matching linguistic abstraction and meaning for
	each pair of words.
	"""
	global associations
	while True:
		print "corpus lexicon:",
		for a in sorted(associations.keys()): print a,
		print "\nTalk to me (or enter \'quit\'): ",
		text = stdin.readline().split()
		if text == ['quit']: break
		print text
		for w1, w2 in zip(text[:-1], text[1:]):
			if w1 not in associations:
				print w1, "not in corpus"
			elif w2 not in associations:
				print w2, "not in corpus"
			else:
				r = []
				for a in words2frame(w1, w2):
					if not a in r:
						r.append(a)
				if r == []:
					print "no matching linguistic abstraction found."
				else:
					for resultno, (score, LA, meaning) in enumerate(r[:5]):
						print '%i. \"%s %s\"\tla score = %i' % (resultno + 1, w1, w2, score)
						printframe(LA)
						printframe(meaning)

def frames2wordstest():
	"""
	For each situation in the corpus, show the twoword utterances that can
	be generated from it with its linguistic abstraction.
	"""
	global situations
	for sitno, sit in enumerate(situations):
		print 79 * "-"
		print "Situation:", sitno
		r = []
		for a in frame2words(sit):
			if a not in r:
				r.append(a)
		for resultno, (score, words) in enumerate(r[:2]):
			print resultno+1, ':', words, "\tScore =", score

def chattest():
	"""
	Input words and look for a matching linguistic abstraction, then
	try to find a new utterance as a reply, different from the input.
	"""
	global associations
	while True:
		print "corpus lexicon:",
		for a in sorted(associations.keys()): print a,
		print "\nTalk to me (or enter \'quit\'): ",
		text = stdin.readline().split()
		if text == ['quit']: break
		print text
		for w1, w2 in zip(text[:-1], text[1:]):
			if w1 not in associations:
				print w1, "not in corpus"
			elif w2 not in associations:
				print w2, "not in corpus"
			else:
				frames = words2frame(w1, w2)
				if frames == []:
					print "no matching linguistic abstraction found."
				else:
					#use second best result to generate a response sentence:
					response = la2words(frames[1][1], frames[1][2])
					if response == []:
						print "no response found."
					else:
						#for a in response:
						#	if w1 not in a or w2 not in a:
						#		print a
						print response

def uniqueutterancestest():
	"""
	similar to frames2words, but count how many utterances are
	generated which are unique, as in, not found in the corpus.
	"""
	global situations, associations
	utts, unique, total = [], 0, 0
	def mangle(a):
		return a.replace("!", "").replace("?", "").replace(".","").replace(",", "").split()
	def seen(words):
		for sentence in utts:
			if words[0] in sentence and words[1] in sentence:
				if sentence.index(words[0]) < sentence.index(words[1]):
					return True
		return False
	for a in situations:
		utterances = [mangle(a) for a in adultutterances(a)]
		#utterances = [zip(a[:-1], a[1:]) for a in utterances]
		utts.extend(utterances)
	print "utterances:", utts
	print "nr. utts:", len(utts)

	while True:
		print "corpus lexicon:",
		for a in sorted(associations.keys()): print a,
		print "\nHow many results to check per situation (or enter \'quit\'): ",
		nr = stdin.readline()
		if nr == ['quit']: break
		else:
			try: nr = int(nr)
			except: continue
		for sitno, sit in enumerate(situations):
			print 79 * "-"
			print "Situation:", sitno
			r = []
			for a in frame2words(sit):
				if a not in r:
					r.append(a)
			for resultno, (score, words) in enumerate(r[:nr]):
				total += 1
				if seen(tuple(words.split())):
					print resultno+1, ":", words, "in corpus.\tScore =", score
				else:
					print resultno+1, ": not in corpus:", words, "\tScore =", score
					unique += 1
		print "percentage of unique generated utterances:", float(unique) / total * 100, "(", unique, ")"

def words2frame(w1, w2):
	"""
	Searches the linguistic corpus for relevant LAs given two words.
	Return: list of frames.
	"""
	
	global W2MLADict, associations, frameindex
	def wordorderfit(LA, w1, w2):
		wo = elementiterator("wordorder", LA).next().childNodes[0].data
		if w1 == wo.split(":")[0] or w2 == wo.split(":")[1]:
			return True
		else:
			return False
	
	results = []	
	if w1 in associations and w2 in W2MLADict:
		meanings1 = associations[w1].keys()
		meaningsLAs2 = W2MLADict[w2]
		lingAbs2 = [mla[1] for mla in meaningsLAs2]
		for abst in lingAbs2:
			for m in meanings1:
				if wordorderfit(LADict[abst][0], w1, w2) and fits(frameindex[m], LADict[abst][0]):
					r = (LADict[abst][1], LADict[abst][0], frameindex[m])
					results.append(r)

	if w2 in associations and w1 in W2MLADict:
		meanings2 = associations[w2].keys()
		meaningsLAs1 = W2MLADict[w1]
		lingAbs1 = [mla[1] for mla in meaningsLAs1]
		for abst in lingAbs1:
			for m in meanings2:
				if wordorderfit(LADict[abst][0], w1, w2) and fits(frameindex[m], LADict[abst][0]):
					r = (LADict[abst][1], LADict[abst][0], frameindex[m])
					results.append(r)
	
	results.sort(reverse=True)
	return results

def invertassociations(associations):
	"""
	Make a dictionary with framehashes as keys and their most likely 
	words/scores as values. The data is taken from the associations
	dictionary from the	oneword stage. 
	"""
	inverted = {}
	for word, framescores in associations.items():
		for framehash, score in framescores.items():
			if not framehash in inverted:
				inverted[framehash] = (word, score)
			else:
				if score > inverted[framehash][1]:
					inverted[framehash] = (word, score)
		
	return inverted
	
def la2words(la, meaning):
	"""
	return utterances matching a LA and a meaning
	"""
	inverted = invertassociations(associations)
	
	words = la.getElementsByTagName("wordorder")[0].childNodes[0].data
	words = words.split(":")
	varword = inverted[framehash(meaning)][0]
	
	if words[0] == "VAR": words = " ".join((varword, words[1]))
	elif words[1] == "VAR": words = " ".join((words[0], varword))
	return words
	
def frame2words(frame):
	"""
	Searches the linguistic corpus for relevant LAs given a meaning frame. Then
	produces a list of 2-word utterances using these LAs.
	Return: list of strings.
	"""
		
	global LADict, M2WLADict, associations, frameindex
	
	inverted = invertassociations(associations)
	matches, results = [], []
	subframes = derivemeanings(frame, frameindex)
	#find all words, LAs and their fitting subframes for the situation:
	for a in subframes:
		if a in M2WLADict:
			wordAbst = M2WLADict[a]
			for wa in wordAbst:
				for b in subframes:
					if b in frameindex and fits(frameindex[b], LADict[wa[1]][0]):
						matches.append((wa[0],wa[1],b))
	
	#find words for the subframes in inverted and set them in the right order:
	for m in matches:
		word1 = m[0]
		if m[2] in inverted:
			word2 = inverted[m[2]][0]
		else:
			print m[2], "not in inverted."
		LA, score = LADict[m[1]]
		wordorder = LA.getElementsByTagName("wordorder")[0].childNodes[0].data
		wo = wordorder.split(":")
		
		if wo[0] == word1:
			utt = score, " ".join([word1, word2])
			results.append(utt)
		elif wo[1] == word1:
			utt = score, " ".join([word2, word1])
			results.append(utt)
		else:
			print "NO MATCH"
	
	results.sort(reverse=True) 
	return results
	
def fits(meaning, LA):
	"""
	Return True if supplied meaning fits in the
	linguistic abstraction.
	Converts both arguments to a flat string representation,
	and iterates over the lines of the LA. When a matching line
	with the meaning argument is found, a counter is increased.
	If the counter reaches the number of lines in the meaning,
	the meaning matched the LA. If a line is found which does
	not match the LA, the counter is resetted.

	Arguments:
	meaning -- a meaning frame, without VAR elements
	LA -- a linguistic abstraction, with VAR elements in properties or id
	"""
	#convert to a list of strings; skip 'meaning' header
	LAstr = frametostr(LA).splitlines()
	mstr = frametostr(meaning).splitlines()[1:]
	mlines = len(mstr)
	curmeaningline = 0
	matchedVAR = False
	for a in LAstr:
		if a.endswith("VAR"):
			matchedVAR = True
			substr = a[:-3].strip()
		else:
			substr = a.strip()
		if mstr[curmeaningline].find(substr) == -1:
			curmeaningline = 0  #start over
			matchedVAR = False
		else:  #found match
			curmeaningline += 1
			if curmeaningline == mlines:
				if matchedVAR:
					return True
				else:
					return False
	#if we came this far, it didn't work out
	return False			

def findLAList(scoresWordsMeanings):
	"""
	Find a list of LAs that fit the meanings in the argument.
	Return: list of LA hashes.
	"""

		
	global frameindex, M2WLADict

	LAList = []
	#loop through all pairs of meanings
	for swm1 in scoresWordsMeanings:
		for swm2 in scoresWordsMeanings:
			#find all subframes of meaning1
			#print swm1
			meanings = derivemeanings(frameindex[swm1[2]], frameindex)
			for meaning in meanings:
				#if meaning2 is a subframe of meaning1:
				if swm2[2] == meaning:
					try:
						wordLAList = M2WLADict[swm2[2]]
					except KeyError:
						wordLAList = []
					for wordLA in wordLAList:
						#if a LA compatible with meaning1/word1 is found:
						if wordLA[0] == swm1[1]:
							LA = LADict[wordLA[1]][0]
							if fits(frameindex[swm2[2]], LA):
								LAList.append(wordLA[1])
	return LAList

def hypothesize(utterance, scoresWordsMeanings):
	"""
	Make a new LA based on the meanings in the second argument. First search for
	the highest scoring meaning that has a subframe also in the list of
	meanings. Then get the wordorder of the two words corresponding with these
	meanings and make a new LA using these frames and the wordorder.
	Return: None.
	"""
	global LADict, MDict, M2WLADict, frameindex, W2MLADict

	def makeLA(meaning1, meaning2, wordorder):
		"""
		Make a linguistic abstraction from first argument. The second
		argument must always be a subframe of the first, this part of
		the first argument will be variable, corresponding to the VAR
		in the wordorder.
		Return: LA frame.
		"""
		def compareframes(frame1, frame2):
			"""
			compare two frames, ignoring the role
			"""
			p = frametostr(frame1).replace(frame1.getAttribute("name"), "FOO")
			q = frametostr(frame2).replace(frame2.getAttribute("name"), "FOO")
			if p == q:
				return True
			else:
				return False
 
		#printframe(meaning1)
		#printframe(meaning2)
		LA = meaning1.cloneNode(deep=1)
		Kreator = minidom.Document()
		
		for a in meaning2.childNodes:
			if a.nodeName == "frame":
				for b in LA.getElementsByTagName("frame"):
					if compareframes(a, b):
						id(b).childNodes[0].data = "VAR"
						for c in properties(b):
							c.childNodes[0].data = "VAR"			

			if a.nodeName == "prop":
				for b in LA.getElementsByTagName("prop"):
					if a.getAttribute("name") == b.getAttribute("name"):
						if a.childNodes[0].data == b.childNodes[0].data:
							b.childNodes[0].data = "VAR"

		woElem = Kreator.createElement("wordorder")
		woElem.appendChild(Kreator.createTextNode(wordorder[0] + ':' + wordorder[1]))
		LA.appendChild(woElem)
		LA.tagName,	LA.nodeName = "la", "la"

		return LA
	
	#find highest scoring frame that has a subframe in scoresWordsMeanings:
	for swm1 in scoresWordsMeanings:
		for swm2 in scoresWordsMeanings:
			meanings = derivemeanings(frameindex[swm1[2]], frameindex)
			for meaning in meanings:
				#if meaning2 is subframe of meaning1 and they are not same:
				if swm2[2] == meaning and not swm1[2] == swm2[2]:
					#make a new LA:
					wordorder = getWordorder(utterance, swm1[1], swm2[1])
					LA = makeLA(frameindex[swm1[2]], frameindex[meaning], wordorder)
					#insert LA into datastructure:
					LAHash = framehash(LA)
					LADict[LAHash] = (LA, 1)
					MDict[swm1[2]] = frameindex[swm1[2]]
					try:
						M2WLADict[swm1[2]].append((swm1[1], LAHash))
						W2MLADict[swm1[1]].append((swm1[2], LAHash))
					except:
						M2WLADict[swm1[2]] = [(swm1[1], LAHash)]
						W2MLADict[swm1[1]] = [(swm1[2], LAHash)]
					

def getWordorder(u, word1, word2):
    # Check whether word2 comes before or after word1 and replace word2 with VAR
    for i in range(len(u)):
        if u[i] == word1 and word2 in u[i:]:
            return (word1, "VAR")
        elif u[i] == word1 and word2 in u[:i]:
            return ("VAR", word1)

def analyze(situation):
	"""
	Analyzes a situation. This includes finding LAs that fit the situation
	and its utterance and then either reinforcing found LAs or
	hypothesizing a new LA.
	Return: None.
	"""
	global frameindex, meaningcounters
	
	def getWordMeaningList(word):
		"""
		Get meanings and scores corresponding to the word passed as
		argument.
		Return: a list of (word, meaning, score) tuples.
		"""
		import operator
		global associations, meaningcounters
		wordMeaningList = []
		if word in associations:
			for meaning in associations[word]:
				wordMeaningList.append( (word, meaning, associations[word][meaning]) ) #/float(meaningcounters[meaning])) )
		wordMeaningList.sort(key=operator.itemgetter(2), reverse=True)
		return wordMeaningList[:5]
	
	frames = derivemeanings(situation, frameindex)
	utts = parseutterances(situation, {})
	
	highestCompatibleMeaningList = []
	#for each word get a list of meanings and select one
	for word in utts:
		wordMeaningList = getWordMeaningList(word)
		compatibleWordMeaningList = []
		#filter out all meanings that are not compatible with situation
		for (word, meaning, score) in wordMeaningList:
			for subframe in frames:
				if meaning == subframe:
					compatibleWordMeaningList.append((score,word,meaning))
		#select meaning with highest score from remaining meanings
		if not compatibleWordMeaningList == []:
			compatibleWordMeaningList.sort(reverse=True)
			highestCompatibleMeaningList.extend(compatibleWordMeaningList[:5])
	
	highestCompatibleMeaningList.sort(reverse=True)
	#find a list of fitting LAs and reinforce these if found
	LAList = findLAList(highestCompatibleMeaningList)
	if not LAList == []:
		reinforce(LAList)
	#otherwise make a new LA
	else:
		hypothesize(utts, highestCompatibleMeaningList)

def reinforce(LAList):
	"""
	Reinforce LAs corresponding to the list of LAhashes in the first
	argument.
	Return: None.
	"""
	global LADict
	for LA in LAList:
		LADict[LA] = (LADict[LA][0], LADict[LA][1] + 1)

if __name__ == "__main__":
    twMain()
