# coding:utf-8
import json
from nltk.stem import WordNetLemmatizer
import string
import re
import random
import torch
from tqdm import tqdm
import opencc
from transformers import XLMRobertaTokenizer, XLMRobertaModel, AdamW
import OpenHowNet
import thulac

src_babel_data_file = './data/babel_synset_list.txt'
synset_with_sememe_file = './data/synset_sememes.txt'
babel_data_file = './data/babel_data.json'
sememes = './data/sememe_all.txt'
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
cc = opencc.OpenCC('t2s')
wnl = WordNetLemmatizer()
seg_thulac = thulac.thulac(seg_only=True, T2S=True)
hownet_dict = OpenHowNet.HowNetDict()

def read_num(fin):
    try:
        line = fin.readline().strip()
        return eval(line)
    except:
        print(line)

def read_list(fin, **kword):
    line = fin.readline()
    if 'sep' in kword:
            line = line.strip().split('\t')
    else:
            line = line.strip().split()
    try:
            num = eval(line[0])
            line = line[1:]
    except:
            return []
    return line

def read_synset_id(fin):
    num_str = fin.readline()
    if not num_str:
        return None
    while num_str[0:3]!='bn:':
        num_str = fin.readline()
    num_str = num_str.strip()
    return num_str

def read_synset_definition(fin):
    num_str = fin.readline()
    if not num_str:
        return None
    num_str = num_str.strip()
    e_word = fin.readline().strip()
    c_word = fin.readline().strip()
    f_word = fin.readline().strip()
    glosses_EN = read_list(fin, sep = '\t')
    glosses_ZH = read_list(fin, sep = '\t')
    glosses_FR = read_list(fin, sep = '\t')
    return [num_str, e_word, c_word, f_word, glosses_EN, glosses_ZH, glosses_FR]

def read_synset_id_sememes(fin, with_noun = True):
    synset_dic = {}
    synset_id_list = []
    sememes = []
    while True:
        line = fin.readline()
        if not line:
            return sememes, synset_dic, synset_id_list
        line = line.strip().split()
        synset_id = line[0]
        if not with_noun:
            if synset_id[-1]!='n':
                continue
        synset_sememes = line[1:]
        synset_id_list.append(synset_id)
        if synset_id not in synset_dic:
            synset_dic[synset_id] = {}
        synset_dic[synset_id]['sememes'] = synset_sememes
        for sememe in synset_sememes:
            if sememe not in sememes:
                sememes.append(sememe)

def data_clean(word_str, lg):
    if lg == 'cn':
        pattern = re.compile(r'[^\u4e00-\u9fa5]')
    else:
        pattern = re.compile('[^a-z^A-Z^\s]')
    word_str = re.sub(pattern, '', word_str)
    return word_str

def gen_babel_data():
    fin_ss = open(synset_with_sememe_file,'r',encoding = 'utf-8')
    fin_def = open(src_babel_data_file,'r',encoding = 'utf-8')
    fout = open(babel_data_file, 'w', encoding = 'utf-8')
    sememes, synset_dic, synset_id_list = read_synset_id_sememes(fin_ss)
    fin_ss.close()
    num = 0
    while True:
        item = read_synset_definition(fin_def)
        if item == None:
            break
        if item[0] not in synset_id_list:
            continue
        synset_dic[item[0]]['word_en'] = item[1]
        synset_dic[item[0]]['word_cn'] = item[2]
        synset_dic[item[0]]['word_fr'] = item[3]
        synset_dic[item[0]]['definition_en'] = item[4]
        synset_dic[item[0]]['definition_cn'] = item[5]
        synset_dic[item[0]]['definition_fr'] = item[6]
    fin_def.close()
    synset_json = json.dump(synset_dic,fout,ensure_ascii=False)
    fout.close()

def split_data(data_file, save_path):
    data = json.load(open(data_file))
    l = [i for i in range(len(data))]
    random.shuffle(l)
    train_data = [data[i] for i in l[0:int(0.8*len(data))]]
    valid_data = [data[i] for i in l[int(0.8*len(data)):int(0.9*len(data))]]
    test_data = [data[i] for i in l[int(0.9*len(data)):]]
    print(len(train_data), len(valid_data), len(test_data))
    json.dump(train_data, open(save_path+'/train_data.json', 'w', encoding='utf-8'), ensure_ascii=False)
    json.dump(valid_data, open(save_path+'/valid_data.json', 'w', encoding='utf-8'), ensure_ascii=False)
    json.dump(test_data, open(save_path+'/test_data.json', 'w', encoding='utf-8'), ensure_ascii=False)

def gen_data():
    dst_data = []
    src_data_path = './data/babel_data.json'
    src_data = json.load(open(src_data_path))
    for k in tqdm(src_data):
        
        dst_data_instance = {}
        dst_data_instance['s'] = [ss.split('|')[1] for ss in src_data[k]['sememes']]
        dst_data_instance['b'] = k

        if 'word_en' in src_data[k].keys() and 'definition_en' in src_data[k].keys():
            dst_data_instance['w_e'] = src_data[k]['word_en'].lower()
            if len(src_data[k]['definition_en']) != 0:
                dst_data_instance['d_e'] = wnl.lemmatize(src_data[k]['definition_en'][0]).lower().split(' ')
                dst_data_instance['d_e_tw'] = dst_data_instance['w_e'] + [':'] + dst_data_instance['d_e']
        
        if 'word_cn' in src_data[k].keys() and 'definition_cn' in src_data[k].keys():
            dst_data_instance['w_c'] = cc.convert(src_data[k]['word_cn'])
            if len(src_data[k]['definition_cn']) != 0:
                dst_data_instance['d_c'] = seg_thulac.cut(src_data[k]['definition_cn'][0], text=True).split(' ')
                dst_data_instance['d_c_tw'] = dst_data_instance['w_c'] + [':'] + dst_data_instance['d_c']
        
        if 'word_fr' in src_data[k].keys() and 'definition_fr' in src_data[k].keys():
            dst_data_instance['w_f'] = src_data[k]['word_fr'].lower()
            if len(src_data[k]['definition_fr']) != 0:
                dst_data_instance['d_f'] = src_data[k]['definition_fr'][0].lower().split(' ')
                dst_data_instance['d_f_tw'] = dst_data_instance['w_f'] + [':'] + dst_data_instance['d_f']

        dst_data.append(dst_data_instance)
    with open('./data/data/data_all.json', 'w', encoding='utf-8') as f:
        json.dump(dst_data, f, ensure_ascii=False) 
    f.close()

def gen_tokens(dst_path = './data/data/', all_lang = False):
    with open('./data/sememe_all.txt', 'r', encoding='utf-8') as f:
        sememe_str = f.read()
    f.close()
    sememe_list = sememe_str.split(' ')
    data = json.load(open('./data/data/data_all.json'))
    final_data = []
    for instance in tqdm(data):
        final_data_instance = {}
        final_data_instance['b'] = instance['b']
        final_data_instance['s_i'] = [sememe_list.index(ss) for ss in instance['s']]
        final_data_instance['d_i'] = [0]
        final_data_instance['d_i_tw'] = [0]
        final_data_instance['i2s'] = []
        final_data_instance['i2s_tw'] = []
        idx = 0
        idx_tw = 0

        if all_lang:
            if 'd_e' not in instance.keys() or 'd_c' not in instance.keys() or 'd_f' not in instance.keys():
                continue
        if 'd_e' in instance.keys():
            for w in instance['d_e']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx += 1
                    idx_list.append(idx)
                    final_data_instance['d_i'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s'].append([idx_list, temp])
        final_data_instance['d_i'].append(2)
        idx += 1

        if 'd_e_tw' in instance.keys():
            for w in instance['d_e_tw']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx_tw += 1
                    idx_list.append(idx_tw)
                    final_data_instance['d_i_tw'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s_tw'].append([idx_list, temp])
        final_data_instance['d_i_tw'].append(2)
        idx_tw += 1
        
        if 'd_c' in instance.keys():
            final_data_instance['d_i'].append(2)
            idx += 1
            for w in instance['d_c']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx += 1
                    idx_list.append(idx)
                    final_data_instance['d_i'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s'].append([idx_list, temp])
            final_data_instance['d_i'].append(2)
            idx += 1

        if 'd_c_tw' in instance.keys():
            final_data_instance['d_i_tw'].append(2)
            idx_tw += 1
            for w in instance['d_c_tw']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx_tw += 1
                    idx_list.append(idx_tw)
                    final_data_instance['d_i_tw'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s_tw'].append([idx_list, temp])
            final_data_instance['d_i_tw'].append(2)
            idx_tw += 1
        
        if 'd_f' in instance.keys():
            final_data_instance['d_i'].append(2)
            idx += 1
            for w in instance['d_f']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx += 1
                    idx_list.append(idx)
                    final_data_instance['d_i'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s'].append([idx_list, temp])
            final_data_instance['d_i'].append(2)
            idx += 1

        if 'd_f_tw' in instance.keys():
            final_data_instance['d_i_tw'].append(2)
            idx_tw += 1
            for w in instance['d_f_tw']:
                idx_list = []
                word_ids = tokenizer(w)['input_ids']
                for i in range(1,len(word_ids)-1):
                    idx_tw += 1
                    idx_list.append(idx_tw)
                    final_data_instance['d_i_tw'].append(word_ids[i])
                ids_sememe = hownet_dict.get_sememes_by_word(w,structured=False,lang="zh",merge=True)
                if ids_sememe:
                    if isinstance(ids_sememe, dict):
                        ids_sememe = list(list(ids_sememe.items())[0][1])
                    elif isinstance(ids_sememe, set):
                        ids_sememe = list(ids_sememe)
                    temp = []
                    for s in ids_sememe:
                        if s in sememe_list:
                            temp.append(sememe_list.index(s))
                    if temp:
                        final_data_instance['i2s_tw'].append([idx_list, temp])
            final_data_instance['d_i_tw'].append(2)
            idx_tw += 1
        
        if len(final_data_instance['d_i']) > 512:
            final_data_instance['d_i'] = final_data_instance['d_i'][:512]
        if len(final_data_instance['d_i_tw']) > 512:
            final_data_instance['d_i_tw'] = final_data_instance['d_i_tw'][:512]
        final_data.append(final_data_instance)
    fout = open(dst_path + 'data.json', 'w', encoding = 'utf-8')
    json.dump(final_data, fout, ensure_ascii=False)

def gen_train_data():
    train = json.load(open('./data/data/train_list.json'))
    valid = json.load(open('./data/data/valid_list.json'))
    test = json.load(open('./data/data/test_list.json'))
    
    train_data = []
    valid_data = []
    test_data = []

    data = json.load(open('./data/data/data.json'))
    for instance in data:
        if instance['b'] in train:
            train_data.append(instance)
        elif instance['b'] in valid:
            valid_data.append(instance)
        elif instance['b'] in test:
            test_data.append(instance)
    assert(len(valid_data) == len(test_data))
    assert(len(valid_data)+len(test_data)+len(train_data) == 15461)
    json.dump(train_data, open('./data/data/train_data.json', 'w'))
    json.dump(valid_data, open('./data/data/valid_data.json', 'w'))
    json.dump(test_data, open('./data/data/test_data.json', 'w'))

def gen_POS_data():
    test_data = json.load(open('./data/data/test_data.json'))
    a = []
    v = []
    n = []
    r = []
    for i in test_data:
        if i['b'][-1] == 'a':
            a.append(i)
        if i['b'][-1] == 'v':
            v.append(i)
        if i['b'][-1] == 'n':
            n.append(i)
        if i['b'][-1] == 'r':
            r.append(i)
    json.dump(a, open('./data/data/a.json', 'w'))
    json.dump(v, open('./data/data/v.json', 'w'))
    json.dump(n, open('./data/data/n.json', 'w'))
    json.dump(r, open('./data/data/r.json', 'w'))    


if __name__ == "__main__":
    gen_babel_data()
    gen_data()
    gen_tokens()
    gen_train_data()
    gen_POS_data()
