#!/usr/bin/env python
from __future__ import print_function
#from oldheapdict import heapdict
#from pq import heapdict
from agenda import Agenda as heapdict
import random
import unittest
import sys
try:
	import test.support as test_support # Python 3
except ImportError:
	import test.test_support as test_support # Python 2
from containers import Edge, SmallChartItem
NONE = SmallChartItem(0, 0)
N = 100

def _parent(i):
	return ((i - 1) >> 1)

def getval(entry, h):
	try: return entry[0]
	except TypeError:
		try: return entry.value
		except: return h.getval(entry)

class TestHeap(unittest.TestCase):
	def check_invariants(self, h):
		try: heap = h.heap
		except: heap = h.getheap()
		for i in range(len(h)):
			# this check is only an implementation detail of heapdict,
			# pq / cpq use this field to break ties.
			#self.assertEqual(h.heap[i][2], i)
			if i > 0:
				self.assertTrue(getval(heap[_parent(i)], h) <= getval(heap[i], h))

	def make_data(self):
		pairs = [(random.random(), Edge(random.random(), 0., 0., NONE, NONE)) for i in range(N)]
		h = heapdict()
		d = {}
		for k, v in pairs:
			h[k] = v
			d[k] = v

		pairs.sort(key=lambda x: x[1], reverse=True)
		return h, pairs, d

	def test_contains(self):
		h, pairs, d = self.make_data()
		h, pairs2, d = self.make_data()
		for k, v in pairs + pairs2:
			self.assertEqual(k in h, k in d)

	def test_len(self):
		h, pairs, d = self.make_data()
		self.assertEqual(len(h), len(d))
	
	def test_popitem(self):
		h, pairs, d = self.make_data()
		while pairs:
			v = h.popitem()
			v2 = pairs.pop(-1)
			self.assertEqual(v, v2)
			d.pop(v[0])
			self.assertEqual(len(h), len(d))
			self.assertEqual(set(h.items()), set(d.items()))
		self.assertEqual(len(h), 0)

	def test_popitem_ties(self):
		h = heapdict()
		for i in range(N):
			h[i] = Edge(0., 0., 0., NONE, NONE)
		for i in range(N):
			k, v = h.popitem()
			self.assertEqual(v, Edge(0., 0., 0., NONE, NONE))
			self.check_invariants(h)

	def test_popitem_ties_fifo(self):
		h = heapdict()
		for i in range(N):
			h[i] = Edge(0., 0., 0., NONE, NONE)
		for i in range(N):
			k, v = h.popitem()
			self.assertEqual(k, i)
			self.assertEqual(v, Edge(0., 0., 0., NONE, NONE))
			self.check_invariants(h)

	def test_peek(self):
		h, pairs, d = self.make_data()
		while pairs:
			v = h.peekitem()[0]
			h.popitem()
			v2 = pairs.pop(-1)
			self.assertEqual(v, v2[0])
		self.assertEqual(len(h), 0)

	def test_iter(self):
		h, pairs, d = self.make_data()
		self.assertEqual(list(h), list(d))

	def test_keys(self):
		h, pairs, d = self.make_data()
		self.assertEqual(sorted(h.keys()), sorted(d.keys()))
		self.assertEqual(sorted(h.iterkeys()), sorted(d.iterkeys()))

	def test_values(self):
		h, pairs, d = self.make_data()
		self.assertEqual(sorted(h.values()), sorted(d.values()))
		self.assertEqual(sorted(h.itervalues()), sorted(d.itervalues()))

	def test_items(self):
		h, pairs, d = self.make_data()
		self.assertEqual(sorted(h.items()), sorted(d.items()))
		self.assertEqual(sorted(h.iteritems()), sorted(d.iteritems()))

	def test_del(self):
		h, pairs, d = self.make_data()
		while pairs:
			k, v = pairs.pop(len(pairs)//2)
			del h[k]
			del d[k]
			self.assertEqual(len(h), len(d))
			self.assertEqual(set(h.items()), set(d.items()))
		self.assertEqual(len(h), 0)

	def test_pop(self):
		h, pairs, d = self.make_data()
		while pairs:
			k, v = pairs.pop(-1)
			v2 = h.pop(k)
			v3 = d.pop(k)
			self.assertEqual(v, v2)
			self.assertEqual(v, v3)
			self.assertEqual(len(h), len(d))
			self.assertEqual(set(h.items()), set(d.items()))
		self.assertEqual(len(h), 0)

	def test_change(self):
		h, pairs, d = self.make_data()
		k, v = pairs[N//2]
		h[k] = Edge(0.5, 0.0, 0.0, NONE, NONE)
		pairs[N//2] = (k, Edge(0.5, 0.0, 0.0, NONE, NONE))
		pairs.sort(key = lambda x: x[1], reverse=True)
		while pairs:
			v = h.popitem()
			v2 = pairs.pop()
			self.assertEqual(v, v2)
		self.assertEqual(len(h), 0)

	def test_init(self):
		h, pairs, d = self.make_data()
		h = heapdict(d.items())
		while pairs:
			v = h.popitem()
			v2 = pairs.pop()
			self.assertEqual(v, v2)
			d.pop(v[0])
		self.assertEqual(len(h), len(d))
		self.assertEqual(len(h), 0)

	def test_repr(self):
		h, pairs, d = self.make_data()
		self.assertEqual(h, eval(repr(h)))
	

#==============================================================================

def test_main(verbose=None):
	from types import BuiltinFunctionType

	test_classes = [TestHeap]
	test_support.run_unittest(*test_classes)

	# verify reference counting
	if verbose and hasattr(sys, "gettotalrefcount"):
		import gc
		counts = [None] * 5
		for i in xrange(len(counts)):
			test_support.run_unittest(*test_classes)
			gc.collect()
			counts[i] = sys.gettotalrefcount()
		print(counts)

if __name__ == "__main__":
	test_main(verbose=True)
