from __future__ import division
from __future__ import print_function
import numpy as np


class BeamEntry:
	"information about one single beam at specific time-step"
	def __init__(self):
		self.prTotal = 0 # blank and non-blank
		self.prNonBlank = 0 # non-blank
		self.prBlank = 0 # blank
		self.prText = 1 # LM score
		self.lmApplied = False # flag if LM was already applied to this beam
		self.labeling = () # beam-labeling


class BeamState:
	"information about the beams at specific time-step"
	def __init__(self):
		self.entries = {}

	def norm(self):
		"length-normalise LM score"
		for (k, _) in self.entries.items():
			labelingLen = len(self.entries[k].labeling)
			self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))

	def sort(self):
		"return beam-labelings, sorted by probability"
		beams = [v for (_, v) in self.entries.items()]
		sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
		return [x.labeling for x in sortedBeams]
	
	def sort_get_probs(self):
		"return beam-labelings, sorted by probability"
		beams = [v for (_, v) in self.entries.items()]
		sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
		# for x in sortedBeams:
		# 	print(x.labeling, x.prTotal, x.prText)
		return [(x.labeling, x.prTotal*x.prText) for x in sortedBeams]


def applyLM(parentBeam, childBeam, classes, lm):
	"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
	if lm and not childBeam.lmApplied:
		c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
		c2 = classes[childBeam.labeling[-1]] # second char
		lmFactor = 0.01 # influence of language model
		bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
		childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
		childBeam.lmApplied = True # only apply LM once per beam entry


def addBeam(beamState, labeling):
	"add beam if it does not yet exist"
	if labeling not in beamState.entries:
		beamState.entries[labeling] = BeamEntry()

import time
def ctcBeamSearch(mat, blank_idx, using_begin_index, lm, beamWidth=25, tokenizer=None):
	"beam search as described by the paper of Hwang et al. and the paper of Graves et al."

	blankIdx = blank_idx
	maxT, maxC = mat.shape
	# print(mat.shape)

	# initialise beam state
	last = BeamState()
	labeling = ()
	last.entries[labeling] = BeamEntry()
	last.entries[labeling].prBlank = 1
	last.entries[labeling].prTotal = 1

	# go over all time-steps
	for t in range(maxT):
		# print(t)
		curr = BeamState()

		# get beam-labelings of best beams
		bestLabelings = last.sort()[0:beamWidth]
		# print(bestLabelings)
		# go over best beams
		# print(bestLabelings)
		# for labeling in bestLabelings:
		# 	print("best ", tokenizer.string(labeling))
		begin_time = time.time()
		for labeling in bestLabelings:
			# print("best ", tokenizer.string(labeling), len(labeling))
			# probability of paths ending with a non-blank
			prNonBlank = 0
			# in case of non-empty beam
			if labeling:
				# probability of paths with repeated last char at the end
				prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]

			# probability of paths ending with a blank
			prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]

			# add beam at current time-step if needed
			addBeam(curr, labeling)

			# fill in data
			curr.entries[labeling].labeling = labeling
			curr.entries[labeling].prNonBlank += prNonBlank
			curr.entries[labeling].prBlank += prBlank
			curr.entries[labeling].prTotal += prBlank + prNonBlank
			curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
			curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling


			# 这里为了减少用时，不遍历整个vocab，而是只搜索前n大的character
			maxc_num = 20
			t_prob = mat[t]
			t_max_idx = np.argpartition(t_prob, -maxc_num)[-maxc_num:]

			# extend current beam-labeling
			# for c in range(4, maxC):
			for c in t_max_idx:
				# add new char to current beam-labeling
				if c == blankIdx: # 跳过blank index
					continue

				newLabeling = labeling + (c,)
				# if new labeling contains duplicate char at the end, only consider paths ending with a blank
				if labeling and labeling[-1] == c:
					prNonBlank = mat[t, c] * last.entries[labeling].prBlank
				else:
					prNonBlank = mat[t, c] * last.entries[labeling].prTotal

				# add beam at current time-step if needed
				addBeam(curr, newLabeling)
				
				# fill in data
				curr.entries[newLabeling].labeling = newLabeling
				curr.entries[newLabeling].prNonBlank += prNonBlank
				curr.entries[newLabeling].prTotal += prNonBlank

				# print("check ", newLabeling, tokenizer.string(newLabeling), curr.entries[newLabeling].prTotal)
				
				# apply LM, not using now
				# applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)

		# set new beam state
		last = curr
		# print("ending one step")
		# for k, v in last.entries.items():
		# 	print(k, v.labeling)

	# normalise LM scores according to beam-labeling-length
	last.norm()

	 # sort by probability
	# bestLabeling = last.sort()[0] # get most probable labeling
	bestLabeling = last.sort_get_probs()
	
	return bestLabeling # 目前只需要idx就好

	# map labels to chars
	# res = ''
	# for l in bestLabeling:
	# 	res += classes[l]

	# return res


def testBeamSearch():
	"test decoder"
	classes = 'ab'
	mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]])
	print('Test beam search')
	expected = 'a'
	actual = ctcBeamSearch(mat, classes, None)
	print('Expected: "' + expected + '"')
	print('Actual: "' + actual + '"')
	print('OK' if expected == actual else 'ERROR')


if __name__ == '__main__':
	testBeamSearch()
