import copy
import importlib
import pdb
import numpy as np
import torch
import torch.nn as nn
from utils import config
from torch.autograd import Variable 
from utils import record_time as time_record
from numpy import linalg as L2
from utils.data_reader import Personas
import transformers
import random

cuda0 = torch.device('cuda:0')

p = Personas()
test_tasks = p.get_personas('test')

def l2_dist(x,y):
    diff = x-y
    return torch.sum(diff**2)
def l2_loss(pred_emb, tgt_emb): # ? how about MSE loss?
    if config.use_l2: # Use l2 norm distance
        #return torch.dist(pred_emb,tgt_emb, p=2) # ? or change to cosine similarity funciton.
        return l2_dist(pred_emb, tgt_emb)
    else: # Use cosine similarity
        cos_distance = nn.CosineSimilarity(dim=1, eps=1e-6) # default dim=1
        return  cos_distance(pred_emb, tgt_emb)

# TODO: Bind keys and values in the memory. =============================================================
class NN(nn.Module): 
    def __init__(self, hidden_size=768):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(hidden_size,1024) # ori: 512
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(1024,512)
        self.out = nn.Linear(512, 300)
    def forward(self, src_emb):
        x1 = self.fc2(self.tanh(self.fc1(src_emb))) 
        pred_emb = self.out(self.tanh(x1))
        # print('pred:',pred_emb.size(),flush=True)
        return pred_emb

# TODO: Build GUI token embeddings through LSTM. ====================================
class RNN(nn.Module):
    def __init__(self,input_size=768, hidden_size=512):
        super(RNN, self).__init__()
        self.num_layers = 1 
        self.hidden_size = hidden_size
        self.lstm_tgt = nn.LSTM(input_size=input_size, hidden_size=hidden_size,batch_first=True, bidirectional=False) # batch_first--(batch,seq,feature), bidirectional default False

        self.fc = nn.Linear(self.hidden_size,self.hidden_size)
        self.fc1 = nn.Linear(self.hidden_size, 300)
        self.tanh = nn.Tanh()

    def forward(self,tgt_enc):
        batch_size,tgt_length,input_size = tgt_enc.size() 
        h_0 = Variable(torch.zeros(self.num_layers,tgt_enc.size(0),self.hidden_size)) # initial hidden state
        c_0 = Variable(torch.zeros(self.num_layers,tgt_enc.size(0),self.hidden_size)) # initial cell
        h_0, c_0 = h_0.cuda(),c_0.cuda()

        tgt_out, (tgt_hn,tgt_cn) = self.lstm_tgt(tgt_enc, (h_0,c_0)) # input size of lstm: (seq_len, batch, input_size)
        
        tgt_emb = tgt_hn[-1,:,:]
        tgt_emb = self.tanh(self.fc(tgt_emb.view(batch_size, self.hidden_size)))
        gui_emb = self.fc1(tgt_emb)

        return gui_emb 

# TODO: Build memory to store training experience. =========================================================
class Memory(object):
    """
    Create the empty memory buffer.
    Dictionary of dictionary.
    """
    def __init__(self,buffer=None):

        if buffer is None:
            self.memory={}
            print('memory model: current memory has saved %d tasks'%len(self.memory.keys()),flush=True)
            total_keys = len(self.memory.keys())
            # print('mem,keys:',self.all_keys,flush=True)
        else:
            self.memory=buffer.memory
            total_keys = len(self.memory.keys())


    def compute_angle(self, v1,v2):
        norm_prod = torch.sqrt(torch.sum(v1 **2))*torch.sqrt(torch.sum(v2 **2))
        dot_prod = torch.matmul(v1,v2)
        cos = dot_prod / norm_prod
        cos_clip = torch.clamp(cos,-1,1)
        angle = torch.acos(cos_clip)
        #print('angle',angle,flush=True)
        return angle
    
    # * Diversity score computation.
    def diversity_score(self,A):
        angle_matrix = torch.zeros((len(A),len(A)))
        K = len(A)
        B=torch.sqrt(torch.sum(A**2,dim=1))
        C=torch.div(A.T,B).T
        D=torch.matmul(C,C.T)
        D_clip=torch.clamp(D,-1,1)
        angle_matrix=torch.acos(D_clip)
        
        mean = torch.sum(angle_matrix) / (K**2)
        var = (1/K**2) * torch.sum((angle_matrix-mean)**2) 
        return mean-var

    # * Update memory module when meet new key-value pairs.
    def memory_update(self, prev_keys, new_key, Tau=config.Tau): # M means memory; compute diversity scores of all keys for one task;
        timing=time_record.Time()
        # Compute diversity scores.
        new_key = torch.unsqueeze(new_key,0)
        # print('New key',new_key.shape,'prev_keys',prev_keys.shape,flush=True)
        combine = torch.cat((prev_keys,new_key))
        timing.begin('mm_cal_diver')
        prev_score = self.diversity_score(prev_keys)
        timing.end('mm_cal_diver')
        total_num=combine.size()[0]
        max_score = -1
        rp_key=None
        rp_key_idx=0
        for i in range(total_num):
            timing.begin('mm_update_one_loop')
            new_memory =  torch.cat((combine[:i],combine[i+1:]))
            div_score=self.diversity_score(new_memory)
            if div_score>max_score:
                max_score=div_score
                rp_key=combine[i]
                rp_key_idx=i
            timing.end('mm_update_one_loop')
        if config.print_time:
            timing.print_all()
        if rp_key_idx==total_num-1:
            return (False,0, None) 
        else:
            return (True, rp_key_idx,rp_key)

    # * Push new key-value pairs into the memory.
    def push(self,task_idx, keys, values,beta=config.store_ratio):
        """
        Add the key-value pairs to the memory dictionary.
        Two ways of whether memory is full or not.
        """
        task_id, num_dialogs = task_idx
        max_num = int(beta * num_dialogs) + 1

        timing=time_record.Time()
        # update the memory dictionary
        for key, value in zip(keys,values):
            if task_id in self.memory.keys():
                task_memory = self.memory[task_id]
            else:
                timing.begin('mm_list_append')
                self.memory[task_id]=[[],[]]
                self.memory[task_id][0].append(key)
                self.memory[task_id][1].append(value)
                timing.end('mm_list_append')
                continue
            cur_num = len(task_memory)
            if cur_num < max_num:
                task_memory.append(key)
                self.memory[task_id][0].append(key)
                self.memory[task_id][1].append(value)
            else:
                timing.begin('mm_update_diversity')
                if len(task_memory)==1:
                    prev_keys=torch.stack(task_memory[0]).unsqueeze(0)
                else:
                    prev_keys=torch.stack(task_memory[0])
                    if config.no_diverse:
                        rnd_idx = random.choice(range(len(prev_keys)))
                        task_memory[0].pop(rnd_idx)
                        task_memory[1].pop(rnd_idx)
                        task_memory[0].append(key)
                        task_memory[1].append(value)
                        continue

                rp, rp_key_idx, rp_key = self.memory_update(prev_keys,key)
                if rp:
                    # print('rp_key',rp_key.size(),rp_key,flush=True)
                    task_memory[0].pop(rp_key_idx)
                    task_memory[1].pop(rp_key_idx)
                    task_memory[0].append(key)
                    task_memory[1].append(value)
                timing.end('mm_update_diversity')
    
    def _prepare_batch(self,x):
        return torch.stack(x) # Tensor convert data type to float.

    # * KNN retrieve nearest neighbors.
    def get_neighbours(self,keys, task_idx, k=config.neighbor_num):
        """
        Return samples from buffer using KNN.
        """
        k_values = []
        k_keys = []
        task_id, num_dialogs = task_idx
        max_num = int(config.store_ratio * num_dialogs) + 1
        timing=time_record.Time()
        # * Only consider the same task:
        cur_task_memory = self.memory[task_id]
        if config.all_tasks:
            total_k = k
            k = max(min(int(0.5 * total_k),len(cur_task_memory)-1),1)
        else:
            k = min(k,len(cur_task_memory)-1)
        if len(cur_task_memory[0])<k:
            k=len(cur_task_memory[0])
        cur_all_keys = torch.stack(cur_task_memory[0]).cuda()

        key_self_sim=torch.matmul(keys,keys.T)
        softmax_score=nn.Softmax(dim=0)
        for key in keys:
            if not config.no_neighbors:
                similarity_scores = torch.matmul(cur_all_keys,key.T)
                top_k=torch.topk(similarity_scores,k).indices.tolist() # return k largest
                top_k_score=torch.topk(similarity_scores,k).values
                topk_score_softmax=softmax_score(top_k_score).tolist()
                num_simi=similarity_scores.size()[0]
            else:
                top_k = random.choices(range(len(cur_all_keys)),k=k)
            K_neighbour_keys = [cur_all_keys[i] for i in top_k] 
            if config.weighted_value:
                neighbours_value = [cur_task_memory[1][top_k[i]]*topk_score_softmax[i] for i in range(len(top_k))]
            else:
                neighbours_value = [cur_task_memory[1][p] for p in top_k] 

            value_batch = self._prepare_batch(neighbours_value)
            key_batch = self._prepare_batch(K_neighbour_keys)
            k_values.append(value_batch)
            k_keys.append(key_batch)
            timing.end('mm_get_nei_for_one_key')
        # ! k_keys-->(batch_size,k,emb_dim)
        if not config.all_tasks:
            if config.print_time:
                timing.print_all()
            return (k_keys, k_values)
        # * Consider other tasks help.
        else:
            timing.begin('mm_all_task_get_nei')
            # Get other tasks center:
            all_ids=self.memory.keys()
            other_centers=[]
            other_centers_values=[]
            other_keys = []
            other_values = []
            for per in all_ids:
                if config.test:
                    if per!=task_id and per not in test_tasks: # ! During testing, different test tasks should not interfere each other.
                        per_memory=self.memory[per]
                        other_keys.extend(per_memory[0])
                        other_values.extend(per_memory[1])
                else:
                    if per!=task_id:
                        per_memory=self.memory[per]
                        other_keys.extend(per_memory[0])
                        other_values.extend(per_memory[1])
            # Get k1 nearest centers.
            if len(other_keys)==0:
                return (k_keys,k_values)

            k1=max(min(total_k-k,len(other_keys)-1),1)
            other_keys_tensor=torch.stack(other_keys).cuda()

            k1_values=[]
            k1_keys=[]
            for key in keys:
                similarity_scores1=torch.matmul(other_keys_tensor,key.T)
                top_k1=torch.topk(similarity_scores1,k1).indices.tolist()
                top_k1_score=torch.topk(similarity_scores1,k1).values
                top_k1_score_softmax=softmax_score(top_k1_score).tolist()
                num_simi1=similarity_scores1.size()[0]
                K_neighbour_keys1=[other_keys[i] for i in top_k1]
                if config.weighted_value:
                    neighbours_value1=[other_values[top_k1[i]]*top_k1_score_softmax[i] for i in range(len(top_k1))]
                else:
                    neighbours_value1=[other_values[i] for i in top_k1]
                value_batch=self._prepare_batch(neighbours_value1)
                key_batch=self._prepare_batch(K_neighbour_keys1)
                k1_values.append(value_batch)
                k1_keys.append(key_batch)

            # Combine nearest centers of different tasks with nearest shots of the same task. 
            for i in range(len(keys)):
                k1_keys[i]=torch.cat((k_keys[i],k1_keys[i]))
                k1_values[i]=torch.cat((k_values[i],k1_values[i]))
            timing.end('mm_all_task_get_nei')
            if config.print_time:
                timing.print_all()
            return (k1_keys,k1_values)

    def get_mean_all_values(self):
        all_values=[]
        for per_id in self.memory.keys():
            per_values=torch.mean(torch.stack(self.memory[per_id][1]),dim=0)
            all_values.append(per_values)

        all_values1=torch.stack(all_values)
        return torch.mean(all_values1,dim=0)

# TODO: Perform Local adaptation in the memory. ================================================       
class LocalAdapt(nn.Module):
    """
    Implement local adaptation.
    """
    def __init__(self,binding_model, bert_model, L=config.adapt_num, BP_adapt_model=False):
        super(LocalAdapt,self).__init__()
        self.key_encoder = bert_model
        self.key_encoder = self.key_encoder.cuda()

        self.binding = binding_model        
        #self.rnn = rnn_embedding
        self.base_weights = list(self.binding.parameters())
        # local adaptation learning rate - 1e-3 or 5e-3
        self.loc_adapt_lr = 1e-3

        # Number of local adaptation steps
        self.L = L
        self.BP_adapt_model=BP_adapt_model

    def get_keys(self, contents, attn_masks):
        """
        Return key representation of the documents
        """
        # Freeze the weights of the key network to prevent key
        # representations from drifting as data distribution changes
        with torch.no_grad():
            outputs = self.key_encoder(contents, attention_mask=attn_masks)
        # Obtain key representation of every text content by selecting the its [CLS] hidden representation
        last_hidden_states = outputs.last_hidden_state
        # print('last_hid',last_hidden_states,flush=True)
        keys = last_hidden_states[:, 0, :]
        # print('get_keys',keys.size(),flush=True)
        return keys.detach()

    def get_guidance(self, contents, attn_masks, Rnn):
        """
        Return key representation of the documents
        """
        # Freeze the weights of the key network to prevent key
        # representations from drifting as data distribution changes
        with torch.no_grad():
            outputs = self.key_encoder(contents, attention_mask=attn_masks)
        # Obtain key representation of every text content by selecting the its [CLS] hidden representation
        last_hidden_states = outputs.last_hidden_state.detach()
        guidance = Rnn(last_hidden_states)
        return guidance

    def infer(self,src_emb, K_src_embs, K_tgt_embs ):
        """
        Function that performs inference based on memory based local adaptation
        Parameters:
        content   -> document that needs to be classified
        attn_mask -> attention mask over document
        rt_batch  -> the batch of samples retrieved from the memory using nearest neighbour approach

        Returns:
        logit -> label corresponding to the single document provided,i.e, content
        """
        timing=time_record.Time()
        adaptive_binding = copy.deepcopy(self.binding).cuda()
        optimizer = torch.optim.AdamW(adaptive_binding.parameters(),lr=self.loc_adapt_lr)
        # Current model weights
        curr_weights = list(adaptive_binding.parameters())

        # * Train the adaptive NN for L epochs.
        for l in range(self.L):
            timing.begin('one_adaptation')
            optimizer.zero_grad()
            K_pred_embs = adaptive_binding(K_src_embs)
            loss = l2_loss(K_pred_embs,K_tgt_embs)

            diff_loss = torch.Tensor([0]).to('cuda' if torch.cuda.is_available() else 'cpu')

            for base_param, curr_param in zip(self.base_weights, curr_weights):
                curr_param, base_param = curr_param.cuda(),base_param.cuda()
                diff_loss += (curr_param-base_param).pow(2).sum()
            
            # * Update the binding model (FNN).
            total_loss = 0.01*diff_loss + loss
            total_loss.backward()
            optimizer.step()
            timing.end('one_adaptation')

        # * Predict the target embedding.
        pred_emb = adaptive_binding(src_emb)
        if config.print_time:
            timing.print_all() 

        if self.BP_adapt_model:
            return pred_emb
        else:
            return pred_emb.detach()
