import enum
import demjson
import re

from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torch.nn.functional as F
from skimage import io,transform
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import transformers as tfs
import math
import random
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import scipy as sp
import scipy.stats
import warnings
#import preprocessor
import os
import json
import tqdm
import createdata
import tree2seq 
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction


from createdata import sememeDataset
from transformers import BertModel, BertTokenizer
from model import TreePred

import matplotlib.pyplot as plt
import seaborn as sns

torch.cuda._initialized = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def find_synset_with_different_sense_but_same_sememe():
    trainData = torch.load("/data3/private/yyn/structPred/datas/trainSet.pkl")
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=1, MASK=1, seq = False, depthMethod="depth", biasMethod="distance")
    data = []
    for contid in tqdm.tqdm(range(len(trainData))):
        cont = trainData[contid]
        set_me = set(cont[2][0])
        for id,sememe_set in enumerate(data):
            if len(sememe_set) > 4 and "~" not in set_me and "part|部件" not in set_me and (set_me - sememe_set) == set() and (sememe_set - set_me) == set() and trainData[id][2][0] != cont[2][0]:
                print("*"*20)
                a = [model.sememe2id[i] for i in trainData[id][2][0]] + [2084]
                a = [str(i) for i in a]
                b = [model.sememe2id[i] for i in cont[2][0]] + [2084]
                b = [str(i) for i in b]
                scores,meTree,answerTree = tree2seq.computeStrict(a,b,id2sememe=model.id2sememe)
                print(trainData[id])
                print(meTree)
                print(cont)
                print(answerTree)
                
        data.append(set_me)
        
#find_synset_with_different_sense_but_same_sememe()
#exit()

def check_zero_shot():
    trainData = torch.load("/data3/private/yyn/structPred/datas/trainSet.pkl")
    testData = torch.load("/data3/private/yyn/structPred/datas/testSet.pkl")
    shunxu = list(range(len(testData)))
    random.seed(2077)
    random.shuffle(shunxu)
    
    map_pair = {}
    for contid in range(len(trainData)):
        cont = trainData[contid]
        stack = [cont[2][0][0]]
        for data_str in cont[2][0][1:]:
            if data_str == "end":
                stack.pop()
            else:
                if map_pair.get(stack[-1]):
                    map_pair[stack[-1]].append(data_str)
                else:
                    map_pair[stack[-1]] = [data_str]
                stack.append(data_str)
                
    bad = 0
    total = 0
    zeros = []
    drop = 0
    for contid in range(len(testData)):
        cont = testData[contid]
        total += 1
        stack = [cont[2][0][0]]
        #print(cont[2][0])
        drop = 0
        for data_str in cont[2][0][1:]:
            if data_str == "end":
                
                stack.pop()
            else:
                if (not map_pair.get(stack[-1])) or (data_str not in map_pair[stack[-1]]):              
                    drop += 1
                stack.append(data_str)
        if drop != 0:
            bad += 1 
        zeros.append(drop)
    print(f"{bad}, {total}")
    return zeros

def visualize_attn():
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=1, MASK=1, seq = False, depthMethod="depth", biasMethod="distance")
    loadname = "tupe_depth_bias_380"
    stateTest = torch.load(f"/data3/private/yyn/structPred/models/"+loadname+"/modelWeight.pt",map_location=torch.device('cpu'))
    model.load_state_dict(stateTest['net'],strict=False)
    model.loadBert(f"/data3/private/yyn/structPred/models/"+loadname+"/")
    model.device = torch.device("cpu")
    model.to('cpu')
    model.eval()
    
    with torch.no_grad():
        input = ["A rapid descent by a submarine"]
        seq = ["start","GoDown|下去","abnormal|不正常","end","location|位置","beneath|下","end",
               "waters|水域","end","end","weapon|武器","military|军","end","ship|船","end","end","end","end"]
        sem, pos, tot = model.visualize(input, seq = seq)
        seq = [i.split("|")[0] for i in seq]
        seqbig = seq + seq + seq
        #plt.imshow(sem, cmap=plt.cm.hot)
        out = torch.cat((sem,pos,tot),dim=1)
        print(out.size())
        plt.figure(figsize=(20,6))
        ax = sns.heatmap(out, xticklabels=seqbig, yticklabels = seq, robust = True)
        plt.savefig("./pictures/heat_attn.png")
        plt.close()
        
    
#visualize_attn()
#exit()

def check_zero_shot_with_size():
    zeros = check_zero_shot()
    outData = torch.load("./data/base_cls_400_out.pkl")
    s1,s2 = [],[]
    for id, cont in enumerate(outData):
        if len(cont[2][0]) <= 4:
            s1.append(min(zeros[id],1))
        else:
            s2.append(min(zeros[id],1))
    s1 =np.array(s1)
    print(np.mean(s1))
    s2 =np.array(s2)
    print(np.mean(s2))
    
def check_output_size(name):
    outData = torch.load("./data/"+name+".pkl")
    s1 = []
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=0, MASK=1, seq = False, depthMethod="", biasMethod="bias")
    for id, cont in enumerate(outData):
        if len(cont[2][0]) <= 4:
            ans = [model.sememe2id[i] for i in cont[2][0]] + [2084]
            ans = [str(i) for i in ans]
            ls = ["2083"] + cont[-2] + ["2084"]
                #print(ans)
                #print(dataout[dataid])
                #exit()
            scores,meTree,answerTree = tree2seq.computeStrict(ls,ans,id2sememe=model.id2sememe)
            #s1.append(len(cont[-2]) + 2)
            s1.append(scores[-2:])
    s1 = np.array(s1)
    return np.mean(s1,axis=0)

#for name in ["base_cls_400_out","tupe_depth_bias_380_out"]:
#    print(check_output_size(name))


#check_zero_shot_with_size()

def gen_demo(testResult = ["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out"]):
    
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=0, MASK=1, seq = False, depthMethod="", biasMethod="bias")
    testData = torch.load("/data3/private/yyn/structPred/datas/testSet.pkl")
    shunxu = list(range(len(testData)))
    f = []
    for name in testResult:
        f.append(torch.load("./data/"+name+".pkl"))
    random.seed(2077)
    random.shuffle(shunxu)
    for dataid, testid in enumerate(shunxu):
        if len(testData[testid][2][0]) <= 4 and "~" not in testData[testid][2][0]:
            print("*********************")
            print(testData[testid][1][0])
            for dataout in f:
                ans = [model.sememe2id[i] for i in testData[testid][2][0]] + [2084]
                ans = [str(i) for i in ans]
                ls = ["2083"] + dataout[dataid][-2] + ["2084"]
                #print(ans)
                #print(dataout[dataid])
                #exit()
                scores,meTree,answerTree = tree2seq.computeStrict(ls,ans,id2sememe=model.id2sememe)
                print(meTree)
                print(dataout[dataid][-1])
                #print(answerTree)
                #print(scores)
                #print(dataout[dataid][-1])
            print(answerTree)


def test_zero_shot(test_name):
    outData = torch.load("./data/"+test_name+".pkl")
    zeros = check_zero_shot()
    zero_result_list = [[] for i in range(3)]
    for id,cont in enumerate(outData):
        if zeros[id] > 2:
            pass
            #zero_result_list.append(cont[-1])
        else:
            zero_result_list[zeros[id]].append(cont[-1])
    for i in range(3):
        zero_result_list[i] = np.array(zero_result_list[i])
        zero_result_list[i] = np.mean(zero_result_list[i], axis = 0)
    zero_result_list = np.array(zero_result_list)
    print(zero_result_list)
    #print(zero_result)
    #print(non_zero_result)

def data_distribute(testSet = "/data3/private/yyn/structPred/datas/testSet.pkl"):
    testData = torch.load(testSet)
    shunxu = list(range(len(testData)))
    random.seed(2077)
    random.shuffle(shunxu)
    length_list = []
    depth_list = []
    for cont in testData:
        #print(cont)
        lens = (len(cont[2][0]) - 1) // 2
        depth_m, depth_n = 0, 0
        for a in cont[2][0][1:]:
            if a != "end":
                depth_n += 1
            else:
                depth_n -= 1
            depth_m = max(depth_m, depth_n)
        length_list.append(lens)
        depth_list.append(depth_m)
    
    plt.hist(length_list, bins=20)
    plt.savefig("./pictures/dataset_length.png")
    plt.close()
    plt.hist(depth_list)
    plt.savefig("./pictures/dataset_depth.png")
    plt.close()
    
def test_set_similar_with_train_set():
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=0, MASK=1, seq = False, depthMethod="", biasMethod="bias")
    trainData = torch.load("/data3/private/yyn/structPred/datas/trainSet.pkl")
    testData = torch.load("/data3/private/yyn/structPred/datas/testSet.pkl")
    test_data_modified = []

    for cid in tqdm.tqdm(range(len(testData))):
        cont = testData[cid]
        maxx = [0.0]
        tempdata = []
        for dataid in range(len(trainData)):
            data = trainData[dataid]
            me = [model.sememe2id[i] for i in cont[2][0]] + [2084]
            me = [str(i) for i in me]
            he = [model.sememe2id[i] for i in data[2][0]] + [2084]
            he = [str(i) for i in he]
            #print(he)
            #print(me)
            scores,meTree,answerTree = tree2seq.computeStrict(he,me)
            tempout = [scores[0]]
            tempdata.append(tempout)
        tempdata.sort(key = lambda x:x[-1], reverse=True)
             
        #print(cont)
        l,d,c = compute_depth_and_size(cont[2][0])
        #print(tempout)
        s1 = tempdata[0][-1]
        s10 = 0.0
        for i in range(10):
            s10 += tempdata[i][-1]
        s10 /= 10
        s100 = 0.0
        for i in range(100):
            s100 += tempdata[i][-1]
        s100 /= 100
        s500 = 0.0
        for i in range(500):
            s500 += tempdata[i][-1]
        s500 /= 500
        cont.append([s1,s10,s100,s500])
        cont.append([l,d,c])
        #print(cont)
        #exit()
        test_data_modified.append(cont)
        
    torch.save(test_data_modified,"./data/data_similar.pkl")

def compute_depth_and_size(answer_str):
    #print(cont)
    lens = (len(answer_str) - 1) // 2
    depth_m, depth_n = 0, 0
    chashu = 0
    former = answer_str[0]
    for id,a in enumerate(answer_str[1:]):
        if a != "end":
            depth_n += 1
            if former == "end":
                chashu += 1
        else:
            depth_n -= 1
        former = a
                
        depth_m = max(depth_m, depth_n)
    return lens, depth_m, chashu
    
def confidence_interval(array,c = 0.95):
    a = 1.0 * np.array(array)
    m = np.mean(a)
    n = len(array)
    fc = scipy.stats.sem(a)
    h = fc * sp.stats.t._ppf((1+c)/2, n-1) / ((n-1)**0.5)
    return h
    
def change_over_size(testResult = ["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out","sstg_out"]):
    fig, ax = plt.subplots()
    for result_name in testResult:
        outData = torch.load("./data/"+result_name+".pkl")
        strict_data = [[] for i in range(8)]
        for cont in outData:
            size,_,_ = compute_depth_and_size(cont[2][0])
            if size > 8:
                #pass
                strict_data[7].append(cont[-1][1])
            else:
                strict_data[size-1].append(cont[-1][1])
        #strict_data=np.array(strict_data)
        data_mean, data_err = [], []
        for index, dd in enumerate(strict_data):
            print(f"{index}: {len(dd)}")
            data_mean.append(np.mean(dd))
            data_err.append(confidence_interval(dd,c=0.95))

        ax.errorbar(np.arange(1, 9, 1), data_mean, yerr=data_err, fmt="o--")
    plt.xticks([i for i in range(1,9)],[str(i) for i in range(1,8)]+["8+"])
    plt.tick_params(labelsize=15)
    font = {'family': 'sans-serif',
        'color': 'k',
        'weight': 'normal',
        'size': 25,}
    #plt.ylabel("Strict-F1 Score", font)
    #plt.xlabel("Depth of sememe tree")
    #plt.legend(["TSTG","TaSTG",r"$^*$-B",r"$^*$-D", "SSTG"],fontsize=11)
    plt.savefig("./pictures/change_over_size.pdf",bbox_inches="tight",pad_inches=0.1)
    
def change_over_depth(testResult = ["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out","sstg_out"]):
    fig, ax = plt.subplots()
    for result_name in testResult:
        outData = torch.load("./data/"+result_name+".pkl")
        strict_data = [[] for i in range(5)]
        for cont in outData:
            _,depth,_ = compute_depth_and_size(cont[2][0])
            if depth > 5:
                #pass
                strict_data[4].append(cont[-1][1])
            else:
                strict_data[depth-1].append(cont[-1][1])
        #strict_data=np.array(strict_data)
        data_mean, data_err = [], []
        for index, dd in enumerate(strict_data):
            print(f"{index}: {len(dd)}")
            data_mean.append(np.mean(dd))
            data_err.append(confidence_interval(dd,c=0.95))

        ax.errorbar(np.arange(1, 6, 1), data_mean, yerr=data_err, fmt = "o--")
    plt.xticks([i for i in range(1,6)],[str(i) for i in range(1,5)]+["5+"])
    plt.tick_params(labelsize=15)
    font = {'family': 'sans-serif',
        'color': 'k',
        'weight': 'normal',
        'size': 25,}
    plt.ylabel("Strict-F1 Score", font)
    #plt.xlabel("Depth of sememe tree")
    #plt.legend(["TSTG","TaSTG",r"$^*$-B",r"$^*$-D", "SSTG"],fontsize=11)
    plt.savefig("./pictures/change_over_depth.pdf",bbox_inches="tight",pad_inches=0.1)
    
def change_over_chashu(testResult = ["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out","sstg_out"]):
    fig, ax = plt.subplots()
    for result_name in testResult:
        outData = torch.load("./data/"+result_name+".pkl")
        strict_data = [[] for i in range(5)]
        for cont in outData:
            _,_,chashu = compute_depth_and_size(cont[2][0])
            if chashu >= 5:
                pass
                #strict_data[5].append(cont[-1][1])
            else:
                strict_data[chashu].append(cont[-1][1])
        #strict_data=np.array(strict_data)
        data_mean, data_err = [], []
        for index, dd in enumerate(strict_data):
            print(f"{index}: {len(dd)}")
            data_mean.append(np.mean(dd))
            data_err.append(confidence_interval(dd,c=0.95))

        ax.errorbar(np.arange(1, 6, 1), data_mean, yerr=data_err, fmt = "o--")
    plt.xticks([i for i in range(1,6)],[str(i) for i in range(1,5)]+["5+"])
    plt.tick_params(labelsize=15)
    font = {'family': 'sans-serif',
        'color': 'k',
        'weight': 'normal',
        'size': 25,}
    #plt.ylabel("Strict-F1 Score", font)
    #plt.xlabel("Depth of sememe tree")
    plt.legend(["TSTG","TaSTG",r"$^*$-B",r"$^*$-D", "SSTG"],fontsize=14)
    plt.savefig("./pictures/change_over_chashu.pdf",bbox_inches="tight",pad_inches=0.1)
            

def get_new_restrict_score(out_names=["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out","sstg_out"]):
    for out_name in out_names:
        testData = torch.load("./data/"+out_name+"_restrict.pkl")
        recall,precision = []
        for cont in testData:
            size,_,_ = compute_depth_and_size(cont[2][0])
            me = [model.sememe2id[i] for i in cont[2][0]] + [2084]
            me = [str(i) for i in me]
            he = [model.sememe2id[i] for i in data[2][0]] + [2084]
            he = [str(i) for i in he]
            #print(he)
            #print(me)
            scores,meTree,answerTree = tree2seq.computeStrict(he,me)
            if size > 1:
                score.append(cont[-1])
        
        score = np.array(score)
        score = np.mean(score, axis=0)
        print(score)
        return score


def check_new_bleu(testResult = ["base_cls_400_out","tupe_depth_bias_380_out","base_tupe_noBias_280_out","tupe_orderDepth_350_out","sstg_out"]):
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=0, MASK=1, seq = False, depthMethod="", biasMethod="")
    for name in testResult:
        outData = torch.load("./data/"+name+"_restrict.pkl")
        s_real = []
        s_fake = []
        p,r,t,s = [],[],[],[]
        for cont in outData:
            if len(cont[2][0]) <= 3:
                continue
            a = [model.sememe2id[i] for i in cont[2][0]] + [2084]
            b = [2083] + cont[3] + [2084]
            
            size,_,_ = compute_depth_and_size([model.id2sememe[int(i)] for i in b])
            s.append(size)
            a = [str(i) for i in a]
            b = [str(i) for i in b]
            scores,meTree,answerTree = tree2seq.computeStrict(b,a)
            p.append(scores[-2])
            r.append(scores[-1])
            t.append(scores[2])
            for i in a.copy():
                if i== "2083" or i == "2084":
                    a.remove(i)
            for i in b.copy():
                if i== "2083" or i == "2084":
                    b.remove(i)
                    
            scoreBleu = sentence_bleu([a], b,smoothing_function=SmoothingFunction().method1)
            s_real.append(scoreBleu)
            s_fake.append(cont[-1][0])
            #print(scoreBleu)
            #print(cont[-1][0])
            #print("*"*50)
        print(name+"*"*20)
        s_real = np.array(s_real)
        print(np.mean(s_real))
        s_fake = np.array(s_fake)
        print(np.mean(s_fake))
        p = np.array(p)
        print(np.mean(p))
        r = np.array(r)
        print(np.mean(r))
        t = np.array(t)
        print(np.mean(t))
        s = np.array(s)
        print(np.mean(s))
#check_new_bleu()
#exit()

def demo(testSet = "/data3/private/yyn/structPred/datas/testSet.pkl", restrict=0,loadname = "base_cls_400", tt = 0, dep = "depth", bias = "distance"):
    testData = torch.load(testSet)
    temp = []
    shunxu = list(range(len(testData)))
    random.seed(2077)
    random.shuffle(shunxu)
    model = TreePred(wordDim = 768, maxPosEmbed = 50,sememeDim = 768, head=8,attentionLayer = 8,hiddenDim = 128, pretrained = 1, TUPE=tt, MASK=1, seq = False, depthMethod=dep, biasMethod=bias)
    
    stateTest = torch.load(f"/data3/private/yyn/structPred/models/"+loadname+"/modelWeight.pt",map_location=torch.device('cpu'))
    model.load_state_dict(stateTest['net'],strict=False)
    model.loadBert(f"/data3/private/yyn/structPred/models/"+loadname+"/")
    model.device = torch.device("cpu")
    model.to('cpu')
    model.eval()
    acc1,acc2,acc3 = 0.0,0.0,0.0
    total,totalEdge,totalBLEU = 0.0,0.0,0.0
    BLEUSig = 0.0
    with torch.no_grad():
        bar = tqdm.tqdm(range(len(testData)))
        for id in bar:
            torch.cuda.empty_cache() 
            cont = testData[shunxu[id]]
            #print(cont)

            #dictsSeq = tree2seq.tree2seq(cont[2][0])
            ans = [2083]+[model.sememe2id[i] for i in cont[2][0]][1:] + [2084]
            ans = [str(i) for i in ans]
            if restrict:
                out = model.predict([cont[1][0]],beamWidth=1,maxLen=15,stricted=cont[2][0])
            else:
                out = model.predict([cont[1][0]],beamWidth=1,maxLen=15)
            #ls = out.argmax(dim=1).to("cpu").numpy().tolist()
            assert(out[0] == 2083)
            ls = [str(i) for i in out]
            #print(ans)
            #print(ls)
            scores,meTree,answerTree = tree2seq.computeStrict(ls,ans,id2sememe=model.id2sememe)
            #print(scores)
            cont.append(ls)
            #remove 2083 , 2084
            ls.remove("2083")
            ls.remove("2084")
            ans.remove("2083")
            ans.remove("2084")
            scoreBleu = sentence_bleu([ans], ls,smoothing_function=SmoothingFunction().method1)
            #print(scoreBleu)
            #print(ls)
            #print([ans])
            #print(f"************************************************BLEU:{scoreBleu}, strict:{scores[0]}, edge:{scores[1]}, vertex:{scores[2]}")
            #print(cont[1][0])
            #print(meTree)
            #print(answerTree)
            cont.append([scoreBleu]+scores)
            temp.append(cont)
            acc1 += scores[0]
            if scores[1] >= 0:
                totalEdge += 1
                acc2 += scores[1]
            acc3 += scores[2]
            #if len(ls) > 3:
            BLEUSig += scoreBleu
            totalBLEU += 1
            total += 1
            if total > 0 and totalEdge > 0 and totalBLEU > 0:
                bar.set_description(f"{format(BLEUSig / totalBLEU,'.4f')},{format(acc1/total,'.4f')},{format(acc2/totalEdge,'.4f')}, {format(acc3/total,'.4f')} ")
        #print(f"score: {acc / total}")
        if restrict:
            torch.save(temp,"./data/"+loadname+"_out_restrict.pkl")
        else:
            torch.save(temp,"./data/"+loadname+"_out.pkl")
        print(f"{format(BLEUSig / totalBLEU,'.4f')},{format(acc1/total,'.4f')},{format(acc2/totalEdge,'.4f')}, {format(acc3/total,'.4f')} ")
        return acc1 / total
    
  
'''print("deal base cls")
demo(restrict=0,loadname = "base_cls_400",tt=0,dep="depth",bias="distance")
demo(restrict=1,loadname = "base_cls_400",tt=0,dep="depth",bias="distance")
print("deal tupe normal")
demo(restrict=0,loadname = "tupe_depth_bias_380",tt=1,dep="depth",bias="distance")
demo(restrict=1,loadname = "tupe_depth_bias_380",tt=1,dep="depth",bias="distance")
print("deal tupe no bias 400")
demo(restrict=0,loadname = "base_tupe_noBias_400",tt=1,dep="depth",bias="none")
demo(restrict=1,loadname = "base_tupe_noBias_400",tt=1,dep="depth",bias="none")'''
#print("deal tupe no bias 280")
#demo(restrict=0,loadname = "base_tupe_noBias_280",tt=1,dep="depth",bias="none")
#demo(restrict=1,loadname = "base_tupe_noBias_280",tt=1,dep="depth",bias="none")
'''print("deal tupe forward")
demo(restrict=0,loadname = "tupe_orderDepth_350",tt=1,dep="order",bias="distance")
demo(restrict=1,loadname = "tupe_orderDepth_350",tt=1,dep="order",bias="distance")'''
#data_distribute()
change_over_size()
change_over_depth()
change_over_chashu()
#check_zero_shot()

'''print("***************")
test_zero_shot("base_cls_400_out")
print("***************")
test_zero_shot("tupe_depth_bias_380_out")
print("***************")
test_zero_shot("base_tupe_noBias_400_out")
print("***************")
test_zero_shot("tupe_orderDepth_350_out")'''

#gen_demo()
#test_set_similar_with_train_set()

'''for name in ["base_cls_400_out_restrict","tupe_depth_bias_380_out_restrict","base_tupe_noBias_280_out_restrict","tupe_orderDepth_350_out_restrict"]:
    get_new_restrict_score(name)'''