#!env python
# Language Acquisition project, June 2007, UvA.
"""
	OVERVIEW

# Important data structures:
# - Dictionary with words as keys, and lists tuples (amount, framehash) as
#   values
eg. associations['ball'] = { _frameXYZ : 0.45, _frameABC : 0.33, ...} 

# - Dictionary with framehashes as keys, and the real frames as values
frameindex[hash1234] = _frameXYZ

#Initialization functions
readcorpus()
parseutterances()

#Learn Functions:
createsubframes() <-- formerly known as 'abstractions'
associations() <-- formerly known as 'speech'

#Print functions
printframe()
printprop()
printsituation()
frametostr() # string representation of a frame, ordered canonically

#Demonstration:
done in main() (when running standalone)
"""
from xml.dom import minidom
from sys import stdin
from sys import argv
import math

def main():
	"""
	then read words from stdin and find matching frames
	(if this file is not called directly, main() will be ignored).
	"""
	#banner/silly disclaimer
	#	
	print """Language Acquisition, one-word model. 2nd Year project UvA 2007
This program is not distributed in the hope that it will be useful,
so WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
"""
	try:
		corpus = argv[1]
	except:
		corpus = "corpus.xml"
	xmldoc = minidom.parse(corpus).documentElement
	associations, frameindex, meaningcounters = oneword(xmldoc)
	while True:
		# read words from stdin and find frames
		print "corpus lexicon:",
		for a in sorted(associations.keys()): print a,
		print
		print "Talk to me (or enter \'quit\'): ",
		text = stdin.readline()
		if text == 'quit\n': break
		for word in text.split():
			if word in associations:
				data = associations[word]
				print word
				#list = [(data[fhash]/float(meaningcounters[fhash]), fhash) for fhash in data]
				list = [(data[fhash], fhash) for fhash in data]
				list.sort(reverse=True)
	# temporarily show the complete list for debugging
				for i, match in enumerate(list[:5]):
	#			for i, match in enumerate(list):
					print "match", i + 1, "score:", match[0]
					printframe(frameindex[match[1]])
			else:
				print word, "not in corpus."
	#clean up
	xmldoc.unlink()
	return	#pass on

def oneword(xmldoc):
	"""
	parse corpus and generate derived frames,
	parameter: use: xmldoc = minidom.parse("corpus.xml").documentElement
	return: a tuple (associations, frameindex)
	"""
	#parse XML
	print "oneword: Reading corpus data.."
	situations = xmldoc.getElementsByTagName("situation")
	associations, frameindex, meaningcounters = {}, {}, {}
	#
	#generate derived frames
	print "oneword: Analyzing situations.."
	for sit in situations:
		#possible methods: ignore, repeat, default
		utterances = parseutterances(sit, associations, method='default')
		meanings = derivemeanings(sit, frameindex)
		updatemeaningcounters(meanings, meaningcounters)
		associate(utterances, meanings, associations, method=1)
	#
	#try to correct scores (disabled because twoword doesn't work if this is on)
	#print "oneword: Correcting scores.."
	#correctassociations(associations, len(utterances), meaningcounters)
	return (associations, frameindex, meaningcounters)

def derivemeanings(situation, frameindex):
	"""
	for a given situation, return a list of derived meanings.
	"""
	def abstractiontoid(frame):
		"""
		try to rename abstraction to an id element and delete the
		old id element, otherwise return an empty list.
		"""
		aframe = frame.cloneNode(deep=1)
		try:
			a = abst(frames(aframe).next()).childNodes[0]
		except StopIteration:
			return  []
		b = id(frames(aframe).next()).childNodes[0]
		id(frames(aframe).next()).replaceChild(a.cloneNode(deep=1), b)
		frames(aframe).next().removeChild(abst(frames(aframe).next()))
		return [aframe]
	#
	def recursiveframes(frame):
		"""
		recursively generate derived frames
		"""	
		solutions = []
		for a in frames(frame):
			newframe = Kreator.createElement("meaning")
			sub = newframe.appendChild(a.cloneNode(deep=1))
			#sub.setAttribute("name", "sub")
			#add properties in sorted order:
			for b in properties(frame):
				newframe.appendChild(b.cloneNode(deep=1))
			solutions.extend(abstractiontoid(newframe))
			solutions.append(newframe)
			solutions += recursiveframes(a)
		for a in properties(frame):
			newframe = Kreator.createElement("meaning")
			newframe.appendChild(a.cloneNode(deep=1))
			solutions.append(newframe)
		return solutions
	#
	#kludge to create new elements
	Kreator = minidom.Document()
	try:
		originalframe = frames(situation).next()
	except StopIteration:	
		originalframe = situation
	#
	# printframe(originalframe)
	fullmeaning = Kreator.createElement("meaning")
	fullmeaning.appendChild(originalframe.cloneNode(deep=1))
	emptyframe = Kreator.createElement("frame")
	emptyframe.setAttribute("name", "action")
	try:
		emptyframe.appendChild(id(originalframe).cloneNode(deep=1))
		emptyframe.appendChild(abst(originalframe).cloneNode(deep=1))
	except StopIteration:
		pass
	emptymeaning = Kreator.createElement("meaning")
	emptymeaning.appendChild(emptyframe)	
	solutions = [fullmeaning, emptymeaning]
	solutions.extend(abstractiontoid(fullmeaning))
	solutions.extend(abstractiontoid(emptymeaning))
	# find subframes, ehh, lower stuff, et cetera
	solutions += recursiveframes(originalframe)
	return makehashes(solutions, frameindex)

def updatemeaningcounters(meanings, meaningcounters):
        """
        increment count of the occurrences of the meaning-frames.
        these will be used to correct scores later on.
        """
        for m in meanings:
                if not m in meaningcounters:
                        meaningcounters[m] = 1
                else:
                        meaningcounters[m] += 1

def associate(utterances, meanings, associations, method=0): 
	"""
	compute scoring between words and frames
	if the method argument > 0 unrelated frames
	will have their associations decreased.
	"""
	for a in utterances:
		for b in meanings:
			if not b in associations[a]:
				associations[a][b] = 1
			else: #TODO: insert weird formula here
				#monotically increasing: 
				associations[a][b] += 2
		# statistics below seem to degrade performance (at least in
		# some situation, eg. 'ball'), so stopped using for now:
		method = 0
		if method > 0:
			for b in associations[a]:
				if b not in meanings:
					#assume other frames are unrelated to this word:
					associations[a][b] -= 1

def correctassociations(associations, wordcount, meaningcounters):
	"""
	do some math to correct associatons
	"""
	for word, data in associations.items():
		total = sum([data[framehash] for framehash in data])
		for framehash in data:
			#data[framehash] = float(data[framehash]) / total
			data[framehash] = float(data[framehash]) / meaningcounters[framehash]
	#remove words that appear to often (assume they bear no semantic information)
	#for word in associations:
	#	word count divided by number of sentences
	#
	#logistic function?:
	#associations[a][b] = 1.0 / (1.0 + math.e ** (-1.0 * associations[a][b]))
	#
	return	
	
def parseutterances(situation, associations, method='default'):
	"""
	Create a list of single words of all the adult utterances combined,
	after stripping unwanted characters. If seperate utterances are needed
	change this.  As this is the oneword stage, sentence boundaries are
	currently meaningless.
	"""
	utterance = " ".join( adultutterances(situation) )
	utts = utterance.replace("?", "").replace(".","").replace(",", "").split()
	#
	# words starting with a '!' are emphasized; we can use this in
	# different ways, for example ignoring all other words, or increasing
	# the number of the emphasized words
	# 
        if method == 'ignore': # use only emphasized words
                correctedwords = [a[1:] for a in utts if a[0] == '!']
        elif method == 'repeat': # triple emphasized words
		def repeatifemphasized(n, a):
			if a[0] == '!': return  [a[1:] for i in range(n)]
			else: return [a]
                #adding multiple elements not possible with list comprehension?:
		#correctedwords = [repeatifemphasized(3, a) for a in utts]
		correctedwords = []
		for a in utts:
                                correctedwords.extend(repeatifemphasized(3, a))
        else: # method == 'default'
                correctedwords = [a.replace("!","") for a in utts]
        #
        # make sure the words in associations are known beforehand
        for a in correctedwords:
                if not a in associations:
                        associations[a] = {}
        return correctedwords
	
### Section Auxilary Functions
def makehashes(meanings, frameindex):
	output = []
	for a in meanings:
		b = framehash(a)
		if not b in frameindex:
			frameindex[b] = a
		output.append(b)
	return output

# Print Functions
def printsituation(situation):
	"""
	print a situation's description, frames and utterances
	"""
	description = situation.getElementsByTagName("description")[0]
	print " DESC:",description.childNodes[0].data
	for a in frames(situation):
		printframe(a)
	for a in situation.childNodes:
		if a.nodeName == "adult":
			print "adult:", a.childNodes[0].data
		elif a.nodeName =="child":
			print "child:", a.childNodes[0].data
		else:
			pass

def printframe(frame):
	print frametostr(frame),
	#end with ruler to signify end of top level frame
	print 70 * '-'

def frametostr(frame, nesting=0, removename=False):
	"""
	make 'human readable' string of a frame, for both pretty-printing
	and finding duplicates.
	"""
	result = []
	if frame.nodeName == "la":
		result = ["LINGUISTIC ABSTRACTION:\n"]
	elif frame.nodeName == "meaning":
		result = ["MEANING:\n"]
	elif frame.nodeName == "frame":
		result = [nesting * '\t' + "FRAME: "]
		if removename == False:
			result += "%s\n" % frame.getAttribute("name")
		else:
			result += "void\n"
	elif frame.nodeName == "situation":
		result = ["SITUATION:\n"]
	try:	
		result += "%sID: %s\n" % ((nesting+1)*'\t', id(frame).childNodes[0].data)
		result += "%sABSTR: %s\n" % ((nesting+1)*'\t', abst(frame).childNodes[0].data)
	except StopIteration:
		pass #not all frames need to have an abstraction element
	try:
		result += "%sWORDORDER: %s\n" % ((nesting+1)*'\t', elementiterator("wordorder", frame).next().childNodes[0].data)
	except StopIteration: pass

	def cmppropkeys(prop1, prop2):
		if prop1.getAttribute("name") < prop2.getAttribute("name"):
			return -1
		elif prop1.getAttribute("name") > prop2.getAttribute("name"):
			return 1
		else:
			return 0

	for a in sorted(properties(frame), cmppropkeys):
		result += "%sPROP: %s = %s\n" % ((nesting + 1) * '\t', a.getAttribute("name"), a.childNodes[0].data)
	for a in sorted(frames(frame), cmppropkeys):
		result += frametostr(a, nesting + 1)
	return "".join(result)

# Readabilty functions
def framehash(frame):
	"""
	generate a hash value of a frame, by converting it to a string
	representation and hashing that. The conversion is used to do a
	"deep" conversion, instead of just comparing object references.
	"""
	return hash(frametostr(frame, removename=True))

def id(frame):
	return elementiterator("id", frame).next()
def abst(frame):
	return elementiterator("abstraction", frame).next()
	
def frames(frame):
	return elementiterator("frame", frame)
def properties(frame):
	return elementiterator("prop", frame)
def adultutterances(frame):
	""" return utterances as strings (hacky code) """
	#return [a.childNodes[0].data for a in elementiterator("adult", frame)]
	for a in elementiterator("adult", frame):
		yield a.childNodes[0].data

def elementiterator(tag, frame):
	""" iterator over elements with a specific tag in a frame """
	for a in frame.childNodes:
		if a.nodeName == tag:
			yield a
### Start the program
if __name__ == "__main__":
    #make sure sorted() is available if using python <2.4
    try:
    	sorted([])
    except:
    	def sorted(l, cmp=None):
    		copy = l[:]
    		copy.sort(cmp)
    		return copy
    main()
