"""
Functions:
    This module is the major experimental framework. It compares different features and baselines (except SMT)
"""
import langToolkit
import libEditDist
import libSoundexImp
import lm
import time
import math
import pickle
import sys
import logging

# init some golbal modules
editDist = libEditDist.editingDistance('../data/dict.dict') # init. edit distance obj
soundexImp = libSoundexImp.doubleMetaPhone('../data/soundex.dict') # init. double metaphone obj
internetDict = pickle.load(open('../data/slang.pickle')) # internet dictionary
suppDict = pickle.load(open('../data/contextSupport.pickle')) # support dictionary from dependencies
cacheDict = dict() # dictionary of (ill-formed word, canonical form) for unsupervised normalization
lmDict = pickle.load(open('../data/unigram.pickle')) # google unigram dictionary
dictSet = pickle.load(open('../data/dict.pickle')) # IV word dictionary

# some global settings
ngram = 3 # ngram for language model score
ratio = 0.1 # ratio of language model score pruned candidates
vowels = 'aeiou' # vowels

def shrink(token):
    """
    Shrink consecutive letters less than 3.
    """
    weights = [0] * len(token)
    weights[0] = 1
    for i in range(len(token)-1):
        if token[i] == token[i + 1]:
            weights[i+1] = weights[i] + 1
        else:
            weights[i+1] = 1
    sList = []
    for i in range(len(weights)):
        if weights[i] < 3:
            sList.append(token[i])
    return "".join(sList)


def genCandidates(token):
    """
    Generate confusion candidates for a given token by morphological/phonetice methods
    """
    token = shrink(token)
    candidateSet = (editDist.getEditVariants(token, 2) | soundexImp.getSoundEditVariants(token, 1))  # edit distance of 2 and soundex edit distance of 1
    return candidateSet

def pruneCandidates(tokens, pos, candidates):
    """
    Prune candidates by fitness of local context (pruning by language model score)
    """
    # generate context for each candidate
    prefix = " ".join(tokens[pos - ngram + 1: pos]) + " "
    suffix = " " + " ".join(tokens[pos + 1 : pos + ngram])
    reqList = [prefix + candidate + suffix for candidate in candidates]
    # get language model score
    scoreList = [math.exp(score) for score in  lm.getLMScoreBatch(reqList)]
    num = int(len(reqList) * ratio)
    # rank candidates
    candidateRank = sorted(list(zip(candidates, scoreList)), key = lambda candidate: candidate[1], reverse=True)[:num + 1]
    return candidateRank

def isPrefix(token, norm):
    """
    Prefix feature, following the unsupervised normalisation method
    """
    if token[-1:] in vowels:
        token = token[:-1]
    if len(token) > 2 and token[-2:-1] == token[-1:]:
        token = token[:-1]
    if norm.startswith(token):
        return 1.0
    else:
        return 0.0

def isSuffix(token, norm):
    """
    Suffix feature
    """
    if norm.endswith(token):
        return 1.0
    else:
        return 0.0

def rankFunc(item, dictList):
    """
    Linear combination of different features
    """
    key = item[0]
    return sum(dictList[i][key] for i in range(0, features))

def getContextSupp(tokens, pos, candidate):
    """
    Obtain the context support form conventional copora dependencies given a candidate.
    """
    global ngram, suppDict
    s = pos - ngram
    e = pos + ngram + 1
    if s < 0:
        s = 0
    if e > len(tokens):
        e = len(tokens)
    curPos = s - pos
    accSupp = 0.0 # accumulated context support, i.e. the sum of dependency occurrences in the corpora.
    for token in tokens[s:e]:
        if curPos != 0:
            key = "{0}{1}".format(token, str(curPos))
            if key in suppDict and candidate in suppDict[key]:
                accSupp += suppDict[key][candidate]
        curPos += 1
    return accSupp

def lCSeq(s1, s2):
    if not s1 or not s2:
        return ""
    else:
        if s1[-1:] == s2[-1:]:
            return lCSeq(s1[:-1], s2[:-1]) + s1[-1:]
        else:
            return max(lCSeq(s1[:-1], s2), lCSeq(s1, s2[:-1]))

def editDistFeature(tokens, pos, item):
    return 1.0 / math.exp(editDist.getEditDistance(tokens[pos], item[0]))

def soundEditDistFeature(tokens, pos, item):
    s1 = soundexImp.dm(tokens[pos])[0].lower()
    s2 = soundexImp.dm(item[0])[0].lower()
    return 1.0 / math.exp(editDist.getEditDistance(s1, s2))

def prefixFeature(tokens, pos, item):
    return isPrefix(tokens[pos], item[0]) 

def suffixFeature(tokens, pos, item):
    return isSuffix(tokens[pos], item[0])

def lcsFeature(tokens, pos, item):
    s1 = item[0]
    s2 = tokens[pos]
    lcs = lCSeq(s1, s2)
    return float(len(lcs)) / (max(len(s1), len(s2)))

def lmFeature(tokens, pos, item):
    return item[1]

def contextSuppFeature(tokens, pos, item):
    return getContextSupp(tokens, pos, item[0])
"""
In light of input parameter, e.g. 01111100, it means calculate Word Similarity features, which coppresond the feature table
"""
funcTable = [] # feature function table
funcTable.append(editDistFeature)
funcTable.append(soundEditDistFeature)
funcTable.append(prefixFeature)
funcTable.append(suffixFeature)
funcTable.append(lcsFeature)
funcTable.append(lmFeature)
funcTable.append(contextSuppFeature)

features = len(funcTable) # number of features for ranking process, including both word simlarity and context features
selector = '1' * features # default model type

def rankCandidates(tokens, pos, candidates):
    """
    Calculate feature contributions and rank candidates 
    """
    global selector
    # prune candidates
    rankList = pruneCandidates(tokens, pos, candidates)
    dictList = []
    normList = []
    for i in range(features):
        dictList.append(dict())
        normList.append(0.000000000001)
    
    # calculate features
    for item in rankList:
        for i in range(len(selector) - 1):
            if selector[i+1] == '1':
                dictList[i][item[0]] = funcTable[i](tokens, pos, item)
            else:
                dictList[i][item[0]] = 0.0

        for i in range(2, features):
            normList[i] += dictList[i][item[0]]

    # normalize features  
    for item in rankList:
        key = item[0]
        for i in range(2, features):
            dictList[i][key] /= normList[i]

    # integrated feature ranking
    reRankList = [
            (item[0], rankFunc(item, dictList))
            for item in rankList
            ]
    reRank = sorted(reRankList, key = lambda(k, v): (v, k), reverse=True)
    return reRank

def model_dl(tokenList, cp):
    """
    Dictionary lookup model
    """
    if tokenList[cp] in internetDict:
        result = internetDict[tokenList[cp]]
        if len(result.split(' ')) == 1:
            return result
    return tokenList[cp]

def model_fm(tokenList, cp):
    """
    Integration framework of different features
    """
    candidates = genCandidates(tokenList[cp])
    if len(candidates) == 0:
        return tokenList[cp]
    candidateRank = rankCandidates(tokenList, cp, candidates)
    if candidateRank[0][1] == 0:
        return tokenList[cp]
    return candidateRank[0][0]

def part1(token, norm):
    # Stylistic Variations.
    # We imitate the process by calculating the soundex code overlapping, since different graphemes and phonemes can be uniformly expressed in double metaphone code.
    # Furthermore, it is flexible than 2 character distance restrictions in the paper.
    tcode = soundexImp.dm(token)[0]
    ncode = soundexImp.dm(norm)[0]
    if len(tcode) == 0 or len(ncode) == 0:
        return 0.0
    lcs = lCSeq(tcode, ncode)

    ratio = float(len(lcs))/max(len(tcode), len(ncode))
    return ratio

def part2(token, norm):
    # Subsequence Abbreviations    
    maxLen = len(norm)    
    if maxLen < len(token):
        return 0.0
    else:
        idx = 0
        tokenLen = len(token)
        for letter in norm:
            if letter == token[idx]:
                idx += 1
                if idx == tokenLen:
                    return 1.0
        return 0.0

def part3(token, norm):
    # Prefix clippings
    if len(token) > 1 and token[-1:] in vowels:
        token = token[:-1]
    if len(token) > 2 and token[-2:-1] == token[-1:]:
        token = token[:-1]
    if norm.startswith(token) or norm.endswith(token):
        return 1.0
    else:
        return 0.0

def getMax(token):
    """
    Obtain the most probable normalisation used by noisy channel models
    """
    global dictSet # speed up by dictionary
    if token in cacheDict:
        return cacheDict[token]

    scoreList = list()    
    pdict1 = dict()
    pdict2 = dict()
    pdict3 = dict()

    for norm in dictSet:
        pdict1[norm] = part1(token, norm)
        pdict2[norm] = part2(token, norm)    
        pdict3[norm] = part3(token, norm)

    c1 = sum(1 for pscore in pdict1.values() if pscore > 0)
    if c1 != 0:
        c1 = float(1) / c1
    c2 = sum(1 for pscore in pdict2.values() if pscore > 0)
    if c2 != 0:
        c2 = float(1) / c2
    
    c3 = sum(1 for pscore in pdict3.values() if pscore > 0)    
    if c3 != 0:    
        c3 = float(1) / c3
    
    for norm in dictSet:
        if norm not in lmDict:
            continue
        score = (pdict1[norm] * c1 + pdict2[norm] * c2 + pdict3[norm] * c3) * math.exp(lmDict[norm])
        scoreList.append((norm, score))
    resultList = sorted(scoreList, key = lambda(k, v):(v, k), reverse=True)
    # If the score is zero then put original token in it
    if resultList[0][1] != 0:
        cacheDict[token] = resultList[0][0]
        return resultList[0][0]
    else:
        return token

def model_cook(tokenList, cp):
    """
    Implementation of noisy channel baseline
    """
    return getMax(tokenList[cp])

def normalize(tokenList, checkpointList):
    """
    Normalise ill-formed words by different methods
    """
    predictionDict = dict()
    for cp in checkpointList:        
        prediction = tokenList[cp]
        if selector == '1':
            prediction = model_dl(tokenList, cp)
        elif selector == '0':
            prediction = model_cook(tokenList, cp)
        else:
            if selector[0] == '1':
                prediction = model_dl(tokenList, cp)
                if prediction == tokenList[cp]:
                    prediction = model_fm(tokenList, cp)
            else: # word similarity (ws) or context support (cs) models
                prediction = model_fm(tokenList, cp)
        if prediction != tokenList[cp]:
            predictionDict[cp] = prediction
    for prediction in predictionDict:
        tokenList[prediction] = predictionDict[prediction]
    return (" ".join(tokenList), predictionDict)

def getOracle(tokenList, correctionDict):
    """
    Construct normalised messages from tokens.
    """
    for k, v in correctionDict.iteritems():
        tokenList[k] = v
    return " ".join(tokenList)

def evalNormalization(tokenList, modifications, correctionDict):
    """
    Evaluate each normalisation, and record logs
    """
    truePos = 0
    falsePos = 0
    trueNeg = 0
    s = 0
    e = len(tokenList)
    for k in modifications:
        if correctionDict[k] == modifications[k]:
            truePos += 1
        else:
            falsePos += 1
            logging.warning('falsePos: {0}\t{1}\t{2}'.format(modifications[k], correctionDict[k], " ".join(tokenList[max(s, k - 3): min(k + 4, e)])))
    for k in correctionDict:
        if k not in modifications:
            trueNeg += 1
            logging.warning('trueneg: {0}\t{1}'.format(correctionDict[k], " ".join(tokenList[max(s, k - 3): min(k + 4, e)])))
    return (truePos, falsePos, trueNeg)

def predict(infile):
    """
    Conduct the experiments on a corpus
    """
    fin = open(infile, 'r')
    # a true error is corrected
    totalTruePos = 0
    # a true error is not corrected
    totalFalsePos = 0
    # a true error is missed
    totalTrueNeg = 0

    while True:
        line = fin.readline().rstrip()
        if not line:
            break
        num = int(line)
        tokenList = list()
        correctionDict = dict()
        for i in range(num):
            line = fin.readline().rstrip()
            tokens = line.split('\t')
            text = tokens[0].lower().strip()
            if len(tokens) == 2:
                norm = tokens[1].lower().strip()
                if norm != text:
                    correctionDict[i] = norm
            tokenList.append(text)
        # identify OOV words
        cpList = langToolkit.getCheckpoints(tokenList)
        # we separate the ill-formed word detection process, here we offer the ill-formed words.
        checkpointList = list()
        for cp in cpList:
            if cp in correctionDict:
                checkpointList.append(cp)
        # normalisation
        prediction, modifications = normalize(list(tokenList), checkpointList)
        oracle = getOracle(list(tokenList), correctionDict)
        truePos, falsePos, trueNeg = evalNormalization(tokenList, modifications, correctionDict)
        totalTruePos += truePos
        totalFalsePos += falsePos
        totalTrueNeg += trueNeg
        # construct normalised message from tokens
        for correctionKey in correctionDict:
            if correctionKey in modifications:
                tokenList[correctionKey] = modifications[correctionKey]
        print " ".join(tokenList)
        print oracle
    fin.close()
    # calculate evaluation metrics
    prec = float(totalTruePos) / (totalTruePos + totalFalsePos)
    rec = float(totalTruePos) / (totalTruePos + totalFalsePos + totalTrueNeg)
    logging.warn("TruePositive: {0}, FalsePositive: {1}, TrueNegative: {2}, Prec:{3:4.2%}, Recall: {4:4.2%}, F-score: {5:4.2%}".format(totalTruePos, \
            totalFalsePos, \
            totalTrueNeg, \
            prec, \
            rec,
            2 * prec * rec / (prec + rec)
            )
            )
        
if __name__ == '__main__':
    inFile = sys.argv[1]
    selector = sys.argv[2]
    predict(inFile)
