import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import time
import numpy as np
import random
import json
import argparse
import os, sys
from torch.utils.data import DataLoader, RandomSampler

from transformers import BertTokenizer, BertForMaskedLM, XLMRobertaTokenizer, XLMRobertaForMaskedLM, RobertaTokenizer, RobertaForMaskedLM

from data import Data

label_map = {"PAD":0, "O": 1, "B-PER":2, "I-PER":3, "B-ORG":4, "I-ORG":5,
             "B-LOC":6, "I-LOC":7, "B-MISC":8, "I-MISC":9}


def NT_Xnet(gold_hids, pos_hids, neg_hids, T=1):
    
    cos = nn.CosineSimilarity(dim=-1)
    numer = torch.exp(cos(gold_hids, pos_hids) / T)
    #denom = numer + torch.exp(cos(gold_hids, neg_hids) / T)
    denom = numer.clone()
    for j in range(len(neg_hids)):
        denom += torch.exp(cos(gold_hids, neg_hids[j]) / T)
    loss = - torch.log(numer / denom)
    #print("numer ", numer)
    #print("denom", denom)
    #print("loss" ,loss)
    return loss

def neg_loss(gold_hids, neg_hids, weight):

    loss = 0
    num_neg = len(neg_hids)

    cos = nn.CosineSimilarity(dim=-1)
    for j in range(num_neg):
        loss += weight * cos(gold_hids, neg_hids[j]) / num_neg

    return loss

def MI_loss(gold_embs, gold_hids, neg_embs, weight):

    loss = 0
    num_neg = len(neg_embs)

    cos = nn.CosineSimilarity(dim=-1)
    loss = loss - cos(gold_embs, gold_hids)
    for j in range(num_neg):
        loss += cos(neg_embs[j], gold_hids) / num_neg

    return loss * weight

def train(model, iterator, optimizer, clip, num_neg, grad_acc=1, epoch=0, emb_queue=None):

    log_interval = round(len(iterator)/5) * 1
    #log_interval = 1

    model.train()
    train_start = time.time()

    epoch_loss = 0
    epoch_nce_loss = 0
    correct_count = 0
    total_count = 0
    entity_correct = 0
    entity_total = 0

    optimizer.zero_grad()

    for i, batch in enumerate(iterator):
        batch_start = time.time()
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, label_ids, masked_ids, entity_mask, nce_pos_ids, nce_neg_ids = batch
        # Different masking for diff epoch
        epoch_remainder = epoch % 30
        masked_ids = masked_ids[:,epoch_remainder]
        entity_mask = entity_mask[:,epoch_remainder]        
        nce_pos_ids = nce_pos_ids[:,epoch_remainder]
        nce_neg_ids = nce_neg_ids[:,epoch_remainder]
        
        batch_size = label_ids.shape[0]

        #optimizer.zero_grad()

        outputs = model(masked_ids, input_mask, labels=input_ids, output_hidden_states=True)
        loss = outputs.loss
        logits = outputs.logits
        last_hids = outputs.hidden_states[-1]
        embs = outputs.hidden_states[0]

        # NCE loss
        '''
        # Pos sample without detach
        nce_pos_outputs = model.roberta(nce_pos_ids, input_mask)
        nce_pos_last_hids = nce_pos_outputs.last_hidden_state#.detach()    
        
        # Single neg sample, need to modify data.py
        #nce_neg_outputs = model.roberta(nce_neg_ids, input_mask)
        #nce_neg_last_hids = nce_neg_outputs.last_hidden_state.detach()
        ##nce_neg_outputs = model(nce_neg_ids, input_mask, labels=input_ids, output_hidden_states=True)
        ##nce_neg_last_hids = nce_neg_outputs.hidden_states[-1]
        #nce_loss = NT_Xnet(last_hids, nce_pos_last_hids, nce_neg_last_hids)

        #NCE 7 loss
        nce_neg_last_hids = []
        #for j in range(num_neg):
        for j in random.sample(range(7),num_neg):
            nce_neg_outputs = model.roberta(nce_neg_ids[:,j], input_mask)
            nce_neg_last_hids.append(nce_neg_outputs.last_hidden_state) #.detach())

        # NT_Xnet loss
        #nce_loss = NT_Xnet(last_hids, nce_pos_last_hids, nce_neg_last_hids)
        # Avg neg cos loss
        nce_loss = neg_loss(last_hids, nce_neg_last_hids, weight=WEIGHT)

        if num_neg != 0:
            assert nce_loss.shape == last_hids.shape[:2]
        nce_loss = torch.sum(nce_loss * entity_mask) / (torch.sum(entity_mask) + 1e-8)
        #print(entity_mask)
        #print(nce_loss)
        
        loss += nce_loss
        '''
#       Directly use LM loss of neg samples, training will not converge
#        nce_loss = 0
#        for j in random.sample(range(7),num_neg):
#            nce_loss = nce_loss -  model(nce_neg_ids[:,j], input_mask, labels=input_ids, output_hidden_states=True).loss / num_neg
#            loss = loss + nce_loss

#       MI Loss (ref Cfd EMNLP2020)
#        nce_neg_embs = []
#        for j in random.sample(range(7),num_neg):
#            nce_neg_embs.append(model.roberta.embeddings(nce_neg_ids[:,j]))
#        #print('shape of nce_neg_embs', len(nce_neg_embs), nce_neg_embs[0].shape)
#        #print('shape of embs', embs.shape)
#        #print('shape of last hids', last_hids.shape)
#
#        nce_loss = MI_loss(embs, last_hids, nce_neg_embs)
#        assert nce_loss.shape == last_hids.shape[:2]
#        nce_loss = torch.sum(nce_loss * entity_mask) / (torch.sum(entity_mask) + 1e-8)
#        loss += nce_loss


        # Directly enlarge embeddings distance of mask tokens
#        nce_loss = 0
#        cos = nn.CosineSimilarity(dim=-1)
#        for j in random.sample(range(7),num_neg):
#            nce_loss += cos(embs, model.roberta.embeddings(nce_neg_ids[:,j])) / num_neg
#        if num_neg != 0:
#            assert nce_loss.shape == last_hids.shape[:2]
#        nce_loss = torch.sum(nce_loss * entity_mask) / (torch.sum(entity_mask) + 1e-8)
#        loss += nce_loss

        # Orthogonal embedding cosine similarith
        nce_loss = 0
#        cos = nn.CosineSimilarity(dim=-1)
#        for k in range(-9,-1): # Excluding O#
#            for j in range(-9,0):
#                if ((k % 2) == 1 and j == (k+1)) or ((k % 2) == 0 and j == (k-1)) or (j == k):
#                    continue
#                else:
#                    nce_loss += WEIGHT * cos(model.roberta.embeddings.word_embeddings.weight[k, :], model.roberta.embeddings.word_embeddings.weight[j, :]) / 56
#
#        loss += nce_loss

        # Center loss
        '''
        input_embs = model.roberta.embeddings(input_ids).detach() # Uncorrupted embs
        cos = nn.CosineSimilarity(dim=-1)

        nce_loss = 0
        updated_class = 1e-8

        for label in range(2,10):
            #print("label ", label)
            class_hids_mask = (((label_ids == label) * (entity_mask == 1)) > 0).unsqueeze(-1)
            class_hids = torch.masked_select(last_hids, class_hids_mask).view(-1, last_hids.shape[-1])
            #print("class_hids...", class_hids.shape)
    
            class_embs_mask = (label_ids == label).unsqueeze(-1)
            class_embs = torch.masked_select(input_embs, class_embs_mask).view(-1, input_embs.shape[-1])
            #print("class_embs...", class_embs.shape)

            if epoch != 0: # Do not include center loss in epoch 0 as centroids are not stable
                if class_hids.nelement() != 0:
                    class_dist = cos(torch.mean(class_hids, dim=0), torch.mean(torch.cat(emb_queue[label]), dim=0))
                    
                    neg_labels = [2,3,4,5,6,7,8,9]
                    neg_labels.remove(math.floor(label/2)*2)
                    neg_labels.remove(math.floor(label/2)*2+1)
                    assert len(neg_labels) == 6
                    inter_class_loss = 0
                    for neg_label in neg_labels:
                        inter_class_loss += cos(torch.mean(class_hids, dim=0), torch.mean(torch.cat(emb_queue[neg_label]), dim=0)) / 6
                    
                    nce_loss = nce_loss - class_dist + inter_class_loss
                    #print("nce_loss", nce_loss)
                    updated_class += 1

            if class_embs.nelement() != 0:
                emb_queue[label].append(class_embs)
                if epoch != 0: # only remove leading elements from 2nd epoch (ep1)
                    emb_queue[label].pop(0)
            #print(len(emb_queue[label]))
                    
        nce_loss = WEIGHT * nce_loss / updated_class
        loss += nce_loss
        '''
        loss = loss / grad_acc
        
        epoch_loss += loss
        epoch_nce_loss += nce_loss

        pred = torch.argmax(logits, dim=-1)
        
        match = (input_ids == pred) * input_mask
        correct_count += torch.sum(match).item()
        total_count += torch.sum(input_mask).item()
        
        entity_match = (input_ids == pred) * entity_mask
        entity_correct += torch.sum(entity_match).item()
        entity_total += torch.sum(entity_mask).item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
             
        # Freeze special label tokens
        #freeze_idx = torch.LongTensor(range(-9,0))
        #model.roberta.embeddings.word_embeddings.weight.grad[freeze_idx] = 0

        if (i+1) % grad_acc == 0:
            optimizer.step()
            optimizer.zero_grad()
            # Ensure mask tokens are not the same
            assert torch.equal(model.roberta.embeddings.word_embeddings.weight[-9, :], model.roberta.embeddings.word_embeddings.weight[-7, :]) == False

        if (i+1) % log_interval == 0:
            print(f'{i + 1}/{len(iterator)} batches done | Total time: {time.time() - train_start:10.3f} | ',
                  f'Current Batch Loss {loss:.3f} | Epoch average loss: {epoch_loss/(i+1):.3f} |',
                  f'Epoch avg NCE loss: {epoch_nce_loss/(i+1):.3f} | Epoch accuracy: {correct_count/total_count*100:.2f}% | ',
                  f'Entity acc: {entity_correct/entity_total*100:.2f}%')
            #print(model.roberta.embeddings.word_embeddings.weight[-9, :5])
            #print(model.roberta.embeddings.word_embeddings.weight[-7, :5])

    return emb_queue

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def evaluate(model, iterator):
    model.eval()
    with torch.no_grad():

        epoch_loss = 0
        correct_count = 0
        total_count = 0
        entity_correct = 0
        entity_total = 0

        for i, batch in enumerate(iterator):
            batch_start = time.time()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, label_ids, masked_ids, entity_mask, _,_ = batch

            # Use first masking for evaluation
            masked_ids = masked_ids[:,0]
            entity_mask = entity_mask[:,0]

            batch_size = label_ids.shape[0]

            outputs = model(masked_ids, input_mask, labels=input_ids)
            loss = outputs.loss
            logits = outputs.logits

            epoch_loss += loss

            pred = torch.argmax(logits, dim=-1)
            
            match = (input_ids == pred) * input_mask
            correct_count += torch.sum(match).item()
            total_count += torch.sum(input_mask).item() 
    
            entity_match = (input_ids == pred) * entity_mask
            entity_correct += torch.sum(entity_match).item()
            entity_total += torch.sum(entity_mask).item()

    return epoch_loss/(i+1), correct_count / total_count, entity_correct / entity_total

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", required=True, type=str)
args = parser.parse_args()

with open(args.config) as json_data:
    config = json.load(json_data)
    print(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Running on ", device)

    SEED = config['seed']
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    FILE_DIR = config["FILE_DIR"]
    CKPT_DIR = config["CKPT_DIR"]
    BSIZE = config["BSIZE"]
    N_EPOCHS = config["N_EPOCHS"]
    CLIP = config["CLIP"]
    LR = config["LR"]
    GRAD_ACC = config["GRAD_ACC"]
    N_NEG = config["N_NEG"]
    MASK_RATE = config["MASK_RATE"]
    WEIGHT = config["WEIGHT"]
 
    ckpt_folder = '/'.join(CKPT_DIR.split('/')[:-1])
    if os.path.isdir(ckpt_folder):
        print("\nWarning! Checkpoint dir exist!.......................\n")
    else:
        os.mkdir(ckpt_folder)
        print("Checkpoints will be saved to: ", CKPT_DIR)
    #model = RobertaForMaskedLM.from_pretrained(config["load_bert"], return_dict=True).to(device)
    #tokenizer = RobertaTokenizer.from_pretrained(config["load_bert"], do_lower_case=False)
    
    model = XLMRobertaForMaskedLM.from_pretrained(config["load_bert"], return_dict=True).to(device)
    tokenizer = XLMRobertaTokenizer.from_pretrained(config["load_bert"], do_lower_case=False)

    #model = BertForMaskedLM.from_pretrained(config["load_bert"], return_dict=True).to(device)
    #tokenizer = BertTokenizer.from_pretrained(config["load_bert"], do_lower_case=False)

    # Add entity labels as special tokens
    tokenizer.add_tokens(['<B-PER>', '<I-PER>', '<B-ORG>', '<I-ORG>', '<B-LOC>', '<I-LOC>', '<B-MISC>', '<I-MISC>','<O>'],
                         special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))
    #assert len(tokenizer) == 250011

    # Initialize new mask tokens as weight of <mask>
    #for i in range(-9,0):
    #    with torch.no_grad():
    #        model.roberta.embeddings.word_embeddings.weight[i, :] = model.roberta.embeddings.word_embeddings.weight.data[250001, :].clone()
  
    # Use numpy for initialization
    #emb_np = model.roberta.embeddings.word_embeddings.weight.data.cpu().numpy()
    #for i in range(-9,0):
    #    emb_np[i] = 10 * emb_np[i] + emb_np[250001]
    #model.roberta.embeddings.word_embeddings.weight.data = torch.from_numpy(emb_np).to(device)

    with torch.no_grad():
        model.roberta.embeddings.word_embeddings.weight[-1, :] += model.roberta.embeddings.word_embeddings.weight[1810, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-2, :] += model.roberta.embeddings.word_embeddings.weight[27060, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-3, :] += model.roberta.embeddings.word_embeddings.weight[27060, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-4, :] += model.roberta.embeddings.word_embeddings.weight[31913, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-5, :] += model.roberta.embeddings.word_embeddings.weight[31913, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-6, :] += model.roberta.embeddings.word_embeddings.weight[53702, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-7, :] += model.roberta.embeddings.word_embeddings.weight[53702, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-8, :] += model.roberta.embeddings.word_embeddings.weight[3445, :].clone()
        model.roberta.embeddings.word_embeddings.weight[-9, :] += model.roberta.embeddings.word_embeddings.weight[3445, :].clone()


    # Generate dataset .pt file is not exist
    print("Loading file from: ", FILE_DIR)
    train_dataset, valid_dataset, test_dataset = tuple(Data(tokenizer, BSIZE, label_map, FILE_DIR, MASK_RATE).datasets)

    # Load datasets
    #train_dataset = torch.load(TRAIN_DIR)
    train_dataloader = DataLoader(train_dataset, batch_size=BSIZE, sampler=RandomSampler(train_dataset))
    #valid_dataset = torch.load(VALID_DIR)
    valid_dataloader = DataLoader(valid_dataset, batch_size=BSIZE)
    #test_dataset = torch.load(TEST_DIR)
    test_dataloader = DataLoader(test_dataset, batch_size=BSIZE)

    optimizer = optim.Adam(model.parameters(), lr=LR)

    best_valid_loss = float('inf')
    best_valid_entity_acc = -float('inf')
    best_valid_entity_acc_by_acc = -float('inf')

    # Initialize emb queue for center loss
    emb_queue = [[] for i in range(10)] # idx 0 & 1 list is not updated

    for epoch in range(N_EPOCHS):
        start_time = time.time()

        #train(model, train_dataloader, optimizer, CLIP, N_NEG, GRAD_ACC, epoch)
        emb_queue = train(model, train_dataloader, optimizer, CLIP, N_NEG, GRAD_ACC, epoch, emb_queue)
        valid_loss, valid_acc, valid_entity_acc = evaluate(model, valid_dataloader)

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        print(f'\nEpoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s',
              f'Epoch valid loss: {valid_loss:.3f} | Epoch valid PPL: {math.exp(valid_loss):.1f}',
              f'Epoch valid acc: {valid_acc * 100:.2f}% | Epoch entity acc: {valid_entity_acc*100:.2f}% \n')

        if valid_loss < best_valid_loss:
        #if valid_entity_acc > best_valid_entity_acc:
        #if (epoch+1) % 5 == 0:
            print("By dev ppl, Saving current epoch to checkpoint...\n")
            best_valid_loss = valid_loss
            best_valid_epoch = epoch
            best_valid_acc = valid_acc
            best_valid_entity_acc = valid_entity_acc
            #torch.save(model.state_dict(), "/".join(CKPT_DIR.split('/')[:-1]) + '/best_val_model_epoch' + str(math.ceil((epoch+1)/3)*3) + '.pt')
            torch.save(model.state_dict(), CKPT_DIR)

        #if valid_entity_acc > best_valid_entity_acc_by_acc:
        #    print("By dev acc, Saving current epoch to checkpoint...\n")
        #    best_valid_entity_acc_by_acc = valid_entity_acc
        #    torch.save(model.state_dict(), "/".join(CKPT_DIR.split('/')[:-1])+"/best_val_model_epoch100.pt") 
 
        
    print("Training finished...")
    print(f'\n Best valid loss until epoch {epoch} is {best_valid_loss:.3f} at epoch {best_valid_epoch + 1}',
          f'\n valid acc is {best_valid_acc * 100:.2}%, valid entity acc is {best_valid_entity_acc * 100:.2f}%')
