from oneword import *
from twowordtest import *
from sys import *
"""
The two-word algorithm will learn linguistic abstractions/structures (LA) for
combining two words of whom it knows the meaning from the one-word stage. When
it receives a situation (SI) together with the coupled utterance it will
analyze this and either reinforce already existing linguistic abstractions or
hypothesize a new linguistic abstraction.

Datastructures:
 - LADict: tuples of linguistic abstractions and their scores, with hashes as
   keys
 - MDict: similar for meanings
 - M2WLADict: dictionary with meaning hashes as keys and list of tuples with a
   word and a linguistic abstraction hashes as values.
 - associations and frameindex from oneword
"""

global associations, frameindex, LADict, MDict, M2WLADict, W2MLADict, meaningcounter
LADict, MDict, M2WLADict, W2MLADict = {}, {}, {}, {}

def twMain():
	"""
	Create one-word data. Pass situations in the corpus to analyze using this
	data.
	"""
	
	#decide which reporting method is wanted
	try:
		method = argv[1]
	except:
		print "Argument missing. Please choose one of the following:"
		print "f : frame to words experiment"
		print "w : words to frames experiment"
		return
	
	#oneword part: make meaning-word associations
	global associations, frameindex, meaningcounter
	xmldoc = minidom.parse("nicecorpus.xml").documentElement
	associations, frameindex = twowordtest(xmldoc)
	
	
	#twoword part: analyze situations, making and reinforcing new LAs
	situations = xmldoc.getElementsByTagName("situation1")

	print "Known frames\n"
        for a in situations:
		print "\"", adultutterances(a).next(), "\"\n"
                printframe(frames(a).next())

#	print "twoword: Analyzing.."
	for a in situations:
		analyze(a)
	global LADict
	
	#do oneword test
	if method == "o":
		print "Talk to me: ",
		text = stdin.readline()
		for word in text.split():
			if word in associations:
				data = associations[word]
				print word
				list = [(data[fhash], fhash) for fhash in data]
				list.sort(reverse=True)
	# temporarily show the complete list for debugging
				for i, match in enumerate(list[:5]):
	#			for i, match in enumerate(list):
					print "match", i + 1, "score:", match[0]
					printframe(frameindex[match[1]])
			else:
				print word, "not in corpus."
	
	#report found utterances for all situations
	if method == "f":
		for sitno, sit in enumerate(situations):
			print 79 * "-"
			print "Situation:", sitno
			for resultno, (score, words) in enumerate(frame2words(sit)[:4]):
				print resultno+1, ':', words, "\tScore =", score
	
	#ask for words and report corresponding frames
	if method == "w":
		print "Talk to me: ",
		text = stdin.readline().split()
		for w1, w2 in zip(text[:-1], text[1:]):
			results = []
			for result in words2frame(w1, w2):
				if not result in results:
					results.append(result)
			
			for resultno, (score, LA, meaning) in enumerate(reversed(results[:5])):
				print 10*'=', '%i. \"%s %s\"' % (resultno + 1, w1, w2), 10*'='
				printframe(LA)
				printframe(meaning)
				stdin.readline()
		
	xmldoc.unlink()
	#end program

def words2frame(w1, w2):
	"""
	Searches the linguistic corpus for relevant LAs given two words.
	Return: list of frames.
	"""
	
	global W2MLADict, associations, frameindex
	def wordorderfit(LA, w1, w2):
		wo = elementiterator("wordorder", LA).next().childNodes[0].data
		if w1 == wo.split(":")[0] or w2 == wo.split(":")[1]:
			return True
		else:
			return False
	
	results = []	
	if w1 in associations and w2 in W2MLADict:
		meanings1 = associations[w1].keys()
		meaningsLAs2 = W2MLADict[w2]
		lingAbs2 = [mla[1] for mla in meaningsLAs2]
		for abst in lingAbs2:
			for m in meanings1:
				if wordorderfit(LADict[abst][0], w1, w2) and fits(frameindex[m], LADict[abst][0]):
					r = (LADict[abst][1], LADict[abst][0], frameindex[m])
					results.append(r)

	if w2 in associations and w1 in W2MLADict:
		meanings2 = associations[w2].keys()
		meaningsLAs1 = W2MLADict[w1]
		lingAbs1 = [mla[1] for mla in meaningsLAs1]
		for abst in lingAbs1:
			for m in meanings2:
				if wordorderfit(LADict[abst][0], w1, w2) and fits(frameindex[m], LADict[abst][0]):
					r = (LADict[abst][1], LADict[abst][0], frameindex[m])
					results.append(r)
	
	results.sort(reverse=True)
	return results
	


def frame2words(frame):
	"""
	Searches the linguistic corpus for relevant LAs given a meaning frame. Then
	produces a list of 2-word utterances using these LAs.
	Return: list of strings.
	"""
	def invertassociations(associations):
		"""
		Make a dictionary with framehashes as keys and their most likely 
		words/scores as values. The data is taken from the associations
		dictionary from the	oneword stage. 
		"""
		inverted = {}
		for word, framescores in associations.items():
			for framehash, score in framescores.items():
				if not framehash in inverted:
					inverted[framehash] = (word, score)
				else:
					if score > inverted[framehash][1]:
						inverted[framehash] = (word, score)
			
		return inverted
		
	global LADict, M2WLADict, associations, frameindex
	
	subframes = derivemeanings(frame, frameindex)
	inverted = invertassociations(associations)
	matches, results = [], []
	
	#find all words, LAs and their fitting subframes for the situation:
	for a in subframes:
		if a in M2WLADict:
			wordAbst = M2WLADict[a]
			for wa in wordAbst:
				for b in subframes:
					if b in frameindex and fits(frameindex[b], LADict[wa[1]][0]):
						matches.append((wa[0],wa[1],b))
	
	#find words for the subframes in inverted and set them in the right order:
	for m in matches:
		word1 = m[0]
		if m[2] in inverted:
			word2 = inverted[m[2]][0]
		else:
			print m[2], "not in inverted."
		LA, score = LADict[m[1]]
		wordorder = LA.getElementsByTagName("wordorder")[0].childNodes[0].data
		wo = wordorder.split(":")
		
		if wo[0] == word1:
			utt = score, " ".join([word1, word2])
			results.append(utt)
		elif wo[1] == word1:
			utt = score, " ".join([word2, word1])
			results.append(utt)
		else:
			print "NO MATCH"
	
	results.sort(reverse=True) 
	return results
	
def fits(meaning, LA):
	"""
	Return True if supplied meaning fits in the
	linguistic abstraction.
	Converts both arguments to a flat string representation,
	and iterates over the lines of the LA. When a matching line
	with the meaning argument is found, a counter is increased.
	If the counter reaches the number of lines in the meaning,
	the meaning matched the LA. If a line is found which does
	not match the LA, the counter is resetted.

	Arguments:
	meaning -- a meaning frame, without VAR elements
	LA -- a linguistic abstraction, with VAR elements in properties or id
	"""
	#convert to a list of strings; skip 'meaning' header
	LAstr = frametostr(LA).splitlines()
	mstr = frametostr(meaning).splitlines()[1:]
	mlines = len(mstr)
	curmeaningline = 0
	matchedVAR = False
	for a in LAstr:
		if a.endswith("VAR"):
			matchedVAR = True
			substr = a[:-3].strip()
		else:
			substr = a.strip()
		if mstr[curmeaningline].find(substr) == -1:
			curmeaningline = 0  #start over
			matchedVAR = False
		else:  #found match
			curmeaningline += 1
			if curmeaningline == mlines:
				if matchedVAR:
					return True
				else:
					return False
	#if we came this far, it didn't work out
	return False			

def findLAList(scoresWordsMeanings):
	"""
	Find a list of LAs that fit the meanings in the argument.
	Return: list of LA hashes.
	"""

		
	global frameindex, M2WLADict

	LAList = []
	#loop through all pairs of meanings
	for swm1 in scoresWordsMeanings:
		for swm2 in scoresWordsMeanings:
			#find all subframes of meaning1
			#print swm1
			meanings = derivemeanings(frameindex[swm1[2]], frameindex)
			for meaning in meanings:
				#if meaning2 is a subframe of meaning1:
				if swm2[2] == meaning:
					try:
						wordLAList = M2WLADict[swm2[2]]
					except KeyError:
						wordLAList = []
					for wordLA in wordLAList:
						#if a LA compatible with meaning1/word1 is found:
						if wordLA[0] == swm1[1]:
							LA = LADict[wordLA[1]][0]
							if fits(frameindex[swm2[2]], LA):
								LAList.append(wordLA[1])
	return LAList

def hypothesize(utterance, scoresWordsMeanings):
	"""
	Make a new LA based on the meanings in the second argument. First search for
	the highest scoring meaning that has a subframe also in the list of
	meanings. Then get the wordorder of the two words corresponding with these
	meanings and make a new LA using these frames and the wordorder.
	Return: None.
	"""
	global LADict, MDict, M2WLADict, frameindex, W2MLADict

	def makeLA(meaning1, meaning2, wordorder):
		"""
		Make a linguistic abstraction from first argument. The second
		argument must always be a subframe of the first, this part of
		the first argument will be variable, corresponding to the VAR
		in the wordorder.
		Return: LA frame.
		"""
		def compareframes(frame1, frame2):
			"""
			compare two frames, ignoring the role
			"""
			p = frametostr(frame1).replace(frame1.getAttribute("name"), "FOO")
			q = frametostr(frame2).replace(frame2.getAttribute("name"), "FOO")
			if p == q:
				return True
			else:
				return False
 
		#printframe(meaning1)
		#printframe(meaning2)
		LA = meaning1.cloneNode(deep=1)
		Kreator = minidom.Document()
		
		for a in meaning2.childNodes:
			if a.nodeName == "frame":
				for b in LA.getElementsByTagName("frame"):
					if compareframes(a, b):
						id(b).childNodes[0].data = "VAR"
						for c in properties(b):
							c.childNodes[0].data = "VAR"			

			if a.nodeName == "prop":
				for b in LA.getElementsByTagName("prop"):
					if a.getAttribute("name") == b.getAttribute("name"):
						if a.childNodes[0].data == b.childNodes[0].data:
							b.childNodes[0].data = "VAR"

		woElem = Kreator.createElement("wordorder")
		woElem.appendChild(Kreator.createTextNode(wordorder[0] + ':' + wordorder[1]))
		LA.appendChild(woElem)
		LA.tagName,	LA.nodeName = "la", "la"

		return LA
	
	#find highest scoring frame that has a subframe in scoresWordsMeanings:
	for swm1 in scoresWordsMeanings:
		for swm2 in scoresWordsMeanings:
			meanings = derivemeanings(frameindex[swm1[2]], frameindex)
			for meaning in meanings:
				#if meaning2 is subframe of meaning1 and they are not same:
				if swm2[2] == meaning and not swm1[2] == swm2[2]:
					#make a new LA:
					wordorder = getWordorder(utterance, swm1[1], swm2[1])
					LA = makeLA(frameindex[swm1[2]], frameindex[meaning], wordorder)
					#insert LA into datastructure:
					LAHash = framehash(LA)
					LADict[LAHash] = (LA, 1)
					MDict[swm1[2]] = frameindex[swm1[2]]
					try:
						M2WLADict[swm1[2]].append((swm1[1], LAHash))
						W2MLADict[swm1[1]].append((swm1[2], LAHash))
					except:
						M2WLADict[swm1[2]] = [(swm1[1], LAHash)]
						W2MLADict[swm1[1]] = [(swm1[2], LAHash)]
					

def getWordorder(u, word1, word2):
    # Check whether word2 comes before or after word1 and replace word2 with VAR
    for i in range(len(u)):
        if u[i] == word1 and word2 in u[i:]:
            return (word1, "VAR")
        elif u[i] == word1 and word2 in u[:i]:
            return ("VAR", word1)

def analyze(situation):
	"""
	Analyzes a situation. This includes finding LAs that fit the situation
	and its utterance and then either reinforcing found LAs or
	hypothesizing a new LA.
	Return: None.
	"""
	global frameindex, meaningcounter
	
	def getWordMeaningList(word):
		"""
		Get meanings and scores corresponding to the word passed as
		argument.
		Return: a list of (word, meaning, score) tuples.
		"""
		import operator
		global associations, meaningcounter
		wordMeaningList = []
		if word in associations:
			for meaning in associations[word]:
				wordMeaningList.append( (word, meaning, associations[word][meaning]) ) #/float(meaningcounter[meaning])) )
		wordMeaningList.sort(key=operator.itemgetter(2), reverse=True)
		return wordMeaningList[:5]
	
	frames = derivemeanings(situation, frameindex)
	utts = parseutterances(situation, {})
	
	highestCompatibleMeaningList = []
	#for each word get a list of meanings and select one
	for word in utts:
		wordMeaningList = getWordMeaningList(word)
		compatibleWordMeaningList = []
		#filter out all meanings that are not compatible with situation
		for (word, meaning, score) in wordMeaningList:
			for subframe in frames:
				if meaning == subframe:
					compatibleWordMeaningList.append((score,word,meaning))
		#select meaning with highest score from remaining meanings
		if not compatibleWordMeaningList == []:
			compatibleWordMeaningList.sort(reverse=True)
			highestCompatibleMeaningList.extend(compatibleWordMeaningList[:5])
	
	highestCompatibleMeaningList.sort(reverse=True)
	#find a list of fitting LAs and reinforce these if found
	LAList = findLAList(highestCompatibleMeaningList)
	if not LAList == []:
		reinforce(LAList)
	#otherwise make a new LA
	else:
		hypothesize(utts, highestCompatibleMeaningList)

def reinforce(LAList):
	"""
	Reinforce LAs corresponding to the list of LAhashes in the first
	argument.
	Return: None.
	"""
	global LADict
	for LA in LAList:
		LADict[LA] = (LADict[LA][0], LADict[LA][1] + 1)

if __name__ == "__main__":
    twMain()
