#argv[1] = original
#argv[2] = generalized_Goal

import sys
import re

def chomp(s):
	return s.rstrip('\r\n')

class PCFGRule(object):

	def __init__(self, lhs, rhs, count):
		self._lhs = lhs
		self._rhs = rhs.split()
		self._count = float(count)
		
	
	def __repr__(self):
		return str(self._count) +"\t"+self._lhs+ " --> "+" ".join([x.replace('"',"") for x in self._rhs])
	
	def __eq__(self,other):
		return self.__repr__() == other.__repr__()
	
	def __hash__(self):
		return hash(self.__repr__())
		


class Grammar(object):
	
	def __init__(self):
		self._rules = {}
		self._TerminalToNT = {}
		self._counter = 0
	
	def addRule(self, rule):
		updated = False
		if len(rule._rhs) > 1:	#make sure that Terminals are introduced by unary rules
			for i,rh in enumerate(rule._rhs):
				if rh[0].islower() or rh[0].isdigit() or rh[0]=="-":	#its a terminal
					if rh in self._TerminalToNT.keys():
						rule._rhs[i] = self._TerminalToNT[rh]
					else:
						self._TerminalToNT[rh] = "DNT_"+str(self._counter)
						self.addRule(PCFGRule("DNT_"+str(self._counter),rh,1.0))
						rule._rhs[i] = "DNT_"+str(self._counter)
						self._counter += 1
		for posR in self._rules.setdefault(rule._lhs,[]):
			if posR._rhs == rule._rhs:
				posR._count += rule._count
				updated=True
		if not updated:
			self._rules.setdefault(rule._lhs,[]).append(rule)
		
	
	def getRules(self, lhs):
		return self._rules[lhs]
	
	def __getitem__(self, key):
		return self.getRules(key)
		
	def save(self, dest):
		output = open(dest, "w")
		NTs = self._rules.keys()
		NTs.sort()
		for lhs in NTs:
			for rule in set(self._rules[lhs]):
				output.write(rule.__repr__()+"\n")
		output.close()
	
	def turnIntoPCFG(self):
		pcfg = Grammar()
		for lhs in self._rules.keys():
			totalCount = float(sum([int(x._count) for x in self._rules[lhs]]))
			for rule in self._rules[lhs]:
				rule._count = int(rule._count) / totalCount
		
	def readGrammar(self,src):
		g = file(src,"r")
		for line in [chomp(x) for x in g.readlines()]:
			if line=="":
				continue
			prob, rest = line.split("\t")
			LHS, RHS = rest.split(" --> ")
			self.addRule(PCFGRule(LHS,RHS,prob))
			
	def generalizeGrammar(self):
		generalizedG = Grammar()
		for lhs in self._rules.keys():
			if lhs[0] == "Q":	#ignore	--> those rules will either be dropped (Q) or be dealt with further below
				continue
			elif lhs[0:2] == "S_":	#the S-Terminals need to be generalized
				for rule in self._rules[lhs]:
					newRHS = " ".join(rule._rhs[1:])	#drop the Q-NT
					# merge all identical rules
					updated = False
					for posR in generalizedG._rules.setdefault(lhs,[]):	
						if " ".join(posR._rhs) == newRHS:
							posR._count += rule._count
							if posR._count > 1:
								posR._count = 1
							updated = True
					if not updated:
						generalizedG.addRule(PCFGRule(lhs,newRHS,rule._count))			
			else:
				for rule in self._rules[lhs]:
					generalizedG.addRule(rule)
		return generalizedG

if __name__ == "__main__":
	g = Grammar()
	g.readGrammar(sys.argv[1])
	f = g.generalizeGrammar()
	f.save(sys.argv[2])
