"""For model training and inference (multi dialogue act & slot detection)
"""
import random
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam, RMSprop
from transformers import BertTokenizer, BertModel, BertConfig, AdamW

from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import pickle
import copy
import numpy as np
import collections
from tqdm import tqdm
from collections import defaultdict, Counter

from model import BertContextNLU, BertFuse, ECA, KASLUM, RecATT
from all_data_context import get_dataloader_context
from config import opt
from utils import *

def train(**kwargs):
    
    # attributes
    for k, v in kwargs.items():
        setattr(opt, k, v)
    np.random.seed(0)
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.enabled = False

    print('Dataset to use: ', opt.train_path)
    print('Dictionary to use: ', opt.dic_path_with_tokens)
    print('Data Type: ', opt.datatype)
    print('Use pretrained weights: ', opt.retrain)

    # dataset
    with open(opt.dic_path_with_tokens, 'rb') as f:
        dic = pickle.load(f)
    with open(opt.slot_path, 'rb') as f:
        slot_dic = pickle.load(f)
    with open(opt.train_path, 'rb') as f:
        train_data = pickle.load(f)
    
    # use all dataset
    # paths = ['data/e2e_dialogue/dialogue_data_movie_all_kg.pkl',
    #          'data/e2e_dialogue/dialogue_data_restaurant_all_kg.pkl',
    #          'data/e2e_dialogue/dialogue_data_taxi_all_kg.pkl']
    # train_data = []
    # for path in paths:
    #     with open(path, 'rb') as f:
    #         data = pickle.load(f)
    #         train_data.extend(data)
        

    # Microsoft Dialogue Dataset / SGD Dataset
    indices = np.random.permutation(len(train_data))
    train = np.array(train_data)[indices[:int(len(train_data)*0.6)]]#[:71]
    val = np.array(train_data)[indices[int(len(train_data)*0.6):int(len(train_data)*0.7)]]#[:71]
    # train_ind = np.random.permutation(len(train))
    # train = train[train_ind[:int(len(train)*0.01)]]
    test = np.array(train_data)[indices[int(len(train_data)*0.7):]]#[:100]
    
    train_loader = get_dataloader_context(train, dic, slot_dic, opt)
    val_loader = get_dataloader_context(val, dic, slot_dic, opt)

    # label tokens
    intent_tokens = [intent for name, (tag, intent) in dic.items()]
    intent_tok, mask_tok = load_data(intent_tokens, 10)
    intent_tokens = torch.zeros(len(intent_tok), 10).long().to(device)
    mask_tokens = torch.zeros(len(mask_tok), 10).long().to(device)
    for i in range(len(intent_tok)):
        intent_tokens[i] = torch.tensor(intent_tok[i])
    for i in range(len(mask_tok)):
        mask_tokens[i] = torch.tensor(mask_tok[i])

    # slot tokens
    slot_tokens = [slot for name, (tag, slot) in slot_dic.items()]
    slot_dic_clean = {name: tag for name, (tag, slot) in slot_dic.items()}
    slot_tok, mask_tok = load_data(slot_tokens, 10)
    slot_tokens = torch.zeros(len(slot_tok), 10).long().to(device)
    slot_mask_tokens = torch.zeros(len(mask_tok), 10).long().to(device)
    for i in range(len(slot_tok)):
        slot_tokens[i] = torch.tensor(slot_tok[i])
    for i in range(len(mask_tok)):
        slot_mask_tokens[i] = torch.tensor(mask_tok[i])
    
    # model
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    
    if opt.run_baseline == 'eca':
        print("Run ECA baseline...")
        model = ECA(opt, len(dic), len(slot_dic))
    elif opt.run_baseline == 'kaslum':
        print("Run KA-SLUM baseline...")
        model = KASLUM(opt, len(dic), len(slot_dic))
    elif opt.model == 'recatt':
        model = RecATT(config, opt, len(dic), len(slot_dic))
    elif opt.model == 'new_fuse':
        print("Run Fusion Model...")
        model = BertFuse(config, opt, len(dic), len(slot_dic))
    else:
        model = BertContextNLU(config, opt, len(dic), len(slot_dic))
    
    if opt.model_path:
        model.load_state_dict(torch.load(opt.model_path))
        print("Pretrained model has been loaded.\n")
    else:
        print("Train from scratch...")
    model = model.to(device)

    optimizer = AdamW(model.parameters(), weight_decay=0.01, lr=opt.learning_rate_bert)
    criterion = nn.BCEWithLogitsLoss(reduction='sum').to(device)
    criterion2 = nn.CrossEntropyLoss(reduction='sum').to(device)

    best_loss = 100
    best_accuracy = 0
    best_f1 = 0

    #################################### Start training ####################################
    for epoch in range(opt.epochs):
        print("====== epoch %d / %d: ======"% (epoch+1, opt.epochs))

        # Training Phase
        total_train_loss = 0
        total_P = 0
        total_R = 0
        total_F1 = 0
        total_acc = 0
        model.train()
        ccounter = 0
        for (result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, result_kg) in tqdm(train_loader):

            result_ids = result_ids.to(device)
            result_token_masks = result_token_masks.to(device)
            result_masks = result_masks.to(device)
            lengths = lengths.to(device)
            result_slot_labels = result_slot_labels.to(device)
            if opt.run_baseline == 'bert_naive' or opt.run_baseline == 'laban':
                result_slot_labels = result_slot_labels[:, :, 1:].to(device)
            result_slot_labels = result_slot_labels.reshape(-1)
            result_labels = result_labels.to(device)
            result_kg = result_kg.to(device)

            optimizer.zero_grad()

            outputs, labels, slot_out = model(result_ids, result_token_masks, result_masks, lengths, result_slot_labels, \
                                              result_labels, intent_tokens, mask_tokens, slot_tokens, slot_mask_tokens, result_kg)
            train_loss = criterion(outputs, labels)
            slot_loss = criterion2(slot_out, result_slot_labels)
            total_loss = train_loss + slot_loss
            
            total_loss.backward()
            optimizer.step()

            total_train_loss += total_loss
            P, R, F1, acc = f1_score_intents(outputs, labels)
            total_P += P
            total_R += R
            total_F1 += F1
            total_acc += acc
            ccounter += 1

        print('Average train loss: {:.4f} '.format(total_train_loss / train_loader.dataset.num_data))
        precision = total_P / ccounter
        recall = total_R / ccounter
        f1 = total_F1 / ccounter
        print(f'P = {precision:.4f}, R = {recall:.4f}, F1 = {f1:.4f}')
        print('Accuracy: ', total_acc/train_loader.dataset.num_data)
        

        # Validation Phase
        if (epoch+1) % opt.interval == 0:
            total_val_loss = 0
            total_P = 0
            total_R = 0
            total_F1 = 0
            total_acc = 0
            model.eval()
            ccounter = 0
            stats = defaultdict(Counter)
            for (result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, result_kg) in val_loader:

                result_ids = result_ids.to(device)
                result_token_masks = result_token_masks.to(device)
                result_masks = result_masks.to(device)
                lengths = lengths.to(device)
                result_slot_labels = result_slot_labels.to(device)
                if opt.run_baseline == 'bert_naive':
                    result_slot_labels = result_slot_labels[:, :, 1:].to(device)
                result_slot_labels = result_slot_labels.reshape(-1)
                result_labels = result_labels.to(device)
                result_kg = result_kg.to(device)
                
                with torch.no_grad():
                    outputs, labels, predicted_slot_outputs  = model(result_ids, result_token_masks, result_masks, lengths, result_slot_labels, \
                                                                    result_labels, intent_tokens, mask_tokens, slot_tokens, slot_mask_tokens, result_kg)
                val_loss = criterion(outputs, labels)

                total_val_loss += val_loss
                P, R, F1, acc = f1_score_intents(outputs, labels)
                total_P += P
                total_R += R
                total_F1 += F1
                total_acc += acc
                ccounter += 1

                _, index = torch.topk(predicted_slot_outputs, k=1, dim=-1)
                evaluate_iob(index, result_slot_labels, slot_dic_clean, stats)

            print('========= Validation =========')
            print('Average val loss: {:.4f} '.format(total_val_loss / val_loader.dataset.num_data))

            precision = total_P / ccounter
            recall = total_R / ccounter
            f1 = total_F1 / ccounter
            print(f'P = {precision:.4f}, R = {recall:.4f}, F1 = {f1:.4f}')
            print('Accuracy: ', total_acc/val_loader.dataset.num_data)
            val_acc = total_acc/val_loader.dataset.num_data

            # print slot stats
            p_slot, r_slot, f1_slot = prf(stats['total'])
            print('========= Slot =========')
            print(f'Slot Score: P = {p_slot:.4f}, R = {r_slot:.4f}, F1 = {f1_slot:.4f}')
            
            f1_for_save = f1_slot if opt.run_baseline != 'laban' else f1

            if f1_for_save > best_f1:
                print('saving with loss of {}'.format(total_val_loss),
                    'improved over previous {}'.format(best_loss))
                best_loss = total_val_loss
                best_accuracy = val_acc
                best_f1 = f1_for_save
                best_stats = copy.deepcopy(stats)

                torch.save(model.state_dict(), 'checkpoints/best_{}_{}_{}.pth'.format(opt.datatype, opt.model, opt.domain))
        
            print()
    print('Best total val loss: {:.4f}'.format(total_val_loss))
    print('Best Test Accuracy: {:.4f}'.format(best_accuracy))
    print('Best F1 Score: {:.4f}'.format(best_f1))

    p_slot, r_slot, f1_slot = prf(best_stats['total'])
    print('Final evaluation on slot filling of the validation set:')
    print(f'Overall: P = {p_slot:.4f}, R = {r_slot:.4f}, F1 = {f1_slot:.4f}')


#####################################################################


def test(**kwargs):

    # attributes
    for k, v in kwargs.items():
        setattr(opt, k, v)
    np.random.seed(0)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.enabled = False

    print('Dataset to use: ', opt.train_path)
    print('Dictionary to use: ', opt.dic_path_with_tokens)

    # dataset
    with open(opt.dic_path_with_tokens, 'rb') as f:
        dic = pickle.load(f)
    print(dic)
    reverse_dic = {v[0]: k for k,v in dic.items()}
    with open(opt.slot_path, 'rb') as f:
        slot_dic = pickle.load(f)
    with open(opt.train_path, 'rb') as f:
        train_data = pickle.load(f)
    # movie = pickle.load(open("data/e2e_dialogue/dialogue_data_movie_all_kg.pkl", "rb"))
    # res = pickle.load(open("data/e2e_dialogue/dialogue_data_restaurant_all_kg.pkl", "rb"))
    # taxi = pickle.load(open("data/e2e_dialogue/dialogue_data_taxi_all_kg.pkl", "rb"))
    # train_data = movie+res+taxi

    
    # Microsoft Dialogue Dataset / SGD Dataset
    indices = np.random.permutation(len(train_data))
    train = np.array(train_data)[indices[:int(len(train_data)*0.6)]]
    val = np.array(train_data)[indices[int(len(train_data)*0.6):int(len(train_data)*0.7)]]
    test = np.array(train_data)[indices[int(len(train_data)*0.7):]]

    train_loader = get_dataloader_context(train, dic, slot_dic, opt)
    test_loader = get_dataloader_context(test, dic, slot_dic, opt)

    # label tokens
    intent_tokens = [intent for name, (tag, intent) in dic.items()]
    intent_tok, mask_tok = load_data(intent_tokens, 10)
    intent_tokens = torch.zeros(len(intent_tok), 10).long().to(device)
    mask_tokens = torch.zeros(len(mask_tok), 10).long().to(device)
    for i in range(len(intent_tok)):
        intent_tokens[i] = torch.tensor(intent_tok[i])
    for i in range(len(mask_tok)):
        mask_tokens[i] = torch.tensor(mask_tok[i])
    
    # slot tokens
    slot_tokens = [slot for name, (tag, slot) in slot_dic.items()]
    slot_dic_clean = {name: tag for name, (tag, slot) in slot_dic.items()}
    print(slot_dic_clean)
    reverse_slot_dic_clean = {v: k for k,v in slot_dic_clean.items()}
    slot_tok, mask_tok = load_data(slot_tokens, 10)
    slot_tokens = torch.zeros(len(slot_tok), 10).long().to(device)
    slot_mask_tokens = torch.zeros(len(mask_tok), 10).long().to(device)
    for i in range(len(slot_tok)):
        slot_tokens[i] = torch.tensor(slot_tok[i])
    for i in range(len(mask_tok)):
        slot_mask_tokens[i] = torch.tensor(mask_tok[i])
    
    # model
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    
    if opt.model == 'new_fuse':
        print("Run Fusion Model...")
        model = BertFuse(config, opt, len(dic), len(slot_dic))
    else:
        model = BertContextNLU(config, opt, len(dic), len(slot_dic))

    if opt.model_path:
        model.load_state_dict(torch.load(opt.model_path))
        print("Pretrained model {} has been loaded.".format(opt.model_path))
    model = model.to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    
    # Run multi-intent validation
    if opt.test_mode == "validation":
        
        total_P = 0
        total_R = 0
        total_F1 = 0
        total_acc = 0
        model.eval()
        ccounter = 0
        stats = defaultdict(Counter)
        for (result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, result_kg) in tqdm(test_loader):

            result_ids = result_ids.to(device)
            result_token_masks = result_token_masks.to(device)
            result_masks = result_masks.to(device)
            lengths = lengths.to(device)
            result_slot_labels = result_slot_labels.to(device)
            result_slot_labels = result_slot_labels.reshape(-1)
            result_labels = result_labels.to(device)
            result_kg = result_kg.to(device)
            
            with torch.no_grad():
                outputs, labels, predicted_slot_outputs  = model(result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, intent_tokens, mask_tokens, slot_tokens, slot_mask_tokens, result_kg)

            P, R, F1, acc = f1_score_intents(outputs, labels)
            total_P += P
            total_R += R
            total_F1 += F1
            total_acc += acc
            ccounter += 1

            _, index = torch.topk(predicted_slot_outputs, k=1, dim=-1)
            evaluate_iob(index, result_slot_labels, slot_dic_clean, stats)

        precision = total_P / ccounter
        recall = total_R / ccounter
        f1 = total_F1 / ccounter
        print(f'P = {precision:.4f}, R = {recall:.4f}, F1 = {f1:.4f}')
        print('Accuracy: ', total_acc/test_loader.dataset.num_data)

        # print slot stats
        p_slot, r_slot, f1_slot = prf(stats['total'])
        print('========= Slot =========')
        print(f'Slot Score: P = {p_slot:.4f}, R = {r_slot:.4f}, F1 = {f1_slot:.4f}')

        # print number slot stats
        individual_slot_score = {}
        print(stats.keys())
        for k, v in stats.items():
            if k != 'total' and k in opt.nslots:
                # if len(v) == 3:
                p_slot, r_slot, f1_slot = prf(v)
                individual_slot_score[k] = f1_slot
        
        for k,v in individual_slot_score.items():
            print(k, ' slot F1: ', v)
        
        print('Average score: ', sum(list(individual_slot_score.values())) /len(individual_slot_score))
    
    # Run test classification
    elif opt.test_mode == "data":

        # Validation Phase
        pred_labels = []
        real_labels = []
        pred_slot_labels = []
        real_slot_labels = []
        all_score_kg = []
        error_ids = []
        total_P, total_R, total_F1, total_acc = 0, 0, 0, 0
        ccounter = 0
        stats = defaultdict(Counter)
        model.eval()
        print(len(test_loader.dataset))
        for num, (result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, result_kg) in enumerate(test_loader):
            print('predict batches: ', num)

            result_ids = result_ids.to(device)
            result_token_masks = result_token_masks.to(device)
            result_masks = result_masks.to(device)
            lengths = lengths.to(device)
            result_slots = result_slot_labels.to(device)
            result_slot_labels = result_slots.reshape(-1)
            result_labels = result_labels.to(device)
            result_kg = result_kg.to(device)

            # Remove padding
            texts_no_pad = []
            for i in range(len(result_ids)):
                texts_no_pad.append(result_ids[i,:lengths[i],:])
            texts_no_pad = torch.cat(texts_no_pad, dim=0) # (b*d, t)

            slots_no_pad = []
            for i in range(len(result_slots)):
                slots_no_pad.append(result_slots[i,:lengths[i],:])
            slots_no_pad = torch.cat(slots_no_pad, dim=0) # (b*d, t)
            
            with torch.no_grad():
                outputs, labels, predicted_slot_outputs, score_kg  = model(result_ids, result_token_masks, result_masks, lengths, result_slot_labels, result_labels, intent_tokens, mask_tokens, slot_tokens, slot_mask_tokens, result_kg)

                # total
                P, R, F1, acc = f1_score_intents(outputs, labels)
                total_P += P
                total_R += R
                total_F1 += F1
                total_acc += acc
                
                ccounter += 1

                _, index = torch.topk(predicted_slot_outputs, k=1, dim=-1)
                evaluate_iob(index, result_slot_labels, slot_dic_clean, stats)

                # slot labels
                index_reshape = index.reshape(result_ids.shape)
                pred_slots_no_pad = []
                for i in range(len(index_reshape)):
                    pred_slots_no_pad.append(result_slots[i,:lengths[i],:])
                pred_slots_no_pad = torch.cat(pred_slots_no_pad, dim=0) # (b*d, t)

                score_kg_no_pad = []
                for i in range(len(score_kg)):
                    score_kg_no_pad.append(score_kg[i,:lengths[i],:,:])
                score_kg_no_pad = torch.cat(score_kg_no_pad, dim=0) # (b*d, t, num_knowledge)

                for i, logits in enumerate(outputs):
                    log = torch.sigmoid(logits)
                    wrong_caption = tokenizer.convert_ids_to_tokens(texts_no_pad[i], skip_special_tokens=True)
                    error_ids.append(wrong_caption)
                    pred_ls = [p for p in torch.where(log>0.5)[0].detach().cpu().numpy()]
                    real_ls = [i for i, r in enumerate(labels[i].detach().cpu().numpy()) if r == 1]
                    pred_labels.append(pred_ls)
                    real_labels.append(real_ls)

                    pred_s = [reverse_slot_dic_clean[slot] for slot in pred_slots_no_pad[i].detach().cpu().numpy() if reverse_slot_dic_clean[slot] != '[PAD]']
                    real_s = [reverse_slot_dic_clean[slot] for slot in slots_no_pad[i].detach().cpu().numpy() if reverse_slot_dic_clean[slot] != '[PAD]']
                    pred_slot_labels.append(pred_s)
                    real_slot_labels.append(real_s)
                    all_score_kg.append(score_kg_no_pad[i])
                

        # write prediction results in txt
        all_weights = []
        with open('error_analysis/{}_{}_context_kg.txt'.format(opt.datatype, opt.data_mode), 'w') as f:
            f.write('----------- Examples ------------\n')
            for i, (caption, pred, real, pred_s, real_s, kg) in enumerate(zip(error_ids, pred_labels, real_labels, pred_slot_labels, real_slot_labels, all_score_kg)):
                f.write(str(i)+'\n')
                f.write(' '.join(caption)+'\n')
                p_r = [reverse_dic[p] for p in pred]
                r_r = [reverse_dic[r] for r in real]
                f.write('Predicted label: {}\n'.format(p_r))
                f.write('Real label: {}\n'.format(r_r))
                f.write('Predicted slot label: {}\n'.format(' '.join(pred_s)))
                f.write('Real slot label: {}\n'.format(' '.join(real_s)))
                f.write('Score: \n')
                sent_weights = np.zeros((len(pred_s), 5*len(pred_s)))
                for i in range(len(pred_s)):
                    # if real_s[i] != '[PAD]' and real_s[i] != 'O':
                    f.write('Word {} score: {}\n'.format(i, kg[i].detach().cpu().numpy()[:5*len(pred_s)]))
                    sent_weights[i,:] = kg[i].detach().cpu().numpy()[:5*len(pred_s)]
                f.write('------\n')
                all_weights.append((caption, real_s, sent_weights))
        with open('error_analysis/{}_{}_context_kg.pkl'.format(opt.datatype, opt.data_mode), 'wb') as f:
            pickle.dump(all_weights, f)
        
        # print dialog act stats
        precision = total_P / ccounter
        recall = total_R / ccounter
        f1 = total_F1 / ccounter
        print(f'P = {precision:.4f}, R = {recall:.4f}, F1 = {f1:.4f}')
        print('Accuracy: ', total_acc/test_loader.dataset.num_data)

        # print slot stats
        p_slot, r_slot, f1_slot = prf(stats['total'])
        print('========= Slot =========')
        print(f'Slot Score: P = {p_slot:.4f}, R = {r_slot:.4f}, F1 = {f1_slot:.4f}')

        # print number slot stats
        individual_slot_score = {}
        for k, v in stats.items():
            if k != 'total' and k in opt.nslots:
                if len(v) == 3:
                    p_slot, r_slot, f1_slot = prf(v)
                    individual_slot_score[k] = f1_slot
        
        for k,v in individual_slot_score.items():
            print(k, ' slot F1: ', v)
        
        print('Average score: ', sum(list(individual_slot_score.values())) /len(individual_slot_score))



if __name__ == '__main__':
    import fire
    fire.Fire()
    


            








        








    


    