import json
from Embedding_models import *
import torch.nn
from tokenizers import Tokenizer
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, BertModel, BertConfig
from evaluation1 import *
from fusion_models import *
from torch.utils.data import DataLoader
from utils import *
import os
import args_FB15k
from head2tailDataset import *
arguments = args_FB15k.get_args()
epochs = arguments.epochs
lr = arguments.lr
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

device = arguments.device

train_data = read_triplets_from_txt(arguments.train_data_path)
valid_data = read_triplets_from_txt(arguments.valid_data_path)
test_data = read_triplets_from_txt(arguments.test_data_path)
train_data_reverse = []
for triplet in train_data:
    train_data_reverse.append([triplet[2], '/be' + triplet[1], triplet[0]])
onlyTail_train_data = train_data + train_data_reverse
valid_data_reverse = []
for triplet in valid_data:
    valid_data_reverse.append([triplet[2], '/be' + triplet[1], triplet[0]])
onlyTail_val_data = valid_data + valid_data_reverse
test_data_reverse = []
for triplet in test_data:
    test_data_reverse.append([triplet[2], '/be' + triplet[1], triplet[0]])
onlyTail_test_data = test_data + test_data_reverse

entity2id = read_entity(arguments.entity_path)
relation2id = read_entity("data/FB15k-237/reverse_relations.txt")
entity_set = set(entity2id.keys())
label_num = len(entity2id)
relation_num = len(relation2id)
groundtruth = count_groundtruth(onlyTail_train_data, onlyTail_val_data, onlyTail_test_data)
def train_val_epochs(bert,model,train_dataloader,val_dataloader,epochs,lr,label_num,device):
    max_MRR = 0
    lr = arguments.lr


    #交叉熵损失函数
    params = [
        {'params':model.transE.entity_embeddings.weight,'lr':1e-4},
        {'params':model.transE.relation_embeddings.weight,'lr':1e-4},
        {'params':[param for name,param in model.named_parameters() if 'transE' not in name],'lr':lr}
    ]
    criterion = CrossEntropyLoss()
    optimizer = torch.optim.Adam(params,lr=lr)
    for epoch in range(epochs):


        model.train()
        train_tail_loss = 0
        bar = tqdm(total = len(train_dataloader),desc=f'Epoch{epoch+1}/{epochs}', ncols=100)
        for iter,batch in enumerate(train_dataloader):
            input_ids,heads,relations,labels = batch
            input_ids = input_ids.to(device)
            heads = heads.to(device)
            relations = relations.to(device)
            labels = labels.to(device)
            word_embeds = bert.embeddings.word_embeddings(input_ids)
            pos_embeds = bert.embeddings.position_embeddings(torch.arange(0,input_ids.shape[1],dtype = torch.long).to(device))
            input_embeddings = word_embeds + pos_embeds
            output,structure_loss = model(input_embeddings, input_ids,heads,relations,labels)
            classify_loss = criterion(output, labels)
            loss = classify_loss + 1*structure_loss
            train_tail_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            bar.set_description(f'train_tail_Epoch {epoch + 1}/{epochs}')
            bar.set_postfix(c_loss=f'{classify_loss.item():.4f}',s_loss = f'{structure_loss:.4f}')
            bar.update(1)
        bar.close()
        mean_loss = train_tail_loss/len(train_dataloader)
        print('train_epoch: {}, loss:{}'.format(epoch+1,mean_loss))
        train_epoch_result = {
            "epoch":epoch+1,
            "loss":mean_loss
        }
        with open(arguments.train_result_json_path,'a') as jsf:
            json.dump(train_epoch_result,jsf)
            jsf.write('\n')


        model.eval()
        # #tail-barch validation
        total_rank = []
        val_loss = 0.0
        bar = tqdm(total = len(val_dataloader),desc=f'Epoch{epoch+1}/{epochs}', ncols=100)
        for iter, batch in enumerate(val_dataloader):
            input_ids, heads, relations, labels = batch
            input_ids = input_ids.to(device)
            heads = heads.to(device)
            relations = relations.to(device)
            labels = labels.to(device)
            word_embeds = bert.embeddings.word_embeddings(input_ids)
            pos_embeds = bert.embeddings.position_embeddings(torch.arange(0,input_ids.shape[1],dtype = torch.long).to(device))
            input_embeddings = word_embeds + pos_embeds
            r_h = []
            for input_id in input_ids:
                r = input_id[2].item()
                h = input_id[1].item()
                r_token = tokenizer.id_to_token(r)
                h_token = tokenizer.id_to_token(h)
                r_h.append((r_token, h_token))

            output,structure_loss = model(input_embeddings, input_ids,heads,relations,labels)
            classify_loss = criterion(output, labels)
            loss = classify_loss + 0.5*structure_loss
            val_loss += loss.item()
            ranks = evaluation(output, groundtruth, entity2id, r_h, labels, modes='tail')

            for label, rank in zip(labels, ranks):
                the_rank = torch.where(rank == label)[0].item() + 1
                total_rank.append(the_rank)
            bar.set_description(f'test_tail_Epoch {epoch + 1}/{epochs}')
            bar.set_postfix(loss=f'{loss.item()}')
            bar.update(1)
        bar.close()

        total_rank = torch.tensor(total_rank)
        tail_MR = total_rank.sum().item() / len(total_rank)
        tail_MRR = torch.sum(1 / total_rank).item() / len(total_rank)
        tail_hit10 = torch.sum(total_rank <= 10).item() / len(total_rank)
        tail_hit3 = torch.sum(total_rank <= 3).item() / len(total_rank)
        tail_hit1 = torch.sum(total_rank <= 1).item() / len(total_rank)
        tail_test_loss = val_loss / len(val_dataloader)
        print('MR:{} ,MRR:{} ,HIT10:{} ,HIT3:{} ,HIT1:{}'.format(
                                                                                                       tail_MR,
                                                                                                       tail_MRR,
                                                                                                       tail_hit10,
                                                                                                       tail_hit3,
                                                                                                       tail_hit1))
        test_epoch_result = {
            "epoch": epoch+1,
            "loss": tail_test_loss,
            "MR": tail_MR,
            "MRR": tail_MRR,
            "Hit@10": tail_hit10,
            "Hit@3": tail_hit3,
            "Hit@1": tail_hit1,

        }
        with open(arguments.valid_result_json_path, 'a') as jsf:
            json.dump(test_epoch_result, jsf)
            jsf.write('\n')
        #save model if mean MRR > max_MRR
        if tail_MRR > max_MRR:
            max_MRR = tail_MRR
            torch.save(model.state_dict(), arguments.weight_path)
            print('model saved')

if __name__ == '__main__':

    Bert = BertModel.from_pretrained(arguments.model_path).to(device)

    tokenizer = Tokenizer.from_file(arguments.tokenizer_path)
    vocab = tokenizer.get_vocab()
    entity2id = read_entity(arguments.entity_path)
    label_num = len(entity2id)
    transe = NewTransE(label_num,relation_num,arguments.hidden_size,groundtruth,entity2id)
    classifier = Classifier(arguments.hidden_size,label_num)
    #
    #simpleMlp
    # simpleMlp = simpleMLP(arguments.hidden_size*2,arguments.hidden_size)
    simpleMlp = simpleMLP(arguments.hidden_size*2,arguments.hidden_size)
    hidden_size = arguments.hidden_size
    new_word_embeddings_weight = torch.load("model/Fb15k237_word_embeddings_reverse.pt")
    position_embeddings_weight = Bert.embeddings.position_embeddings.weight
    # #if weight exists, load it
    weight_path = arguments.weight_path
    Bert.embeddings.word_embeddings.weight = torch.nn.Parameter(new_word_embeddings_weight)
    # Bert_model = MASKBertModel(Bert, tokenizer, device)
    # model = MASKBertFusionModel(Bert,transe,complexMlp,tokenizer,classifier,device).to(device)
    model = NewMASKBertFusionModel(Bert,transe,simpleMlp,tokenizer,classifier,device).to(device)

    if os.path.exists(weight_path):
        model.load_state_dict(torch.load(weight_path))
        print('model has loaded.')
    else:
        print('model dont exists.')
        print('init model...')

    #train_data

    input_ids,heads,relations,labels = get_data_from_rawdata(tokenizer, onlyTail_train_data, entity2id,relation2id,
                                                                               'tail-batch')
    train_dataset = OnlyTailDataset(input_ids,heads,relations,labels)

    input_ids,heads,relations,labels = get_data_from_rawdata(tokenizer, onlyTail_test_data, entity2id,relation2id,
                                                                               'tail-batch')
    test_dataset = OnlyTailDataset(input_ids,heads,relations,labels)

    train_dataloader = DataLoader(train_dataset,batch_size = arguments.train_batch_size,shuffle=True)
    test_dataloader = DataLoader(test_dataset,batch_size = arguments.val_batch_size,shuffle=False)

    train_val_epochs(Bert,model,train_dataloader,test_dataloader,epochs,lr,label_num,device)


