from utils import config
from model.memory import Memory, LocalAdapt, RNN, NN
from model.transformer import Transformer 
from model.common_layer import NoamOpt

import transformers
from utils.data_reader import Personas
import matplotlib
matplotlib.use('Agg')
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np 
from random import shuffle
import random
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import math
from tensorboardX import SummaryWriter
import time, pickle
from utils import record_time as time_record

random.seed(123)
def merge_dicts(*dict_args):
    """
    Given any number of dicts, shallow copy and merge into a new dict,
    precedence goes to key value pairs in latter dicts.
    """
    result = {}
    for dictionary in dict_args:
        result.update(dictionary)
    return result

def make_infinite(dataloader):
    while True:
        for x in dataloader:
            yield x

def make_infinite_list(personas):
    while True:
        shuffle(personas)
        for x in personas:
            yield x

def do_learning(model, train_iter, iterations):
    p, l = [],[]
    for i in range(iterations):
        loss, ppl, _ = model.train_one_batch(train_iter.__next__())
        l.append(loss)
        p.append(ppl)
    return loss

def do_learning_early_stop(model, train_iter, val_iter, iterations, strict=1):
    b_loss, b_ppl = 100000, 100000
    best = deepcopy(model.state_dict())
    cnt = 0
    idx = 0
    for _ ,_ in enumerate(range(iterations)):
        train_l, train_p = [], []
        for d in train_iter:
            t_loss, t_ppl, _ = model.train_one_batch(d)
            train_l.append(t_loss)
            train_p.append(t_ppl)

        n_loss, n_ppl = do_evaluation(model, val_iter)
        ## early stopping
        if(n_ppl <= b_ppl):
            b_ppl = n_ppl
            b_loss = n_loss
            cnt = 0
            idx += 1
            best = deepcopy(model.state_dict()) ## save best weights 
        else: 
            cnt += 1
        if(cnt > strict): break
    
    ## load the best model 
    model.load_state_dict({ name: best[name] for name in best })

    return (torch.mean(train_l), torch.mean(train_p), b_loss, b_ppl), idx

def do_learning_fix_step(iters,model, train_iter, val_iter,iterations, tasks_idx=0, memory=None,Adapt_model=None,test=False):
    val_p = []
    val_p_list = []
    val_loss = 0
    val_other_loss = 0
    for _ ,_ in enumerate(range(iterations)):
        for d in train_iter:
            if config.use_memory:
                t_loss, t_ppl, other_loss = model.train_one_batch(iters,d,tasks_idx,Adapt_model,memory) # train default True
            else:
                t_loss, t_ppl = model.train_one_batch(iters,d,tasks_idx)
        if test:
            _, test_ppl = do_evaluation(iters,model, val_iter)
            val_p_list.append(test_ppl)
    if test:
        return val_p_list
    # * Report query set loss.
    else: 
        if config.use_memory:
            for d in val_iter:
                t_loss, t_ppl, t_other_loss = model.train_one_batch(iters,d,tasks_idx,Adapt_model,memory,train= False)
                val_loss +=t_loss
                val_p.append(t_ppl)
                val_other_loss += t_other_loss
            return val_loss, np.mean(val_p), val_other_loss
        else:
            for d in val_iter:
                t_loss, t_ppl = model.train_one_batch(iters,d,tasks_idx,train= False)
                val_loss+=t_loss
                val_p.append(t_ppl)
            return val_loss, np.mean(val_p)

def do_evaluation(iters,model, test_iter,tasks_idx=0, memory=None,eval_before=False): # eval_before means evaluate before training on support sets.
    p, l = [],[]
    if config.use_memory:
        for batch in test_iter:
            loss, ppl,_  = model.train_one_batch(iters,batch, tasks_idx,Adapt_model, memory,train=False,eval_before_train=eval_before)
            l.append(loss.item())
            p.append(ppl)
        return np.mean(l), np.mean(p)
    else:
        for batch in test_iter:
            loss, ppl = model.train_one_batch(iters,batch, train=False)
            l.append(loss.item())
            p.append(ppl)
        return np.mean(l), np.mean(p)

def change_all_keys(prefix, dictionary):
    for k in list(dictionary.keys()):
        new_k=prefix+str(k)
        dictionary[new_k]=dictionary.pop(k)
    return dictionary
        
# TODO: ================================================= Meta Training ================================================
st=time.time()
p = Personas()
writer = SummaryWriter(log_dir=config.save_tb)
sp=time.time()
if config.use_retrieval:
    print('Use retrieval method (No adaptation)',flush=True)
print('writer and persona',sp-st,flush=True)

# * Build memory and other models.
if (config.load_frompretrain=='None'):
    meta_net = Transformer(p.vocab)
    bert=transformers.BertModel.from_pretrained('bert-base-uncased') 
    if config.use_memory:
        rnn = meta_net.rnn 
        binding = meta_net.binding
        memory=Memory()
        Adapt_model=LocalAdapt(binding,bert,BP_adapt_model=False) 
    begin_iter=0

# * Build model, optimizer, and set states
if not (config.load_frompretrain=='None'):
    if config.use_memory:
        # Load memory we saved before.
        memory_path = config.load_memory 
        if memory_path!='None':
            memory = torch.load(memory_path)
        else:
            memory=Memory()
    # Load model we saved before.
    bert=transformers.BertModel.from_pretrained('bert-base-uncased')
    meta_net = Transformer(p.vocab,model_file_path=config.load_frompretrain,is_eval=False)
    rnn = meta_net.rnn 
    binding = meta_net.binding
    if config.setting=='setting4' or config.setting=='setting8':
        Adapt_model=LocalAdapt(binding,bert,BP_adapt_model=False)
    elif config.setting=='setting9':
        Adapt_model=LocalAdapt(meta_net,bert,BP_adapt_model=False)
    else:
        Adapt_model=LocalAdapt(binding,BP_adapt_model=False)
    begin_iter=meta_net.load_iter
    print('Begin training from iteration %d',begin_iter,flush=True)

if config.use_memory:
    Other_params=[]
    All_params=meta_net.parameters()
    for pname, param in meta_net.named_parameters():
        if pname[:3]=='rnn' or pname[:7]=='binding':
            print(pname,flush=True)
            Other_params+=[param]
    # print(Other_params,flush=True)
    params_id=list(map(id,Other_params))
    Trans_params=list(filter(lambda param: id(param) not in params_id,All_params))
    print('Number of parameters:','Trs:',len(Trans_params),'Others:',len(Other_params),'All:',len(list(All_params)),flush=True)

# * Only update transformer parameters on query sets (MAML style)
if config.meta_optimizer=='sgd':
    meta_optimizer = torch.optim.SGD(Trans_params, lr=config.meta_lr)
elif config.meta_optimizer=='adam':
    meta_optimizer = torch.optim.Adam(Trans_params, lr=config.meta_lr)
elif config.meta_optimizer=='noam':
    meta_optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(Trans_params, lr=0, betas=(0.9, 0.98), eps=1e-9))
else:
    raise ValueError

meta_batch_size = config.meta_batch_size
st = time.time()
# TODO: Map all task IDs to dict.
test_tasks = p.get_personas('test')
train_tasks = p.get_personas('train')
valid_tasks = p.get_personas('valid')
#print('train\n',train_tasks,'valid\n',valid_tasks,'test_tasks\n',test_tasks,flush=True)
all_tasks = train_tasks + test_tasks + valid_tasks
tasks_idxs = dict(zip(all_tasks,list(range(len(all_tasks)))))
dict_file = open('tasks_dict.pkl','wb')
pickle.dump(tasks_idxs,dict_file)
dict_file.close()

sp = time.time()
print('Load dataset time',sp-st,flush=True)
print('Total number of training persons is %d'%len(train_tasks),flush=True)
#tasks_loader = {t: p.get_data_loader(persona=t,batch_size=config.batch_size, split='train') for t in tasks}
tasks_iter = make_infinite_list(train_tasks)

# meta early stop
patience = 50
if config.fix_dialnum_train:
    patience = config.patience
best_loss = 10000000
stop_count = 0

print('whole model dict keys number:\n',len(meta_net.state_dict().keys()),flush=True)
# * Main loop for training.
for meta_iteration in range(begin_iter,config.epochs):
    timing=time_record.Time()
    timing.begin('one_epoch')
    print('\n','-'*100,flush=True)
    print('For %dth epoch'%(meta_iteration),flush=True)
    
    # store original weights to make the update
    if config.use_memory:
        timing.begin('copy_trs_params')
        enc_dict, dec_dict, gen_dict, emb_dict = meta_net.encoder.state_dict(), meta_net.decoder.state_dict(),meta_net.generator.state_dict(), meta_net.embedding.state_dict()
        enc_dict, dec_dict, gen_dict, emb_dict = change_all_keys('encoder.',enc_dict),change_all_keys('decoder.',dec_dict),change_all_keys('generator.',gen_dict),change_all_keys('embedding.',emb_dict)
        trs_dict = merge_dicts(enc_dict, dec_dict, gen_dict,emb_dict)
        trs_weights_original = deepcopy(trs_dict)
        timing.end('copy_trs_params')
    else:
        whole_model_dict = meta_net.state_dict()
        weights_original=deepcopy(whole_model_dict)
    
    train_loss_before = []
    train_loss_meta = []
    train_other_loss_meta = []

    # *loss accumulate from a batch of tasks
    batch_loss=0
    if config.use_memory:
        batch_rnn_loss = 0
        batch_other_loss = 0
    # TODO: Meta-training for different persons.
    for _ in range(meta_batch_size): #! 16 tasks per epoch.
        # Get task
        per = tasks_iter.__next__()
        if config.fix_dialnum_train:
            train_iter, val_iter, num_dialogs = p.get_balanced_loader(persona=per,batch_size=config.batch_size, split='train', dial_num=config.k_shot)
        else:
            train_iter, val_iter, num_dialogs = p.get_data_loader(persona=per,batch_size=config.batch_size, split='train')

        per_idx = (tasks_idxs.get(per),num_dialogs) # * 0: task index, 1: #dialog per task

        # * before first update
        timing.begin('first_eval')
        if config.use_memory:
            v_loss, v_ppl = do_evaluation(meta_iteration, meta_net, val_iter, per_idx, memory,eval_before=True)
        else:
            v_loss, v_ppl = do_evaluation(meta_iteration, meta_net, val_iter, per_idx) 
        timing.end('first_eval')
        train_loss_before.append(v_loss)

        # TODO: Update model parameter through support sets training. 
        # ! Update model parameter through support sets training.   
        if config.use_memory:
            timing.begin('one_episode')
            val_loss, val_ppl,val_other_loss = do_learning_fix_step(meta_iteration,meta_net, train_iter, val_iter, config.meta_iteration, per_idx, memory,Adapt_model)
            timing.end('one_episode')
            train_loss_meta.append(val_loss.item())
            batch_loss += val_loss
            if config.reconstruct:
                batch_loss += val_loss + 0.5 * val_other_loss

            # * Reset Transformer params.
            timing.begin('reset_trs_params')
            NN_new_state = meta_net.binding.state_dict()
            NN_new_state = change_all_keys('binding.',NN_new_state)
            RNN_new_state = meta_net.rnn.state_dict()
            RNN_new_state = change_all_keys('rnn.',RNN_new_state)
            model_dict = merge_dicts(NN_new_state,RNN_new_state,trs_weights_original)
            meta_net.load_state_dict({ name: model_dict[name] for name in model_dict})
            timing.end('reset_trs_params')
        else:
            timing.begin('one_episode')
            val_loss, val_ppl = do_learning_fix_step(meta_iteration,meta_net, train_iter, val_iter, config.meta_iteration, per_idx)
            train_loss_meta.append(val_loss.item())
            batch_loss+=val_loss
            timing.end('one_episode')

            # reset 
            timing.begin('reset_trs_params')
            meta_net.load_state_dict({ name: weights_original[name] for name in weights_original })
            timing.end('reset_trs_params')
            #print('\nmaml: meta_net',meta_net.state_dict().keys(),flush=True)
    # * Write to tensorboard
    writer.add_scalars('Train', {'train_loss_before': np.mean(train_loss_before)}, meta_iteration)
    writer.add_scalars('Train', {'train_loss_meta': np.mean(train_loss_meta)}, meta_iteration)
    print('-'*100,flush=True)
    if config.reconstruct:
        writer.add_scalars('Train', {'train_reconstruct_loss_meta': np.mean(train_other_loss_meta)}, meta_iteration)
        print('\nTraining binding loss on query set is %.4f'%np.mean(train_other_loss_meta),flush=True)
    print('\nTraining Transformer loss on query set is %.4f'%np.mean(train_loss_meta),flush=True)
    print('\nTraining before loss on query set is %.4f'%np.mean(train_loss_before),flush=True)
    print('-'*100,flush=True)
    
    # ! Meta Update for query set.
    if(config.meta_optimizer=='noam'):
        meta_optimizer.optimizer.zero_grad()
    else:
        meta_optimizer.zero_grad()

    if config.use_memory:
        for name, parms in meta_net.named_parameters():	
            if parms.grad!=None:
                if name[:7]=='binding' or name[:3]=='rnn':
                    parms.grad.data.zero_()
    timing.begin('outer_updating')
    batch_loss/=meta_batch_size
    batch_loss=batch_loss.cuda()
    batch_loss.backward()
    parameters = [p for p in meta_net.parameters() if p.grad is not None and p.requires_grad]
    if len(parameters) == 0:
        total_norm = 0.0
    else:
        device = parameters[0].grad.device
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()).to(device) for p in parameters]), 2.0).item()
    print('parameters norm',total_norm,flush=True)
    # clip gradient
    nn.utils.clip_grad_norm_(meta_net.parameters(), config.max_grad_norm)
    meta_optimizer.step()
    timing.end('outer_updating')

    if config.use_memory:
        if meta_iteration % 5000==0 and meta_iteration!=0:
            print('save memory',flush=True)
            meta_net.save_memory(meta_iteration, memory)
        if config.only_store_cur_task:
            if config.memory_refresh:
                memory=Memory() # Refresh memory.
        elif meta_iteration%200==0 and meta_iteration!=0:
            if config.memory_refresh:
                memory=Memory()

    timing.end('one_epoch')
    if config.print_time:
        timing.print_all()

# TODO: ----------------------------------Meta-Validating----------------------------------------------
    # * Meta-Evaluation: Train on valid-support sets in the inner loop, no BP for outer loop. 
    if meta_iteration % 10 == 0 and meta_iteration!=0:
        timing.begin('Meta-validating')
        val_loss_before = []
        val_loss_meta = []
        val_all_loss_meta = []
        whole_model_dict = meta_net.state_dict()
        all_weights_original = deepcopy(whole_model_dict)
        valid_tasks = p.get_personas('valid')
        for idx ,per in enumerate(valid_tasks):
            if config.fix_dialnum_train:
                train_iter, val_iter, val_num_dialogs = p.get_balanced_loader(persona=per,batch_size=config.batch_size, split='valid', fold=0, dial_num=config.k_shot)
            else:
                train_iter, val_iter, val_num_dialogs = p.get_data_loader(persona=per,batch_size=config.batch_size, split='valid', fold=0)

            val_per_idx = (tasks_idxs.get(per),val_num_dialogs)
            # zero shot result
            if config.use_memory:
                loss, ppl = do_evaluation(meta_iteration,meta_net, val_iter,val_per_idx,memory,eval_before=True)
            else:
                loss, ppl = do_evaluation(meta_iteration,meta_net, val_iter,val_per_idx)

            val_loss_before.append(loss)

            # mate tuning
            if config.use_memory:
                val_loss, val_ppl,val_all_loss = do_learning_fix_step(meta_iteration,meta_net, train_iter, val_iter, config.meta_iteration,val_per_idx,memory,Adapt_model)
                val_all_loss_meta.append(val_all_loss.item())
            else:
                val_loss,val_ppl=do_learning_fix_step(meta_iteration,meta_net, train_iter, val_iter, config.meta_iteration)
            val_loss_meta.append(val_loss.item())
            # updated result
            model_dict = all_weights_original
            # reset 
            meta_net.load_state_dict({ name: model_dict[name] for name in model_dict})

        writer.add_scalars('Valid', {'val_loss_before': np.mean(val_loss_before)}, meta_iteration)
        writer.add_scalars('Valid', {'val_loss_meta': np.mean(val_loss_meta)}, meta_iteration)
        print('*'*100,flush=True)
        print('Iteration %d, the mean Transformer loss on validation set is %.4f'%(meta_iteration,np.mean(val_loss_meta)),flush=True)
        print('Iteration %d, the mean Transformer loss before on validation set is %.4f'%(meta_iteration,np.mean(val_loss_before)),flush=True)
        print('*'*100, flush=True)

        #check early stop
        if np.mean(val_loss_meta)< best_loss:
            best_loss = np.mean(val_loss_meta)
            stop_count = 0
            if best_loss<5:
                meta_net.save_model(best_loss,meta_iteration)
        else:
            stop_count+=1
            if stop_count>patience:
                break
        timing.end('Meta-validating')
        if config.print_time:
            timing.print_diff('Meta-validating')
