import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import random
import time
import json
import higher
import gc
import re
from copy import deepcopy
from tqdm import tqdm, trange
from sklearn.cluster import KMeans
from encode import lstm_encoder
from dataprocess_tacred import data_sampler
from model import proto_softmax_layer
from output import outputer
from dataprocess import get_data_loader
from transformers import BertTokenizer,BertModel
from util import set_seed,process_data,select_similar_data,getnegfrombatch,select_similar_data_new,select_similar_data_new_tac
import faiss


def eval_model(config, basemodel, test_set, mem_relations):
    print("One eval")
    print("test data num is:\t",len(test_set))
    basemodel.eval()

    test_dataloader = get_data_loader(config, test_set, shuffle=False, batch_size=30)
    allnum= 0.0
    correctnum = 0.0
    for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
               typelabels) in enumerate(test_dataloader):

        logits, rep = basemodel(sentences, lengths)

        distances = basemodel.get_mem_feature(rep)
        short_logits = distances

        #short_logits = logits
        for index, logit in enumerate(logits):
            score = short_logits[index]  # logits[index] + short_logits[index] + long_logits[index]
            allnum += 1.0
            '''
            preindex = score.argmax()
            if labels[index] == preindex:
                correctnum += 1
            '''
            #'''
            golden_score = score[labels[index]]
            max_neg_score = -2147483647.0
            for i in neg_labels[index]:  # range(num_class):
                if (i != labels[index]) and (score[i] > max_neg_score):
                    max_neg_score = score[i]
            if golden_score > max_neg_score:
                correctnum += 1
            #'''
    acc = correctnum / allnum
    print(acc)
    basemodel.train()
    return acc



def select_data(mem_set, proto_memory, config, model, divide_train_set, num_sel_data, current_relations):
    ####select data according to selecttype
    #selecttype is 0: cluster for every rel
    #selecttype is 1: use ave embedding
    rela_num = len(current_relations)
    for i in range(0, rela_num):
        thisrel = current_relations[i]
        if thisrel in mem_set.keys():
            #print("have set mem before")
            mem_set[thisrel] = {'0': [], '1': []}
            proto_memory[thisrel].pop()
        else:
            mem_set[thisrel] = {'0': [], '1': []}
        thisdataset = divide_train_set[thisrel]
        data_loader = get_data_loader(config, thisdataset, False, False)
        features = []
        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,  lengths,
                   typelabels) in enumerate(data_loader):
            feature = model.get_feature(sentences, lengths)
            features.append(feature)
        features = np.concatenate(features)
        #print(features.shape)
        num_clusters = min(num_sel_data, len(thisdataset))

        #print("use average embedding")
        samplenum = features.shape[0]
        veclength = features.shape[1]
        sumvec = np.zeros(veclength)
        for j in range(samplenum):
            sumvec += features[j]
        sumvec /= samplenum

        ###find nearest sample
        mindist = 100000000
        minindex = -100
        for j in range(samplenum):
            dist = np.sqrt(np.sum(np.square(features[j] - sumvec)))
            if dist < mindist:
                minindex = j
                mindist = dist
        #print(minindex)
        instance = thisdataset[minindex]
        ###change tylelabel
        instance[11] = 3
        mem_set[thisrel]['0'].append(instance)
        mem_set[thisrel]['1'].append(features[minindex])
        proto_memory[thisrel].append(instance)

    return mem_set

def updateaverage(config, model, mem_set, current_proto, mem_relations):
    mem_data = []
    allkey = []
    if len(mem_set) != 0:
        for key in mem_set.keys():
            if key in mem_relations:
                mem_data.extend(mem_set[key]['0'])
                allkey.append(key)
    print(len(mem_data))
    data_loader = get_data_loader(config, mem_data, False, False, batch_size = 1)
    features = []
    for step, (
    labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
    typelabels) in enumerate(data_loader):
        feature = model.get_feature(sentences, lengths)
        #print(feature.shape)
        features.append(feature)
    #print("&&&&&&&&&&&&&&&&&&&&&&")
    #print(len(features))
    for i in range(len(features)):
        #print(features[i].shape)
        #print(mem_set[allkey[i]]['1'][0].shape)
        delta = features[i] - mem_set[allkey[i]]['1'][0]
        mem_set[allkey[i]]['1'][0] = features[i]
        #print(torch.tensor(delta).shape)
        #print(current_proto[allkey[i]].shape)
        current_proto[allkey[i]] = current_proto[allkey[i]] + torch.tensor(delta)
    return 0


def get_average(allaverage, config, model, divide_train_set, num_sel_data, current_relations):
    rela_num = len(current_relations)
    for i in range(0, rela_num):
        thisrel = current_relations[i]
        thisdataset = divide_train_set[thisrel]
        data_loader = get_data_loader(config, thisdataset, False, False)
        features = []
        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,  lengths,
                   typelabels) in enumerate(data_loader):
            feature = model.get_feature(sentences, lengths)
            features.append(feature)
        features = np.concatenate(features)
        #print(features.shape)

        #print("use average embedding")
        samplenum = features.shape[0]
        veclength = features.shape[1]
        sumvec = np.zeros(veclength)
        for j in range(samplenum):
            sumvec += features[j]
        sumvec /= samplenum
        if thisrel not in allaverage:
            allaverage[thisrel] = sumvec
        else:
            allaverage[thisrel] = sumvec
    return 1


def train_with_M_intro_F(config, model, traindata, epochs, allaverage, mem_set):
    mem_data = []
    if len(mem_set) != 0:
        for key in mem_set.keys():
            mem_data.extend(mem_set[key]['0'])
    print(len(mem_data))
    train_set = traindata + mem_data
    data_loader = get_data_loader(config, train_set, batch_size=config['batch_size_per_step'])
    model.train()
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    for epoch_i in range(epochs):
        losses1 = []
        losses2 = []
        losses3 = []

        lossesfactor1 = 1.0
        lossesfactor2 = 1.0
        lossesfactor3 = 0.5

        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid,
                rawtext,lengths, typelabels) in enumerate(tqdm(data_loader)):
            model.zero_grad()
            logits, rep = model(sentences, lengths)
            labels = labels.to(config['device'])
            typelabels = typelabels.to(config['device'])
            loss1 = torch.tensor(0.0).to(config['device'])    ###M
            loss2 = torch.tensor(0.0).to(config['device'])    ###intra
            loss3 = torch.tensor(0.0).to(config['device'])    ####F
            newnum = 0
            for index, logit in enumerate(logits):
                if typelabels[index] == 1:
                    newnum += 1
                    thisrep = rep[index]
                    preindex = labels[index]
                    centerrep = torch.tensor(allaverage[preindex.item()]).to(config['device'])
                    dis = torch.sqrt(torch.sum(torch.square(thisrep - centerrep)))
                    loss2 += dis
            for index, logit in enumerate(logits):
                if typelabels[index] == 1:
                    score = logits[index]
                    size = score.shape[0]
                    preindex = labels[index]
                    minrep = torch.tensor(allaverage[preindex.item()]).to(config['device'])
                    mindis = torch.sqrt(torch.sum(torch.square(rep[index] - minrep)))
                    secmimdis = 1000000
                    for j in range(size):
                        if j in allaverage.keys():
                            if j != preindex:
                                thisrep = torch.tensor(allaverage[j]).to(config['device'])
                                thisdis = torch.sqrt(torch.sum(torch.square(rep[index] - thisrep)))
                                if thisdis < secmimdis:
                                    secmimdis = thisdis
                    if mindis - secmimdis > 0.0:
                        loss1 += ((mindis - secmimdis)/(mindis + secmimdis)).to(config['device'])
            if newnum != 0:
                loss2 /= newnum
                loss1 /= newnum
            memnum = 0
            for index, logit in enumerate(logits):
                if typelabels[index] == 3:
                    memnum += 1
                    thisrep = rep[index]
                    preindex = labels[index]
                    centerrep = torch.tensor(mem_set[preindex.item()]['1']).to(config['device'])
                    dis = torch.sqrt(torch.sum(torch.square(thisrep - centerrep)))
                    loss3 += dis
            if memnum != 0:
                loss3 /= memnum
            loss = loss1 * lossesfactor1 + loss2 * lossesfactor2 + loss3 * lossesfactor3
            loss.backward()
            losses1.append(loss1.item())
            losses2.append(loss2.item())
            losses3.append(loss3.item())
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        # print (np.array(losses).mean())
    return model
def train_with_ce_intro(config, model, train_set, epochs, allaverage):
    data_loader = get_data_loader(config, train_set, batch_size=config['batch_size_per_step'])
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    pdist = nn.PairwiseDistance(p=2)
    for epoch_i in range(epochs):
        losses1 = []
        losses2 = []

        lossesfactor1 = 1.0
        lossesfactor2 = 1.0

        for step, (
        labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,
        lengths, typelabels) in enumerate(tqdm(data_loader)):
            model.zero_grad()
            logits, rep = model(sentences, lengths)
            labels = labels.to(config['device'])
            loss1 = criterion(logits, labels)
            loss2 = torch.tensor(0.0).to(config['device'])
            for index, logit in enumerate(logits):

                thisrep = rep[index]
                preindex = labels[index]
                centerrep = torch.tensor(allaverage[preindex.item()]).to(config['device'])
                dis = torch.sqrt(torch.sum(torch.square(thisrep - centerrep)))
                loss2 += dis
            loss = loss1 * lossesfactor1 + loss2 * lossesfactor2
            loss.backward()
            losses1.append(loss1.item())
            losses2.append(loss2.item())
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        # print (np.array(losses).mean())
    return model

def train_with_ce(config, model, train_set, epochs):
    data_loader = get_data_loader(config, train_set, batch_size = config['batch_size_per_step'])
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    for epoch_i in range(epochs):
        losses1 = []

        lossesfactor1 = 1.0

        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,
                   lengths, typelabels) in enumerate(tqdm(data_loader)):
            model.zero_grad()
            logits, rep = model(sentences, lengths)
            labels = labels.to(config['device'])
            loss1 = criterion(logits, labels)
            loss = loss1 * lossesfactor1
            loss.backward()
            losses1.append(loss1.item())
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        #print (np.array(losses).mean())
    return model

def train_simple_model(config, model, train_set, epochs, current_proto):
    data_loader = get_data_loader(config, train_set, batch_size = 5)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    for epoch_i in range(epochs):
        model.set_memorized_prototypes(current_proto)
        losses1 = []
        losses2 = []

        lossesfactor1 = 1.0
        lossesfactor2 = 1.0

        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,
                   lengths, typelabels) in enumerate(tqdm(data_loader)):
            #if step < 5:
            #    print(labels)
            #    print(sentences)
            #    print(rawtext)
            model.zero_grad()
            logits, rep = model(sentences, lengths)
            logits_proto = model.mem_forward(rep)

            labels = labels.to(config['device'])
            loss1 = criterion(logits, labels)
            loss2 = criterion(logits_proto, labels)
            loss = loss1 * lossesfactor1 + loss2 * lossesfactor2
            loss.backward()
            losses1.append(loss1.item())
            losses2.append(loss2.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        #print (np.array(losses).mean())
    return model

def get_memory(config, model, proto_set):
    memset = []
    resset = []
    rangeset= [0]
    for i in proto_set:
        #print(i)
        memset += i
        rangeset.append(rangeset[-1] + len(i))
    data_loader = get_data_loader(config, memset, False, False)
    features = []
    for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
               typelabels) in enumerate(data_loader):
        feature = model.get_feature(sentences, lengths)
        features.append(feature)
    features = np.concatenate(features)

    protos = []
    #print ("proto_instaces:%d"%len(features))
    for i in range(len(proto_set)):
        protos.append(torch.tensor(features[rangeset[i]:rangeset[i+1],:].mean(0, keepdims = True)))
    protos = torch.cat(protos, 0)
    #print(protos.shape)
    return protos

def train_new_model(config, model, mem_set, traindata, epochs, current_proto):
    print(len(traindata))
    #print(len(train_set))
    mem_data = []
    if len(mem_set) != 0:
        for key in mem_set.keys():
            mem_data.extend(mem_set[key]['0'])
    print(len(mem_data))
    train_set = traindata + mem_data
    print(len(train_set))
    data_loader = get_data_loader(config, train_set, batch_size=config['batch_size_per_step'])
    model.train()
    criterion = nn.CrossEntropyLoss()
    lossfn = nn.MultiMarginLoss(margin=0.2)
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    for epoch_i in range(epochs):
        model.set_memorized_prototypes(current_proto)
        losses2 = []
        losses3 = []


        lossesfactor2 = 1.0
        lossesfactor3 = 1.0

        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
                   typelabels) in enumerate(data_loader):
            model.zero_grad()
            #print(len(sentences))
            labels = labels.to(config['device'])
            #print(labels.shape)
            typelabels = typelabels.to(config['device'])  ####0:rel  1:pos(new train data)  2:neg  3:mem
            numofmem = 0
            numofnewtrain = 0
            allnum = 0
            memindex = []
            for index,onetype in enumerate(typelabels):
                if onetype == 1:
                    numofnewtrain += 1
                if onetype == 3:
                    numofmem += 1
                    memindex.append(index)
                allnum += 1
            logits, rep = model(sentences, lengths)
            logits_proto = model.mem_forward(rep)
            #print(logits_proto.shape[1])
            if numofnewtrain != 0:
                newlabel = torch.zeros(numofnewtrain,dtype=torch.long)
                newpro = torch.zeros((numofnewtrain,42))
            if numofmem != 0:
                memlabel = torch.zeros(numofmem,dtype=torch.long)
                mempro = torch.zeros((numofmem,42))

            aa = 0
            bb = 0
            #print(labels[1])
            for i in range(allnum):
                #print(i)
                if i in memindex:
                    memlabel[aa] = labels[i]
                    mempro[aa] = logits_proto[i]
                    aa += 1
                else:
                    newlabel[bb] = labels[i]
                    newpro[bb] = logits_proto[i]
                    bb += 1
            loss2 = torch.tensor(0.0).to(config['device'])
            if numofnewtrain != 0:
                loss2 += criterion(newpro, newlabel)
            if numofmem != 0:
                loss2 += 0.5 * criterion(mempro, memlabel)

            loss3 = torch.tensor(0.0).to(config['device'])
            loss3num = 0
            for index, logit in enumerate(logits):
                if index not in memindex:
                    score = logits_proto[index]
                    preindex = labels[index]
                    maxscore = score[preindex]
                    size = score.shape[0]
                    secondmax = -100000
                    for j in range(size):
                        if j != preindex and score[j] > secondmax:
                            secondmax = score[j]
                    if secondmax - maxscore > 0.0:
                        loss3num += 1
                        loss3 += ((secondmax - maxscore)/(secondmax + maxscore)).to(config['device'])
            loss3 /= loss3num


            loss = loss2 * lossesfactor2 + loss3 * lossesfactor3

            loss.backward()

            losses2.append(loss2.item())
            losses3.append(loss3.item())


            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        return model


if __name__ == '__main__':

    f = open("config_IDLVQ_tacred.json", "r")
    config = json.loads(f.read())
    f.close()
    config['device'] = torch.device('cuda' if torch.cuda.is_available() and config['use_gpu'] else 'cpu')
    config['n_gpu'] = torch.cuda.device_count()
    config['batch_size_per_step'] = int(config['batch_size'] / config["gradient_accumulation_steps"])
    config['neg_sampling'] = False

    root_path = '.'
    word2id = json.load(open(os.path.join(root_path, 'glove/word2id.txt')))
    word2vec = np.load(os.path.join(root_path, 'glove/word2vec.npy'))

    donum = 1

    for m in range(donum):
        print(m)
        config["rel_cluster_label"] = "data/tacred/CFRLdata_10_100_10_10/rel_cluster_label_" + str(m) + ".npy"
        config['training_file'] = "data/tacred/CFRLdata_10_100_10_10/train_" + str(m) + ".txt"
        config['valid_file'] = "data/tacred/CFRLdata_10_100_10_10/valid_" + str(m) + ".txt"
        config['test_file'] = "data/tacred/CFRLdata_10_100_10_10/test_" + str(m) + ".txt"

        encoderforbase = lstm_encoder(token2id=word2id, word2vec=word2vec, word_size=len(word2vec[0]), max_length=128, pos_size=None,
                                    hidden_size=config['hidden_size'], dropout=0, bidirectional=True, num_layers=1, config=config)
        sampler = data_sampler(config, encoderforbase.tokenizer)
        modelforbase = proto_softmax_layer(encoderforbase, num_class=len(sampler.id2rel), id2rel=sampler.id2rel, drop=0, config=config)
        modelforbase = modelforbase.to(config["device"])

        word2vec_back = word2vec.copy()

        sequence_results = []
        result_whole_test = []
        for i in range(6):

            num_class = len(sampler.id2rel)
            print(config['random_seed'] + 10 * i)
            set_seed(config, config['random_seed'] + 10 * i)
            sampler.set_seed(config['random_seed'] + 10 * i)

            mem_set = {}
            mem_relations = []   ###not include relation of current task

            past_relations = []

            savetest_all_data = None
            saveseen_relations = []

            proto_memory = []

            for i in range(len(sampler.id2rel)):
                proto_memory.append([sampler.id2rel_pattern[i]])

            oneseqres = []
            ##################################
            ##################################
            current_proto_bak = torch.zeros((42,200))
            print(current_proto_bak.shape)
            allseenrel = []
            for steps, (training_data, valid_data, test_data, test_all_data, seen_relations, current_relations) in enumerate(sampler):
                #print(steps)
                print(len(training_data))
                #for aa in range(20):
                #    print(training_data[aa])
                savetest_all_data = test_all_data
                saveseen_relations = seen_relations

                currentnumber = len(current_relations)
                allseenrel.extend(current_relations)
                print(currentnumber)
                print(current_relations)
                divide_train_set = {}
                for relation in current_relations:
                    divide_train_set[relation] = []  ##int
                for data in training_data:
                    divide_train_set[data[0]].append(data)
                print(len(divide_train_set))

                if steps == 0:
                    print("train first base")
                    current_proto = get_memory(config, modelforbase, proto_memory)
                    modelforbase = train_simple_model(config, modelforbase, training_data, 1, current_proto)
                    #def select_data(mem_set, proto_memory, config, model, divide_train_set, num_sel_data, current_relations):
                    select_data(mem_set, proto_memory, config, modelforbase, divide_train_set,
                                config['rel_memory_size'], current_relations)
                    for j in range(2):
                        #current_proto = get_memory(config, modelforbase, proto_memory)
                        ###update the rep of current rel
                        allaverage = {}
                        get_average(allaverage, config, modelforbase, divide_train_set, config['rel_memory_size'], current_relations)
                        print(len(allaverage))
                        for onekey in allaverage.keys():
                            #print(allaverage[onekey].shape)
                            current_proto[onekey] = torch.tensor(allaverage[onekey])
                            #print("------------------------------")
                            #print(current_proto[onekey].shape)
                        modelforbase = train_simple_model(config, modelforbase, training_data, 1, current_proto)
                        select_data(mem_set, proto_memory, config, modelforbase, divide_train_set,
                                    config['rel_memory_size'], current_relations)
                    for onekey in allseenrel:
                        current_proto_bak[onekey] = current_proto[onekey]
                else:
                    print("train few shot data")
                    current_proto = get_memory(config, modelforbase, proto_memory)
                    for onekey in mem_relations: ###mem_relations
                        current_proto[onekey] = current_proto_bak[onekey]
                    modelforbase = train_new_model(config, modelforbase, mem_set, training_data, 1, current_proto)
                    select_data(mem_set, proto_memory, config, modelforbase, divide_train_set,
                                config['rel_memory_size'], current_relations)
                    for j in range(2):
                        #current_proto = get_memory(config, modelforbase, proto_memory)
                        allaverage = {}
                        get_average(allaverage, config, modelforbase, divide_train_set, config['rel_memory_size'],
                                    current_relations)
                        print(len(allaverage))
                        for onekey in allaverage.keys():
                            current_proto[onekey] = torch.tensor(allaverage[onekey])
                        modelforbase = train_new_model(config, modelforbase, mem_set, training_data, 1, current_proto)
                        select_data(mem_set, proto_memory, config, modelforbase, divide_train_set,
                                    config['rel_memory_size'], current_relations)
                    updateaverage(config, modelforbase, mem_set, current_proto, mem_relations)
                    for onekey in allseenrel:
                        current_proto_bak[onekey] = current_proto[onekey]
                current_proto = get_memory(config, modelforbase, proto_memory)
                for onekey in allseenrel:
                    current_proto[onekey] = current_proto_bak[onekey]
                modelforbase.set_memorized_prototypes(current_proto)
                mem_relations.extend(current_relations)

                currentalltest = []
                for mm in range(len(test_data)):
                    currentalltest.extend(test_data[mm])
                    #eval_model(config, modelforbase, test_data[mm], mem_relations)

                thisstepres = eval_model(config, modelforbase, currentalltest, mem_relations)
                print("step:\t",steps,"\taccuracy:\t",thisstepres)
                oneseqres.append(thisstepres)
            sequence_results.append(np.array(oneseqres))

            #def eval_both_model(config, newmodel, basemodel, test_set, mem_relations, baserelation, newrelation, proto_embed):
            allres = eval_model(config, modelforbase, savetest_all_data, saveseen_relations)
            result_whole_test.append(allres)

            print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
            print("after one epoch allres:\t",allres)
            print(result_whole_test)

            # initialize the models
            modelforbase = modelforbase.to('cpu')
            del modelforbase
            gc.collect()
            if config['device'] == 'cuda':
                torch.cuda.empty_cache()
            encoderforbase = lstm_encoder(token2id=word2id, word2vec=word2vec_back.copy(), word_size=len(word2vec[0]),max_length=128, pos_size=None,
                                          hidden_size=config['hidden_size'], dropout=0, bidirectional=True, num_layers=1, config=config)
            modelforbase = proto_softmax_layer(encoderforbase, num_class=len(sampler.id2rel), id2rel=sampler.id2rel,
                                               drop=0, config=config)
            modelforbase.to(config["device"])
            # output the final avg result
        print("Final result!")
        print(result_whole_test)
        for one in sequence_results:
            for item in one:
                sys.stdout.write('%.4f, ' % item)
            print('')
        avg_result_all_test = np.average(sequence_results, 0)
        for one in avg_result_all_test:
            sys.stdout.write('%.4f, ' % one)
        print('')
        print("Finish training............................")
    #'''

