# from model.transformer import Transformer
from model.transformer import Transformer
from model.memory import Memory, LocalAdapt, RNN, NN
from model.beam_omt import Translator
import transformers
from torch.distributions.normal import Normal
from bayes.bayesian_torch.layers import LinearReparameterization
import matplotlib
matplotlib.use('Agg')
from utils.data_reader import Personas
import pickle
from utils import config
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import numpy as np 
from random import shuffle
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import math,pickle
import pprint
pp = pprint.PrettyPrinter(indent=1)

def change_all_keys(prefix, dictionary):
    for k in list(dictionary.keys()):
        new_k=prefix+str(k)
        dictionary[new_k]=dictionary.pop(k)
    return dictionary
        
def merge_dicts(*dict_args):
    """
    Given any number of dicts, shallow copy and merge into a new dict,
    precedence goes to key value pairs in latter dicts.
    """
    result = {}
    for dictionary in dict_args:
        result.update(dictionary)
    return result

def generate(model, data, persona,per_id=0,task_idx=(0,0),memory=None,adapt_model=None):
    t = Translator(model, model.vocab)
    idx=int(per_id)
    sample_dict[idx]={}
    sample_dict[idx]['persona description']=' | '.join(map(str,persona))
    sample_dict[idx]['contexts']={}
    context_dict=sample_dict[idx]['contexts']
    for j, batch in enumerate(data):
        _, _, _= model.train_one_batch(0,batch, train=False,memory=memory,Adapt_model=adapt_model,adapt=True,eval_before_train=False)
        sent_b, _ = t.translate_batch(batch,memory,adapt_model,task_idx,eval_before_train=False)
        for i in range(len(batch["target_txt"])):
            new_words = []
            for w in sent_b[i][0]:
                if w==config.EOS_idx:
                    break
                new_words.append(w)
                if len(new_words)>2 and (new_words[-2]==w):
                    new_words.pop()
            sent_beam_search = ' '.join([model.vocab.index2word[idx] for idx in new_words])
            dialog=' | '.join(map(str,batch['input_txt'][i]))
            context_dict[dialog]={}
            context_dict[dialog]['Ref']=batch["target_txt"][i]
            context_dict[dialog]['Pred']=sent_beam_search
            print("-"*50,file=gene_file)
            print("persona set",file=gene_file)
            print(pp.pformat(persona),file=gene_file)
            print("dialogue context:",file=gene_file)
            print(pp.pformat(batch['input_txt'][i]),file=gene_file)
            print("Hyp: {}".format(sent_beam_search),file=gene_file)
            print("Ref: {}".format(batch["target_txt"][i]),file=gene_file)
            print("-"*50,file=gene_file)

def do_learning(model, train_iter, val_iter, iterations, persona,per_id=0,task_idx=(0,0),memory=None,Adapt_model=None):
    for i in range(1,iterations):
        for j, d in enumerate(train_iter):
            _, _,_ = model.train_one_batch(i,d,task_idx,Adapt_model,memory,adapt=True)
    generate(model, val_iter, persona,per_id=per_id,task_idx=task_idx,memory=memory,adapt_model=Adapt_model)


p = Personas()
# * Build model, optimizer, and set states
print("Test model",config.model,flush=True)
model = Transformer(p.vocab,model_file_path=config.pt_model,is_eval=False)
# get persona map
filename = 'data/ConvAI2/test_persona_map'
with open(filename,'rb') as f:
    persona_map = pickle.load(f)

memory_path = config.memory_path 
buffer = {}
if memory_path=='None':
    memory=Memory()
else:
    memory=torch.load(memory_path,pickle_module=pickle)
binding = model.binding
bert=transformers.BertModel.from_pretrained('bert-base-uncased')
Adapt_model=LocalAdapt(binding,bert)
# * Load tasks IDs' dictionary
dict_file = open('tasks_dict.pkl','rb')
tasks_idxs = pickle.load(dict_file)
dict_file.close()

# * Store Transformer initial parameters which are learned from meta-training.
enc_dict, dec_dict, gen_dict, emb_dict = model.encoder.state_dict(), model.decoder.state_dict(),model.generator.state_dict(), model.embedding.state_dict()
enc_dict, dec_dict, gen_dict, emb_dict = change_all_keys('encoder.',enc_dict),change_all_keys('decoder.',dec_dict),change_all_keys('generator.',gen_dict),change_all_keys('embedding.',emb_dict)
trs_dict = merge_dicts(enc_dict, dec_dict, gen_dict,emb_dict)
trs_weights_original = deepcopy(trs_dict)

# * Generate to compare.
iterations = config.generate_steps
weights_original = deepcopy(model.state_dict())
tasks = p.get_personas('test')

gene_file=open(config.save_path+'generated_samples'+'.txt','a')
print('BEGIN GENERATING',flush=True)
sample_dict={}
for per in tasks:
    print('*'*60,file=gene_file)
    print('For person with id',per,file=gene_file)
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
        train_iter, val_iter,num_dialogs = p.get_data_loader(persona=per,batch_size=config.batch_size, split='test', fold=val_dial_index)
        persona=[]
        for ppp in persona_map[per]:
            persona+=ppp
        per_idx = (tasks_idxs.get(per),num_dialogs) 
        persona = list(set(persona))
        # * Finetune first.
        do_learning(model, train_iter, val_iter, iterations=iterations, persona=persona,per_id=per,task_idx=per_idx,memory=memory,Adapt_model=Adapt_model)

        model_dict = weights_original
        model.load_state_dict({ name: model_dict[name] for name in model_dict })
    memory.memory.pop(per_idx[0])

gene_file.close()
dict_file = open(config.save_path+'sample_dict.pkl','wb')
pickle.dump(sample_dict,dict_file)
dict_file.close()
