import sys
import os
import editdistance as ED
import LCS
import numpy as np
import re

print 'loading...'

ch_feature = []
en_feature = []
py_feature = []
input = open('ch_feature_index.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr=  line.split('\t')
    ch_feature.append(attr)
input.close()

input = open('en_feature_index.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    en_feature.append(attr)
input.close()

input = open('py_feature_index.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    py_feature.append(attr)
input.close()

ch_pr = {}
en_pr = {}
input = open('ch_pr.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    ch_pr[int(attr[0])] = float(attr[1])
input.close()

input = open('en_pr.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    en_pr[int(attr[0])] = float(attr[1])
input.close()

init_rule = {'a':['a','e','i','o'],'b':['b','v'],'c':['c','t'],'d':['d','t'],'e':['e','o'],'f':['f'],'g':['c','g'],'h':['h'],'i':['i'],'j':['j','g','z'],'k':['c','k','q'],'l':['l','r'],'m':['m'],'n':['n'],'o':['o','e'],'p':['p'],'q':['c','z'],'r':['r','g'],'s':['s','t'],'t':['t'],'u':['u'],'w':['a','o','u','v'],'x':['s','h'],'y':['i','a','u','y'],'z':['z','j']}

yuan_yin = ['a','e','i','o','u']



ch_neighbor_dict = {}
ch_neighbor_weight = {}
input = open('ch_ad_table.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    ch_neighbor_dict[int(attr[0])]=eval(attr[1])
    weight = 0.0
    for e in eval(attr[1]):
        weight += e[1]
    ch_neighbor_weight[int(attr[0])] = weight
input.close()

en_neighbor_dict = {}
input = open('en_ad_table.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    en_neighbor_dict[int(attr[0])] = eval(attr[1])
input.close()
en_ch_lexicon = {}
input = open('lexicon/en_lexicon')
lines1 = input.readlines()
input.close()
input = open('lexicon/ch_lexicon')
lines2 = input.readlines()
for line1,line2 in zip(lines1,lines2):
    line1 = line1.strip()
    line2 = line2.strip()
    line1 = line1.lower()
    attr = line1.split(' ')
    if len(attr)>1:
        line1 = ''
        for i in xrange(0,len(attr)):
            line1 += attr[i]
            if i<len(attr)-1:
                line1+='_'
    if line1 not in en_ch_lexicon:
        en_ch_lexicon[line1] = set()
    en_ch_lexicon[line1].add(line2)
input.close()

print 'loading finished.'

ch_cand_dict = {}
ch_score_sum_dict = {}
key_dict = {}
key_list = []
value_list = []
input = open('init_cand.txt')
lines = input.readlines()
index = 0
for line in lines:
    line = line.strip()
    attr = line.split('\t')
    key_list.append(attr[0]+' '+attr[1])
    value_list.append(float(attr[2]))
    key_dict[attr[0]+' '+attr[1]] = index
    index += 1
    ch_f = int(attr[0])
    if ch_f not in ch_cand_dict:
        ch_cand_dict[ch_f] = set()
    ch_cand_dict[ch_f].add(int(attr[1]))
    if ch_f not in ch_score_sum_dict:
        ch_score_sum_dict[ch_f] = 0.0
    ch_score_sum_dict[ch_f] += float(attr[2])
input.close()

ch_burst_dict = {}
en_burst_dict = {}
input = open('ch_kleinburg_burst.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    space_index = line.find(' ')
    ch_term = line[0:space_index]
    burst_seq = line[space_index+1:]
    ch_burst_dict[ch_term] = np.array(list(map(int,burst_seq.split(' '))))
input.close()

input = open('en_kleinburg_burst.txt')
lines = input.readlines()
for line in lines:
    line = line.strip()
    space_index = line.find(' ')
    en_term = line[0:space_index]
    burst_seq = line[space_index+1:]
    en_burst_dict[en_term] = np.array(list(map(int,burst_seq.split(' '))))
input.close()

ascii_pattern = re.compile('[A-Za-z0-9]+')
cap_pattern = re.compile('[A-Z][a-z0-9]+')

def foo(ch_n_list,en_n_list,ch_f,en_f):
    high_count = 0
    score = 0.0
    for ch_nei in ch_n_list:
        ch_nei_f = ch_nei[0]
        #weight = ch_nei[1]*1.0/ch_neighbor_weight[ch_f]
        weight = ch_nei[1]*1.0/ch_neighbor_weight[ch_nei[0]]
        max_nei = 0.0
        for en_nei in en_n_list:
            en_nei_f = en_nei[0]
            key = str(ch_nei_f)+' '+str(en_nei_f)
            if key in key_dict:
                if value_list[key_dict[key]] > max_nei:
                    max_nei = value_list[key_dict[key]]
        score += weight*max_nei
        if max_nei>0.5:
            high_count += 1
        if len(ascii_pattern.findall(ch_feature[ch_nei_f][0]))>0:
            key = str(ch_nei_f)+' '+str(en_f)
            if key in key_dict:
                score += weight * value_list[key_dict[key]]
                high_count += 1
    return score, high_count

py_score_dict = {}
LCS_score_dict = {}
coburst_score_dict = {}

n_iter = 20

arith = ED.arithmetic()

l1 = 4.0
l2 = 4.0
l3 = 2.0
l4 = 5.0
l5 = 10.0

for iters in xrange(0,n_iter):
    print iters
    temp_value = []
    temp_ch_value = {}
    for key,value in zip(key_list,value_list):
        attr = key.split(' ')
        ch_f = int(attr[0])
        en_f = int(attr[1])
        score = value
        if score == 1.0:
            temp_value.append(score)
            temp_ch_value[ch_f] = ch_score_sum_dict[ch_f]
            continue
        ch_n_list = ch_neighbor_dict[ch_f] #key exists?
        en_n_list = en_neighbor_dict[en_f]
        neighbor_score,high_count = foo(ch_n_list,en_n_list,ch_f,en_f)
        if key not in py_score_dict:
            min_py_score = 9999.9
            ch_term = ch_feature[ch_f][0]
            py_form = py_feature[ch_f][0].lower()
            possible_form = set()
            possible_form.add(py_form)
            en_word = en_feature[en_f][0]
            cap_list = cap_pattern.findall(ch_term)
            flag = False
            if len(cap_list)>1:
                for cap_part in cap_list:
                    cap_part = cap_part.lower()
                    edis = arith.levenshtein(cap_part,en_word)
                    if edis*1.0/len(en_word)<min_py_score:
                        min_py_score = edis*1.0/len(en_word)
                if '_' in en_word:
                    part_count = 0
                    parts = en_word.split('_')
                    for part in parts:
                        for cap_part in cap_list:
                            if part == cap_part:
                                part_count += 1
                    if part_count == min(len(cap_list),len(parts)):
                        min_py_score = 0.0
                    elif part_count > 1:
                        min_py_score = 0.5
                    else:
                        flag = True
            init_py = py_form[0]
            init_en = en_word[0]
            py_bool = False
            possible_init_en = []
            if init_py in init_rule:
                possible_init_en = init_rule[init_py]
                py_bool = True
            for p_init in possible_init_en:
                if p_init != init_py:
                    if p_init in yuan_yin:
                        if init_py in yuan_yin:
                            new_py_form = p_init+py_form[1:]
                        else:
                            new_py_form = py_form[1:]
                    else:
                        new_py_form = p_init + py_form[1:]
                    possible_form.add(new_py_form)
            edistance = 9999.9
            if init_en in possible_init_en:
                for pos_py in possible_form:
                    pos_edistance = arith.levenshtein(pos_py,en_word)
                    if pos_edistance < edistance:
                        edistance = pos_edistance
            if edistance*1.0/len(en_word) < min_py_score:
                min_py_score = edistance*1.0/len(en_word)
            if '_' in en_word:
                parts = en_word.split('_')
                edistance = 9999.9
                for part in parts:
                    if len(part)<4:
                        continue
                    if py_bool == True and part[0] not in possible_init_en:
                        edistance = 9999.9
                    else:
                        for pos_py in possible_form:
                            pos_edistance = arith.levenshtein(pos_py,part)
                            if pos_edistance < edistance:
                                edistance = pos_edistance
                    if edistance*1.0/len(part) < min_py_score:
                        min_py_score = edistance*1.0/len(part)
            if flag == False:
                py_score_dict[key] = min_py_score
            else:
                py_score_dict[key] = 9999.9
        if key not in LCS_score_dict:
            en_word = en_feature[en_f][0]
            ch_word = ch_feature[ch_f][0]
            max_lcs = 0
            if en_word in en_ch_lexicon:
                ch_lexicon_set = en_ch_lexicon[en_word]
                for ch_lexicon in ch_lexicon_set:
                    lcs = LCS.LCSlength(ch_word,ch_lexicon)
                    if lcs > max_lcs:
                        max_lcs = lcs
            u_word = ch_word.decode('utf-8')
            LCS_score_dict[key] = max_lcs*1.0/len(u_word)
        if key not in coburst_score_dict:
            en_word = en_feature[en_f][0]
            ch_word = ch_feature[ch_f][0]
            ch_burst = ch_burst_dict[ch_word]
            en_burst = en_burst_dict[en_word]
            co_burst = np.sum(ch_burst*en_burst)*1.0/(ch_burst.sum()+en_burst.sum())
            coburst_score_dict[key] = float(co_burst)
        py_score = py_score_dict[key]
        lcs_score = LCS_score_dict[key]
        co_burst_score = coburst_score_dict[key]
        
        
        update_score = 0.0
        bar = 0.5
        en_term = en_feature[en_f] 
        #pinyin
        if py_score<=0.25:
           update_score += 3/l1
           bar = 0.8
        elif py_score<=0.34:
            update_score += 2.5/l1
            bar = 0.65
        elif py_score<=0.5:
            update_score += 2/l1
            bar = 0.55
        #lcs
        if lcs_score >= 0.75:
            update_score += 3/l2
            bar = 0.7
        elif lcs_score >=0.66666:
            update_score += 1.5/l2
            bar = 0.6
        elif lcs_score >= 0.5:
            bar = 0.55
            update_score += 1.0/l2
        #neighbor
        update_score += min(0.4,neighbor_score/l3)
        #co_burst
        update_score += co_burst_score/l4
        if en_term[0] in en_ch_lexicon:
            if lcs_score < 0.5:
                update_score = 0.0
        if high_count<1 and py_score > 0.25:
            update_score = 0.0

        ch_pr_value = ch_pr[ch_f]
        en_pr_value = en_pr[en_f]

        diff = abs(ch_pr_value-en_pr_value)

        if diff>2:
            update_score -= 3/l5
        elif diff>1:
            update_score -= 1.5/l5
        elif diff>0.55:
            update_score -= 0.5/l5

        update_score = max(0,min(0.99,update_score))
        #if ch_f == 98166 and en_f == 88548:
        #    print high_count,update_score, neighbor_score/l3
        #normalize
        #if update_score >= 0.5 and update_score + ch_score_sum_dict[ch_f] - score > 1:
        #if update_score != 0.0:
            #update_score = bar*update_score / (update_score + ch_score_sum_dict[ch_f] - score)
            
        temp_value.append(update_score)
        if ch_f not in temp_ch_value:
            temp_ch_value[ch_f] = ch_score_sum_dict[ch_f]
        temp_ch_value[ch_f] = update_score + temp_ch_value[ch_f] - score
#        ch_score_sum_dict[ch_f] = ch_score_sum_dict[ch_f] - score + update_score
    if len(value_list)!=len(temp_value):
        print 'warning big...'
    value_list = list(temp_value)
    ch_score_sum_dict = dict(temp_ch_value)

out = open('validation_output.txt','w')
for k,v in zip(key_list,value_list):
    out.write(k+'\t'+str(v)+'\n')
out.close()
