### TAKEN FROM https://github.com/kolloldas/torchnlp
'''
Using guidance embedding from memory combined with MAML. 
'''
from model.common_layer import EncoderLayer, DecoderLayer, MultiHeadAttention, Conv, PositionwiseFeedForward, LayerNorm , _gen_bias_mask ,_gen_timing_signal, share_embedding, LabelSmoothing, NoamOpt, _get_attn_subsequent_mask,  get_input_from_batch, get_output_from_batch
from model.memory import RNN, NN
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from utils import config
import random
from numpy import random 
import os
import pprint
from tqdm import tqdm
import os
import time
from torch.autograd import Variable 
from utils import record_time as time_record


random.seed(123)
torch.manual_seed(123)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(123)

#torch.autograd.set_detect_anomaly(True) # Cost much time. Can enable this when do toy testing.
class Encoder(nn.Module):
    """
    A Transformer Encoder module.
    Inputs should be in the shape [batch_size, length, hidden_size]
    Outputs will have the shape [batch_size, length, hidden_size]
    Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf
    """
    def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,
                 filter_size, max_length=1000, input_dropout=0.0, layer_dropout=0.0, 
                 attention_dropout=0.0, relu_dropout=0.0, use_mask=False, universal=False):
        """
        Parameters:
            embedding_size: Size of embeddings
            hidden_size: Hidden size
            num_layers: Total layers in the Encoder
            num_heads: Number of attention heads
            total_key_depth: Size of last dimension of keys. Must be divisible by num_head
            total_value_depth: Size of last dimension of values. Must be divisible by num_head
            output_depth: Size last dimension of the final output
            filter_size: Hidden size of the middle layer in FFN
            max_length: Max sequence length (required for timing signal)
            input_dropout: Dropout just after embedding
            layer_dropout: Dropout for each layer
            attention_dropout: Dropout probability after attention (Should be non-zero only during training)
            relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training)
            use_mask: Set to True to turn on future value masking
        """
        
        super(Encoder, self).__init__()
        self.universal = universal
        self.num_layers = num_layers
        self.timing_signal = _gen_timing_signal(max_length, hidden_size)
        
        if(self.universal):  
            ## for t
            self.position_signal = _gen_timing_signal(num_layers, hidden_size)

        params =(hidden_size, 
                 total_key_depth or hidden_size,
                 total_value_depth or hidden_size,
                 filter_size, 
                 num_heads, 
                 _gen_bias_mask(max_length) if use_mask else None,
                 layer_dropout, 
                 attention_dropout, 
                 relu_dropout)
        
        self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)
        if(self.universal):
            self.enc = EncoderLayer(*params)
        else:
            self.enc = nn.ModuleList([EncoderLayer(*params) for _ in range(num_layers)])
        
        self.layer_norm = LayerNorm(hidden_size)
        self.input_dropout = nn.Dropout(input_dropout)
        
        if(config.act):
            self.act_fn = ACT_basic(hidden_size)
            self.remainders = None
            self.n_updates = None

    def forward(self, inputs, mask):
        #mask=mask.cuda()
        #Add input dropout
        timing=time_record.Time()
        timing.begin('enc_input_dropout')
        x = self.input_dropout(inputs)
        timing.end('enc_input_dropout') 
        # Project to hidden size
        timing.begin('enc_project')
        x = self.embedding_proj(x)
        timing.end('enc_project')

        if(self.universal):
            if(config.act):
                x, (self.remainders, self.n_updates) = self.act_fn(x, inputs, self.enc, self.timing_signal, self.position_signal, self.num_layers)
                y = self.layer_norm(x)
            else:
                for l in range(self.num_layers):
                    x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
                    x += self.position_signal[:, l, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
                    x = self.enc(x, mask=mask)
                y = self.layer_norm(x)
        else:
            # Add timing signal
            timing.begin('enc_timing_signal')
            x = x + self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
            timing.end('enc_timing_signal')

            for i in range(self.num_layers):
                x = self.enc[i](x, mask)
            timing.begin('enc_layernorm')
            y = self.layer_norm(x)
            timing.end('enc_layernorm')
            if config.print_time:
                timing.print_all()
        return y

class Decoder(nn.Module):
    """
    A Transformer Decoder module. 
    Inputs should be in the shape [batch_size, length, hidden_size]
    Outputs will have the shape [batch_size, length, hidden_size]
    Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf
    """
    def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,
                 filter_size, max_length=config.max_enc_steps, input_dropout=0.0, layer_dropout=0.0, 
                 attention_dropout=0.0, relu_dropout=0.0, universal=False):
        """
        Parameters:
            embedding_size: Size of embeddings
            hidden_size: Hidden size
            num_layers: Total layers in the Encoder
            num_heads: Number of attention heads
            total_key_depth: Size of last dimension of keys. Must be divisible by num_head
            total_value_depth: Size of last dimension of values. Must be divisible by num_head
            output_depth: Size last dimension of the final output
            filter_size: Hidden size of the middle layer in FFN
            max_length: Max sequence length (required for timing signal)
            input_dropout: Dropout just after embedding
            layer_dropout: Dropout for each layer
            attention_dropout: Dropout probability after attention (Should be non-zero only during training)
            relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training)
        """
        
        super(Decoder, self).__init__()
        self.universal = universal
        self.num_layers = num_layers
        self.timing_signal = _gen_timing_signal(max_length, hidden_size)
        
        if(self.universal):  
            ## for t
            self.position_signal = _gen_timing_signal(num_layers, hidden_size)

        self.mask = _get_attn_subsequent_mask(max_length)

        params =(hidden_size, 
                 total_key_depth or hidden_size,
                 total_value_depth or hidden_size,
                 filter_size, 
                 num_heads, 
                 _gen_bias_mask(max_length), # mandatory
                 layer_dropout, 
                 attention_dropout, 
                 relu_dropout)
        
        self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)
        if(self.universal):
            self.dec = DecoderLayer(*params)
        else:
            self.dec = nn.Sequential(*[DecoderLayer(*params) for l in range(num_layers)])
        
        self.layer_norm = LayerNorm(hidden_size)
        self.input_dropout = nn.Dropout(input_dropout)
        if(config.act):
            self.act_fn = ACT_basic(hidden_size)
            self.remainders = None
            self.n_updates = None

    def forward(self, inputs, encoder_output, mask):
        timing=time_record.Time()
        mask_src, mask_trg = mask
        #mask_src,mask_trg=mask_src.cuda(),mask_trg.cuda()
        dec_mask = torch.gt(mask_trg + self.mask[:, :mask_trg.size(-1), :mask_trg.size(-1)], 0)
        #Add input dropout
        x = self.input_dropout(inputs)
        # Project to hidden size
        x = self.embedding_proj(x)
        
        if(self.universal):
            if(config.act):
                x, attn_dist, (self.remainders,self.n_updates) = self.act_fn(x, inputs, self.dec, self.timing_signal, self.position_signal, self.num_layers, encoder_output, decoding=True)
                y = self.layer_norm(x)

            else:
                x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
                for l in range(self.num_layers):
                    x += self.position_signal[:, l, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
                    x, _, attn_dist, _ = self.dec((x, encoder_output, [], (mask_src,dec_mask)))
                y = self.layer_norm(x)
        else:
            # Add timing signal
            x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
            
            # Run decoder
            timing.begin('dec_decoding')
            y, _, attn_dist, _ = self.dec((x, encoder_output, [], (mask_src,dec_mask)))
            timing.end('dec_decoding')

            # Final layer normalization
            timing.begin('dec_layernorm')
            y = self.layer_norm(y)
            timing.end('dec_layernorm')
            if config.print_time:
                timing.print_all()
        return y, attn_dist


class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
        self.p_gen_linear = nn.Linear(config.hidden_dim, 1)

    def forward(self, x, attn_dist=None, enc_batch_extend_vocab=None, extra_zeros=None, temp=1, beam_search=False):
        timing=time_record.Time()
        if config.pointer_gen:
            p_gen = self.p_gen_linear(x)
            p_gen = torch.sigmoid(p_gen)
        timing.begin('gene_proj')
        logit = self.proj(x)
        timing.end('gene_proj')

        if config.pointer_gen:
            vocab_dist = F.softmax(logit/temp, dim=2)
            #vocab_dist = torch.exp(F.log_softmax(logit/temp, dim=2))
            vocab_dist_ = p_gen * vocab_dist

            attn_dist = F.softmax(attn_dist/temp, dim=-1)
            #attn_dist = torch.exp(F.log_softmax(attn_dist/temp, dim=-1))
            attn_dist_ = (1 - p_gen) * attn_dist            
            enc_batch_extend_vocab_ = torch.cat([enc_batch_extend_vocab.unsqueeze(1)]*x.size(1),1) ## extend for all seq
            if (beam_search):
                enc_batch_extend_vocab_ = torch.cat([enc_batch_extend_vocab_[0].unsqueeze(0)]*x.size(0),0) ## extend for all seq
            timing.begin('gene_logit')
            logit = torch.log(vocab_dist_.scatter_add(2, enc_batch_extend_vocab_, attn_dist_))
            timing.end('gene_logit')
            if config.print_time:
                timing.print_all() 
            return logit
        else:
            return F.log_softmax(logit,dim=-1)

def l2_dist(x,y):
    diff = x-y
    #print('diff in l2',diff.size(),flush=True)
    return torch.sum(diff**2)

def bind_loss(pred_emb, tgt_emb): # ? how about MSE loss?
    if config.use_l2: # Use l2 norm distance
        return l2_dist(pred_emb, tgt_emb)
    else: # Use cosine similarity
        cos_distance = nn.CosineSimilarity(dim=1, eps=1e-6) # default dim=1
        return  cos_distance(pred_emb, tgt_emb)

# * guidance_embedding--> perform memory_storing --> perform local adaptation.
def meet_adapt_condition(iters, train, adapt):
    if config.load_frompretrain != 'None':
        return True

    if config.test: # Always True in the meta-testing part.
        return True
    if iters>=(config.min_iter+3*config.step_apart):
        if not config.adapt_support:
            if train: return False
            else: return True
        return True
    else:
        return False

def meet_save_memory_condition(iters,train):
    if config.load_frompretrain != 'None':
        return True
    if config.test and train: # Always True when meta-testing but only on support set.
        return True
    if iters >= config.min_iter+2*config.step_apart:
        if not config.store_query:
            if train:
                return True
            else:
                return False
        else:
            return True
    else:
        return False
def meet_use_guidence_condition(iters,train):
    if config.load_frompretrain != 'None' or config.test:
        return True
    if iters>=config.min_iter:
        return True 
    else:
        return False

def meet_update_other_param_condition(iters,train):
    if iters%10==0 and iters!=0 and train and config.inner_alter_update:
        return True
    else:
        return False

class Transformer(nn.Module):

    def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab,config.preptrained)
        self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
            
        self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth,total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
        self.generator = Generator(config.hidden_dim,self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight

        if config.use_memory:
            self.rnn = RNN() # Guidance embedding.
            self.binding = NN() # Key-value binding.

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        # print('Transformer parameters:',list(self.named_parameters()),flush=True)
        # i=0
        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        # print('Transformer parameters:',list(self.named_parameters()),flush=True)
        # i=0
        # for name, param in self.named_parameters():
        #     print(i,name,flush=True)
        #     i +=1
        all_param = list(self.parameters())
        print('all parameters number',len(all_param),flush=True)
        if model_file_path is not None:
            print("Loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            #print(state.keys(),flush=True)
            print("Begin Training from %d epoch"%state['iter'])
            self.load_iter=state['iter']
            print("LOSS",state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])

            if not config.from_pretrained:
                self.generator.load_state_dict(state['generator_dict'])
                self.embedding.load_state_dict(state['embedding_dict'])
            if config.use_memory and not config.from_pretrained:
                self.rnn.load_state_dict(state['rnn_dict'])
                self.binding.load_state_dict(state['binding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        Other_params=[]
        All_params=self.parameters()
        for pname, param in self.named_parameters():
            if pname[:3]=='rnn' or pname[:7]=='binding':
                print(pname,flush=True)
                Other_params+=[param]
        params_id=list(map(id,Other_params))
        Trans_params=list(filter(lambda param: id(param) not in params_id,All_params))

        model_param=[\
                {"params":Other_params,'lr':config.other_lr},
                {"params":Trans_params,'lr':config.lr}]

        self.other_optimizer = torch.optim.Adam(Other_params,lr=config.other_lr)
        self.optimizer = torch.optim.Adam(model_param, lr=config.lr)
        if(config.noam):
            self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(model_param, lr=0, betas=(0.9, 0.98), eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(model_param, lr=config.lr)

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
            if config.use_memory:
                self.rnn, self.binding = self.rnn.cuda(),self.binding.cuda()
        if is_eval:
            self.encoder = self.encoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()
            if config.use_memory:
                self.rnn = self.rnn.eval()
                self.binding = self.binding.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.memory_dir=config.memory_path
        if not os.path.exists(self.memory_dir):
            os.makedirs(self.memory_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iters):
        state = {
            'iter': iters,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        if config.use_memory:
            state['rnn_dict']=self.rnn.state_dict()
            state['binding_dict'] = self.binding.state_dict()
        model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}'.format(iters,running_avg_ppl) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def save_memory(self,epoch,memory=None):
        """Save memory to the specified location."""
        memory_save_path = os.path.join(self.memory_dir,'memory_{}'.format(epoch))
        if memory is not None:
            torch.save(memory,memory_save_path)

# TODO: Training function used in the inner loop update on support sets or calculate loss on query sets.
    def train_one_batch(self,iters, batch,task_idx=(0,0),Adapt_model=None, memory=None, train=True,adapt=False,eval_before_train=False):
        # * Conditions to make changes for model training.
        if config.use_memory:
            perform_adapt = meet_adapt_condition(iters,train,adapt)
            perform_storing = meet_save_memory_condition(iters,train)
            use_guidence = meet_use_guidence_condition(iters,train)
            only_update_other = meet_update_other_param_condition(iters,train) # only update RNN, NN params and fix Transformer params.

        # * Get input and target.
        timing=time_record.Time()
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        self.optimizer.zero_grad()

        self.optimizer.zero_grad()
        if only_update_other:
            self.other_optimizer.zero_grad()

        # TODO: Encoding
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        #print('trs: embedding:',self.embedding(enc_batch),flush=True)
        timing.begin('enc_embedding')
        enc_embs=self.embedding(enc_batch)
        # print('enc word embeddings',enc_embs,flush=True)
        timing.end('enc_embedding')
        timing.begin('enc_encoding') 
        encoder_outputs = self.encoder(enc_embs,mask_src)
        timing.end('enc_encoding')


        # * Store training samples information into the memory.
        if config.use_memory:  # Currently, store all samples in the support and query sets.
            if use_guidence:
                src_embs=Adapt_model.get_keys(enc_batch,mask_src)
                mask_tgt = dec_batch.data.eq(config.PAD_idx).unsqueeze(1).cuda()

                timing.begin('rnn_get_gui')
                # ! Get guidance embeddings from BERT and RNN.
                value_embs=Adapt_model.get_guidance(dec_batch,mask_tgt,self.rnn)
                timing.end('rnn_get_gui')

            if memory:
                # * Get pred embeddings from NN binding
                if perform_adapt:
                    timing.begin('bind_get_pred')
                    pred_embs = self.binding(src_embs) # Compute loss need.
                    timing.begin('bind_get_pred')
                    # print('model/transformer: size src',src_embs.size(),'gui',gui_embs.size(),'pred',pred_embs.size(),flush=True)
                    binding_loss = bind_loss(pred_embs, value_embs)
                    # print(':576,binding',binding_loss,flush=True)

                if perform_storing:
                    # print('model/transformer: size src',src_embs.size(),'gui',gui_embs.size(),flush=True)
                    # * Push keys and values into the corrsponding memory slot.
                    keys, values = src_embs.detach(),value_embs.detach()
                    timing.begin('push_memory')
                    memory.push(task_idx,keys,values)
                    timing.end('push_memory')

                # Compute loss and update memory functions.
                task_id, num_dialogs = task_idx

        # TODO: Decoding 
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) # Not consider EOS token.
        # ! dec_batch_shift contain sos_token and original text ids. no EOS.

        if config.use_memory and use_guidence:
            gui_token = torch.LongTensor([config.GUI_idx] * enc_batch.size(0)).unsqueeze(1)
            gui_token = gui_token.cuda()
            # dec_batch_shift = torch.cat((gui_token,dec_batch_shift),1)
            # * Add guidance embedding before encoder outputs.
            enc_batch1=torch.cat((gui_token,enc_batch),1)
            mask_src1=enc_batch1.data.eq(config.PAD_idx).unsqueeze(1)
            # enc_embs1=self.embedding(enc_batch1) # ? 7.17 change this.
            if config.pointer_gen:
                new_enc_batch_extend_vocab=torch.cat((gui_token,enc_batch_extend_vocab),1)
            else:
                new_enc_batch_extend_vocab=None

        timing.begin('dec_trg_mask')
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        timing.end('dec_trg_mask')

        if config.use_memory:
        # ! Use memory information. (provide guidance embedding)
        # * Doing memory adaptation. if train=True, only for support set training. if no train condition, also on query set.
            if perform_adapt and config.use_memory: 
                keys = src_embs.detach()
                # * evaluate before support sets training should not refer to the memory but need to provide 'fake' guidance embedding.
                if eval_before_train and config.use_retrieval:
                    if config.only_store_cur_task:
                        gui_embs=self.embedding(gui_token)
                    else:
                        gui_embs=self.embedding(gui_token)
                elif (eval_before_train or config.only_binding) and not config.use_retrieval:
                    gui_embs=pred_embs.unsqueeze(1)
                else:
                    timing.begin('mm_get_neighbors')
                    retrieved_batches = memory.get_neighbours(keys,task_idx)
                    timing.end('mm_get_neighbors')
                    rt_keys, rt_values = retrieved_batches
                    sample_size = len(encoder_outputs)
                    rt_keys_cuda, rt_values_cuda = torch.tensor(torch.stack(rt_keys)).cuda(), torch.tensor(torch.stack(rt_values)).cuda()

                timing.begin('dec_emb')
                dec_emb = self.embedding(dec_batch_shift) # (batch_size, sen_length, emb_dim)
                timing.end('dec_emb')

                # * Use Retrieval methods.
                if config.use_retrieval: # Directly using mean of KNN retrieval instead of doing adaptation
                    if not eval_before_train: 
                        gui_embs=torch.mean(rt_values_cuda,dim=1,keepdim=True) #(batch_size,1,emb_dim)
                    # * concatenate with src word embeddings.
                    new_enc_embs=torch.cat((gui_embs,enc_embs),1)
                    new_encoder_outputs = self.encoder(new_enc_embs,mask_src1)
                    if config.reconstruct:
                        rec_enc_embs=torch.cat((value_embs.unsqueeze(1),enc_embs),1)
                        rec_encoder_outputs=self.encoder(rec_enc_embs,mask_src1)
                # * Use Local adaptation methods.
                else:
                    if (not eval_before_train) and (not config.only_binding):
                        timing.begin('adapt_infer')
                        gui_embs = Adapt_model.infer(keys, rt_keys_cuda, rt_values_cuda).unsqueeze(1)
                        timing.end('adapt_infer')
                    # * Concatenate guidance with src word embeddings.
                    new_enc_embs=torch.cat((gui_embs,enc_embs),1)
                    new_encoder_outputs = self.encoder(new_enc_embs,mask_src1)

                # Compute logits with new decoder input embeddings. 
                timing.begin('decoding')
                adapt_pre_logit, adapt_attn_dist = self.decoder(dec_emb,new_encoder_outputs, (mask_src1,mask_trg))
                timing.end('decoding')
                timing.begin('generating')
                logit = self.generator(adapt_pre_logit,adapt_attn_dist,new_enc_batch_extend_vocab, extra_zeros)
                timing.begin('generating')
                if config.reconstruct:
                    rec_pre_logit, rec_attn_dist = self.decoder(dec_emb,rec_encoder_outputs, (mask_src1,mask_trg))
                    rec_logit = self.generator(rec_pre_logit,rec_attn_dist,new_enc_batch_extend_vocab, extra_zeros)


            # ! Do not use memory information, but with 'true' guidance embedding from target information.
            elif use_guidence:
                st=time.time()
                dec_emb = self.embedding(dec_batch_shift) # (batch_size, sen_length, emb_dim)
                gui_embs = value_embs.unsqueeze(1)
                new_enc_embs=torch.cat((gui_embs,enc_embs),1)
                new_encoder_outputs = self.encoder(new_enc_embs,mask_src1)

                timing.begin('decoding')
                pre_logit, attn_dist = self.decoder(dec_emb,new_encoder_outputs, (mask_src1,mask_trg))
                timing.end('decoding')
                timing.begin('generating')
                logit = self.generator(pre_logit,attn_dist,new_enc_batch_extend_vocab, extra_zeros)
                timing.end('generating')
            # ! Do not use guidence embedding.
            else:
                timing.begin('decoding')
                pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg))
                timing.end('decoding')
                # compute output dist
                timing.begin('generating')
                logit = self.generator(pre_logit,attn_dist,new_enc_batch_extend_vocab, extra_zeros) 
                timing.end('generating')

        # Not use memory.
        else:
            timing.begin('decoding')
            logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros)
            timing.end('generating')

        # * Trs loss: NNL if ptr else Cross entropy
        trs_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        # print('Transformer loss in the inner loop:',trs_loss.item(),flush=True)
        if config.reconstruct and config.use_memory and perform_adapt and memory:
            rec_loss = self.criterion(rec_logit.contiguous().view(-1, rec_logit.size(-1)), dec_batch.contiguous().view(-1))

        # TODO: Compute loss and BP update weights. 
        if config.use_memory:
            if memory and perform_adapt and not config.use_retrieval: # When only perform retrieval, no need consider binding.
                loss = trs_loss + 0.1 * binding_loss
                rec_loss = binding_loss # ignore
                # print('Loss for trs: %.4f, bind: %.4f'%(trs_loss.item(), binding_loss.item()),flush=True)
            else:
                if config.reconstruct and perform_adapt and memory:
                    # print('reconstruction loss',rec_loss.item(),flush=True)
                    loss = trs_loss + 0.5 * rec_loss
                else:
                    loss = trs_loss 
                    rec_loss = loss
        else:
            loss = trs_loss

        if (train):
            loss.backward()
            if memory and perform_adapt and not config.use_retrieval:
                nn.utils.clip_grad_norm_(self.parameters(), 1e10)
            if step==0:
                self.optimizer.step()
            else:
                self.other_optimizer.step()

        if config.print_time:
            timing.print_all()

        if config.use_memory:
            # print('trs loss',trs_loss.item(),'rec loss',rec_loss.item(),flush=True)
            return trs_loss, math.exp(min(trs_loss.item(), 50)), rec_loss
    
        return loss, math.exp(min(loss.item(), 50))
