"""
Functions:
    This module is to calculate recalls and candidate numbers with different confusion generation settings
"""
import libSoundexImp
import libEditDist
import langToolkit
import lm
import pickle

dictSet = pickle.load(open('../data/dict.pickle'))
editDist = libEditDist.editingDistance('../data/dict.dict')
soundexImp = libSoundexImp.doubleMetaPhone('../data/soundex.dict')
ratio = 0.1 # ratio of truncated candidates based on descending order language model score ranking
ngram = 3 # language model order

def shrink(token):
    """
    Truncate consecutive letters (maximum number is 3)
    """
    weights = [0] * len(token)
    weights[0] = 1
    for i in range(len(token)-1):
        if token[i] == token[i + 1]:
            weights[i+1] = weights[i] + 1
        else:
            weights[i+1] = 1
    sList = []
    for i in range(len(weights)):
        if weights[i] < 3:
            sList.append(token[i])
    return "".join(sList)

def getTopN(tokens, cp, candSet, correct):
   """
   Calculate recall after language model pruning
   """
   global ngram, ratio
   num = int(len(candSet) * ratio) + 1
   hit = False
   s = 0
   e = len(tokens)
   #gen list
   senList = []
   candList = []
   for cand in candSet:
      tokens[cp] = cand
      candList.append(cand)
      senList.append(" ".join(tokens[max(s, cp - ngram): min(e, cp + ngram + 1)]))

   resultList = lm.getLMScoreBatch(senList)
   #rank  
   ranker = sorted(list(zip(candList, resultList)), key = lambda cand: cand[1], reverse=True)[:num]
   for k,v in ranker:
      if k == correct:
         hit = True
   return (num, hit)

def getRecallAndNumber(infile):
   # define variables
   global ratio
   t1 = 0.0 # 1 ed
   s1 = 0
   t2 = 0.0 # 2 ed
   s2 = 0
   t3 = 0.0 # sound
   s3 = 0
   t4 = 0.0 # 1 ed on sound
   s4 = 0
   t5 = 0.0 # 2 ed on sound
   s5 = 0
   t6 = 0.0 # 2 ed + 1 ed on sound
   s6 = 0
   t7 = 0.0 # 2 ed + 2 ed on sound
   s7 = 0
   total = 0.0 # total number of unique candidates
   lmt = 0 # Language model score prunning of (2 ed + 1 ed on sound)
   lms = 0 

   fin = open(infile, 'r')
   while True:
      line = fin.readline().rstrip()
      if not line:
         break
      num = int(line)
      tokenList = list()
      correctionDict = dict()
      for i in range(num):
         line = fin.readline().rstrip()
         tokens = line.split('\t')
         text = tokens[0]
         norm = tokens[1]
         if norm != text and norm in dictSet:
             correctionDict[i] = norm
         tokenList.append(text)
      # calculate recall and unique candidate number
      for i in correctionDict:
         token = shrink(tokenList[i])
         ed1 = editDist.getEditVariants(token, 1)
         ed2 = editDist.getEditVariants(token, 2)      
         sd = soundexImp.getSoundVariants(token)
         sd1 = soundexImp.getSoundEditVariants(token, 1)
         sd2 = soundexImp.getSoundEditVariants(token, 2)
         correct = correctionDict[i]
         total += 1

         (num, hit) = getTopN(tokenList, i, ed2 | sd1, correctionDict[i])
         if hit:
            lmt += 1
         lms += num
         if correct in ed1:
            t1 += 1
         s1 += len(ed1)
         if correct in ed2:
            t2 += 1
         s2 += len(ed2)
         if correct in sd:
            t3 += 1
         s3 += len(sd)
         if correct in sd1:
            t4 += 1
         s4 += len(sd1)
         if correct in sd2:
            t5 += 1
         s5 += len(sd2)
         if correct in sd1 or correct in ed2:
            t6 += 1
         else: # print out ill-formed tokens that can not be collected by type 6.
             print tokenList[i], correct
         s6 += len(sd1 | ed2)
         if correct in sd2 or correct in ed2:
            t7 += 1
         s7 += len(sd2 | ed2)
   fin.close()
   
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 1 ed".format(t1/total, t1, s1, float(s1) / t1)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 2 ed".format(t2/total, t2, s2, float(s2) / t2)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 0 sd".format(t3/total, t3, s3, float(s3) / t3)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 1 sd".format(t4/total, t4, s4, float(s4) / t4)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 2 sd".format(t5/total, t5, s5, float(s5) / t5)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 1 ed + 2 sd".format(t6/total, t6, s6, float(s6) / t6)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {3} -- 2 ed + 2 sd".format(t7/total, t7, s7, float(s7) / t7)
   print "Rec: {0:4.2%}, Hit: {1}, num: {2} avg: {4} -- 1 ed + 2 sd ratio: {3}".format(lmt/total, lmt, lms, ratio, float(lms) / lmt)

if __name__ == "__main__":
   getRecallAndNumber('../data/corpus.tweet1')
