"""
Functions:
    This program is to generate training and testing data for ill-formed word detector
"""
import pickle
import langToolkit
import libSoundexImp
import libEditDist
import random
import sys

suffix = 'def'
windowSize = 3 # context window size
feaIdxDict = dict() #feature index dictionary
feaIdx = 0 #feature index
maxNum = 2000000 # number of training data
expWeight = 0.1 # default expansion weight

suppFullDict = None
soundexImp = libSoundexImp.doubleMetaPhone('../data/soundex.dict')
editDist = libEditDist.editingDistance('../data/dict.dict')
wordSet = pickle.load(open('../data/dict.pickle'))

def getFeaIdx(key):
    """
    Get index of dependency feature, if NOT exists, insert it in the feature index dictionary
    """
    global feaIdx, feaIdxDict
    if key not in feaIdxDict:
        feaIdx += 1
        feaIdxDict[key] = feaIdx
    return feaIdxDict[key]

def extractExpDepT(tokens, pos, minPos, maxPos, word):
    """
    Extract features for testing
    """
    global wordSet, feaIdxDict
    curPos = minPos - pos
    feaDict = dict()
    for token in tokens[minPos:maxPos + 1]:
        if token in wordSet:
            key = "{0}{1}{2}".format(token, curPos, word)
            if key in suppFullDict and key in feaIdxDict:
                feaDict[feaIdxDict[key]] = suppFullDict[key]
        elif expWeight > 0 and token.isalnum():
            expTokens = soundexImp.getSoundVariants(token)
            expDict = dict()
            for expToken in expTokens:
                key = "{0}{1}{2}".format(expToken, curPos, word)
                if key in suppFullDict and key in feaIdxDict:
                    expDict[feaIdxDict[key]] = suppFullDict[key]
            expRank = sorted(expDict.items(), key=lambda d:d[1], reverse=True)
            if len(expRank) > 0:
                feaDict[expRank[0][0]] = int(expRank[0][1] * expWeight)
        curPos += 1
    feaRank = sorted(feaDict.items(), key = lambda d:d[0])
    feaStr = " ".join("{0}:{1}".format(fea[0], fea[1]) for fea in feaRank)
    return feaStr

def extractExpDep(tokens, pos, minPos, maxPos, word):
    """
    Extract features for training
    """
    global wordSet
    curPos = minPos - pos
    feaDict = dict()
    for token in tokens[minPos:maxPos + 1]:
        if token in wordSet:
            key = "{0}{1}{2}".format(token, curPos, word)
            if key in suppFullDict:
                feaDict[getFeaIdx(key)] = suppFullDict[key]
        curPos += 1
    feaRank = sorted(feaDict.items(), key = lambda d:d[0])
    feaStr = " ".join("{0}:{1}".format(fea[0], fea[1]) for fea in feaRank)
    return feaStr

def train(trainFile):
    """
    Given a large clean tweets file, generate training file for ill-formed word detection
    """
    svmTrainSet = pickle.load(open('../data/svmtrainset.pickle')) # Choose OOV neighbour words for training, e.g. neighbour words that have 1 edit distance.
    fin = open(trainFile)
    fout = open('../dect/train.' + suffix, 'w')
    curSen = 0
    global feaIdxDict, maxNum, windowSize
    while True:
        line = fin.readline().rstrip().lower()
        tokens = langToolkit.tokenizeTweet(line)
        # training 
        for i in range(len(tokens)):
            if tokens[i] in svmTrainSet:
                minPos = i - windowSize
                if minPos < 0:
                    minPos = 0
                maxPos = i + windowSize + 1
                if maxPos >= len(tokens):
                    maxPos = len(tokens) - 1
                # positive results
                feaStr = extractExpDep(tokens, i, minPos, maxPos, tokens[i])
                if not feaStr:
                    continue
                curSen += 1
                # TODO:write positive result
                fout.write("+1 {0}\n".format(feaStr))
                # negative results
                negs = soundexImp.getSoundVariants(tokens[i])
                negs.remove(tokens[i])
                negTimer = 5
                for neg in negs:
                    if negTimer == 0:
                        break
                    feaStr = extractExpDep(tokens, i, minPos, maxPos, neg)
                    if not feaStr:
                        continue
                    # TODO:write negative result
                    fout.write("-1 {0}\n".format(feaStr))
                    negTimer -= 1
                    curSen += 1
        if curSen >= maxNum:
            break
    fin.close()
    fout.close()
    pickle.dump(feaIdxDict, open('../dect/feaPickle.' + suffix, 'w'))

def test(testFile):
    """
    Given tweetes, genereate prediction input data for ill-formed word detection
    """
    fin = open(testFile)
    fout = open('../dect/test.'+suffix, 'w')
    foracle = open('../dect/oracle.' + suffix, 'w')
    global wordSet, windowSize, feaIdxDict
    if not feaIdxDict:
        feaIdxDict = pickle.load(open('../dect/feaPickle.' + suffix)) #load feature index dictionary
    while True:
        line = fin.readline()
        if not line:
            break
        num = int(line)
        tokens = list()
        cpList = list()
        cpDict = dict()
        for i in range(num):
            line = fin.readline().lower().rstrip()
            segs = line.split('\t')
            tokens.append(segs[0])
            if segs[1] != segs[0]:
                cpDict[i] = segs[1] # get ill-formed positions
        cpList = langToolkit.getCheckpoints(tokens) # get OOV positions
        for cp in cpList:
            minPos = cp - windowSize
            if minPos < 0:
                minPos = 0
            maxPos = cp + windowSize + 1
            if maxPos > num - 1:
                maxPos = num - 1
            candidates = (soundexImp.getSoundEditVariants(tokens[cp], 1) | editDist.getEditVariants(tokens[cp], 2))
            validCounter = 0
            truthFlag = "-1" # default flag is false, 
            if cp in cpDict: # if an OOV is ill-formed word, label it as true
                truthFlag = "+1"    
            for candidate in candidates:
                feaStr = extractExpDepT(tokens, cp, minPos, maxPos, candidate)
                if not feaStr:
                    continue
                validCounter += 1 
                fout.write("{0}\n".format(feaStr))
            foracle.write("{0}\t{1}\n".format(validCounter, truthFlag))
    fin.close()
    fout.close()    
    foracle.close()

if __name__ == "__main__":
    if len(sys.argv) == 4:
        suffix = sys.argv[1]
        depbank = sys.argv[2]
        expWeight = float(sys.argv[3])
        suppFullDict = pickle.load(open(depbank))
        #train('./data/largeEngTweets')
        test('../data/corpus.tweet1')
    else:
        print "{0} suffixName depbank expWeight".format(sys.argv[0])
