from model.common_layer import evaluate
from model.memory import Memory, LocalAdapt, RNN, NN
from model.transformer import Transformer
import transformers
import matplotlib
matplotlib.use('Agg')
from utils.data_reader import Personas
import pickle
from utils import config
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import numpy as np 
from random import shuffle
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import math, pickle
from utils import record_time as time_record
import pdb
from tensorboardX import SummaryWriter

def change_all_keys(prefix, dictionary):
    for k in list(dictionary.keys()):
        new_k=prefix+str(k)
        dictionary[new_k]=dictionary.pop(k)
    return dictionary
        
def merge_dicts(*dict_args):
    """
    Given any number of dicts, shallow copy and merge into a new dict,
    precedence goes to key value pairs in latter dicts.
    """
    result = {}
    for dictionary in dict_args:
        result.update(dictionary)
    return result

def do_learning(model, train_iter, val_iter, iterations,tasks_idx=(0,0),memory=None,Adapt_model=None):
    logger = {str(i): [] for i in range(iterations)}
    loss_rec = {i: {} for i in range(iterations)}
    print('task idx',tasks_idx,flush=True)
    # * Before inner loop update.

    loss, ppl_val, ent_b,bleu_score_b = evaluate(model, val_iter,0, model_name=config.model,ty="test",verbose=False,memory=memory,adapt_model=Adapt_model, task_idx=tasks_idx,write=True,eval_before_train=True)
    writer.add_scalars('Test',{'loss_before':loss},0)
    logger[str(0)] = [loss, ppl_val, ent_b, bleu_score_b]
    loss_rec[0]['loss_before']=loss

    adapt=True
    for i in range(1,iterations):
        print('Iteration %d:'%i,flush=True)
        # * Evaluate before support sets training.
        if not config.only_train_query:
            loss, ppl_val, ent_b,bleu_score_b = evaluate(model, val_iter,i, model_name=config.model,ty="test",verbose=False,memory=memory,adapt_model=Adapt_model, task_idx=tasks_idx,write=False,eval_before_train=True)
            loss_rec[i]['loss_before']=loss
            writer.add_scalars('Test',{'loss_before':loss},i)

        # * Inner loop update to get task-specific params.
        if not config.only_train_query:
            for j, d in enumerate(train_iter):
                if config.use_memory:
                    _, _, _ = model.train_one_batch(i,d,tasks_idx,Adapt_model,memory,adapt=adapt) # Train default True
                else:
                    _,_ = model.train_one_batch(i,d,tasks_idx)

        # * Outer loop report error for query sets.
        loss, ppl_val, ent_b, bleu_score_b = evaluate(model, val_iter, i, model_name=config.model,ty="test",verbose=False,adapt=adapt, memory=memory,adapt_model=Adapt_model,task_idx=tasks_idx,write=True) # Train=False in the train_one_batch
        loss_rec[i]['loss_after']=loss
        writer.add_scalars('Test',{'loss_after':loss},i)
        print('loss %.4f, ppl %.4f, bleu %.4f'%(loss, ppl_val, bleu_score_b),flush=True)
        logger[str(i)] = [loss, ppl_val, ent_b, bleu_score_b]

    return logger,loss_rec

p = Personas()
# Build model, optimizer, and set states
print("Test model",config.model,flush=True)
writer = SummaryWriter(log_dir=config.save_tb)

timing=time_record.Time()
timing.begin('load_model')
meta_net = Transformer(p.vocab,model_file_path=config.pt_model,is_eval=False)
timing.end('load_model')

if config.use_memory:
    timing.begin('load_memory')
    memory_path = config.memory_path 
    buffer = {}
    if memory_path=='None':
        memory=Memory()
    else:
        memory=torch.load(memory_path,pickle_module=pickle)
    timing.end('load_memory')
    binding = meta_net.binding
    bert=transformers.BertModel.from_pretrained('bert-base-uncased')
    Adapt_model=LocalAdapt(binding,bert)
else:
    memory=None
    Adapt_model=None
    binding=None

fine_tune = []
loss_record=[]
iter_per_task = []
iterations = config.ft_iters
test_tasks = p.get_personas('test')

# * Load tasks IDs' dictionary
dict_file = open('tasks_dict.pkl','rb')
tasks_idxs = pickle.load(dict_file)
dict_file.close()

weights_original = deepcopy(meta_net.state_dict())

# ! For each person
for per in test_tasks:
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
        if config.fix_dialnum_train:
            train_iter, val_iter,num_dialogs = p.get_balanced_loader(persona=per,batch_size=config.batch_size, split='test', fold=val_dial_index, dial_num=config.k_shot) # k_shot default 20
        else:
            train_iter, val_iter, num_dialogs = p.get_data_loader(persona=per,batch_size=config.batch_size, split='test', fold=val_dial_index)
        
        per_idx = (tasks_idxs.get(per),num_dialogs) 

        logger, loss_rec = do_learning(meta_net, train_iter, val_iter, iterations,per_idx,memory, Adapt_model)
        fine_tune.append(logger)
        loss_record.append(loss_rec)

        # * Reset Transformer parameters to initial parameters. keep NN, RNN new parameters.
        model_dict = weights_original
        meta_net.load_state_dict({ name: model_dict[name] for name in model_dict})
    memory.memory.pop(per_idx[0])


if config.fix_dialnum_train:
    config.save_path = config.save_path+'_fix_dialnum_'+str(config.k_shot)+'_'
pickle.dump( [fine_tune,iterations], open( config.save_path+'evaluation.p', "wb" ) )
pickle.dump( [loss_record,iterations], open( config.save_path+'loss_record.p', "wb" ) )
measure = ["LOSS","PPL","Entl_b","Bleu_b"]
temp = {m: [[] for i in list(range(0,config.ft_iters))] for m in measure}
for expe in fine_tune:
    for idx_measure,m in enumerate(measure):
        for j,i in enumerate(list(range(0,config.ft_iters))):
            temp[m][j].append(expe[str(i)][idx_measure])  ## position 1 is ppl

fig = plt.figure(figsize=(20,80))

log = {}
for id_mes, m in enumerate(measure):
    ax1 = fig.add_subplot(331 + id_mes)
    x = range(len(list(np.array(temp[m]).mean(axis=1))))
    y = np.array(temp[m]).mean(axis=1) # Take mean for all test persons.
    e = np.array(temp[m]).std(axis=1)
    plt.errorbar(x, y, e)
    plt.title(m)
    log[m] = y

plt.savefig(config.save_path+'epoch_vs_ppl.pdf')
print("----------------------------------------------------------------------")
print("epoch\tloss\tPeplexity\tEntl_b\tBleu_b\n")
for j,i in enumerate(list(range(0,config.ft_iters))):
    print("{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(i,log['LOSS'][j],math.exp(log['LOSS'][j]),log['Entl_b'][j],log['Bleu_b'][j]))
print("----------------------------------------------------------------------")
with open(config.save_path+'result.txt', 'w', encoding='utf-8') as f:
    f.write("epoch\tloss\tPeplexity\tEntl_b\tBleu_b\n")
    for j,i in enumerate(list(range(0,config.ft_iters))):
        f.write("{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(i,log['LOSS'][j],math.exp(log['LOSS'][j]),log['Entl_b'][j],log['Bleu_b'][j]))
