import copy
import importlib
import pdb
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable 
from numpy import linalg as L2
import transformers
from model.memory_modeling import BertForSequenceClassification, BertConfig
from torch.nn import CrossEntropyLoss, MSELoss
from numpy import linalg as LA
import random

cuda0 = torch.device('cuda:0')

class Memory(object):
    """
        Create the empty memory buffer
    """

    def __init__(self, buffer=None):

        if buffer is None:
            self.memory = {}
            # print('memory keys',self.memory.keys(),flush=True)
            total_keys = len(self.memory.keys())
        else:
            self.memory = buffer
            total_keys = len(buffer.keys())
            # convert the keys from np.bytes to np.float32
            self.all_keys = np.frombuffer(
                np.asarray(list(self.memory.keys())), dtype=np.float32).reshape(total_keys, 768)

    def compute_angle(self, v1,v2):
        norm_prod = np.sqrt(np.sum(v1 **2))*np.sqrt(np.sum(v2 **2))
        dot_prod = np.dot(v1.T,v2)
        cos = dot_prod / norm_prod
        cos_clip = np.clip(cos,-1,1)
        angle = np.arccos(cos_clip)
        #print('angle',angle,flush=True)
        return angle    

    # * Diversity score computation.
    def diversity_score(self,A):
        angle_matrix = np.zeros((len(A),len(A)))
        #print('A',A,flush=True)
        K = len(A)
        B=np.sqrt(np.sum(A**2,axis=1))
        C=np.divide(A.T,B).T
        D=np.dot(C,C.T)
        D_clip=np.clip(D,-1,1)
        angle_matrix=np.arccos(D_clip)
        mean = np.sum(angle_matrix) / (K**2)
        var = (1/K**2) * np.sum((angle_matrix-mean)**2) 
        #print('diversity: K %d'%K,mean, var,flush=True)
        return mean-var

    # * Update memory module when meet new key-value pairs.
    def memory_update(self, prev_keys, new_key): # M means memory; compute diversity scores of all keys for one task;
        # tau is the threshold; prev_keys: np.array
        # Compute diversity scores.
        new_key = np.expand_dims(new_key,0)
        # print('New key',new_key.shape,'prev_keys',prev_keys.shape,flush=True)
        combine = np.concatenate((prev_keys,new_key),axis=0)
        # print('combine',combine.shape,flush=True)
        prev_score = self.diversity_score(prev_keys)
        total_num = len(combine)
        max_score = -1
        rp_key=None
        rp_key_idx=0
        for i in range(total_num):
            new_memory =  np.concatenate((combine[:i],combine[i+1:]),axis=0)
            div_score=self.diversity_score(new_memory)
            if div_score>max_score:
                max_score=div_score
                rp_key=combine[i]
                rp_key_idx=i
        if rp_key_idx==total_num-1:
            return (False,0, None) 
        else:
            return (True, rp_key_idx,rp_key)

    def push(self,task_idx, keys, values,beta=0.5,diverse=True):
        """
        Add the examples as key-value pairs to the memory dictionary with content,attention_mask,label tuple as value
        and key determined by key network
        """
        # update the memory dictionary
        task_id,num_sample = task_idx

        max_num = max(int(beta * min(num_sample,100)) + 1,2)
        # print('task id',task_id,'max num',max_num,'total num',num_sample,flush=True)
        for key, value in zip(keys,values):
            if task_id in self.memory.keys():
                task_memory = self.memory[task_id]
            else:
                self.memory[task_id] = [[],[]]
                self.memory[task_id][0].append(key.tobytes())
                self.memory[task_id][1].append(value)
                continue
            cur_num = len(task_memory)
            if cur_num <= max_num:
                self.memory[task_id][0].append(key.tobytes())
                self.memory[task_id][1].append(value)
            else:
                if len(task_memory[0])==1:
                    prev_keys = np.frombuffer(np.asarray(task_memory[0]), dtype=np.float32).reshape(len(task_memory[0]), 768)
                elif len(task_memory[0])==0:
                    continue
                else:
                    if diverse!=True:
                        rnd_idx = random.choice(range(max_num))
                        task_memory[0].pop(rnd_idx)
                        task_memory[1].pop(rnd_idx)
                        task_memory[0].append(key.tobytes())
                        task_memory[1].append(value)
                        continue

                    prev_keys =np.frombuffer(np.asarray(task_memory[0]), dtype=np.float32).reshape(len(task_memory[0]), 768) 
                # print('prev keys',prev_keys,flush=True)
                rp, rp_key_idx, rp_key = self.memory_update(prev_keys,key)
                if rp:
                    task_memory[0].pop(rp_key_idx)
                    task_memory[1].pop(rp_key_idx)
                    task_memory[0].append(key.tobytes())
                    task_memory[1].append(value)

    def _prepare_batch(self, sample):
        """
        Parameter:
        sample -> list of tuple of experiences
               -> i.e, [(content_1,attn_mask_1,label_1),.....,(content_k,attn_mask_k,label_k)]
        Returns:
        batch -> tuple of list of content,attn_mask,label
              -> i.e, ([content_1,...,content_k],[attn_mask_1,...,attn_mask_k],[label_1,...,label_k])
        """
        return torch.mean(torch.FloatTensor(sample),dim=0).cuda()

    def _prepare_batch_with_labels(self, sample,weights,weighted=False):
        out_list = []
        label_list = []
        for out, label in sample:
            out_list.append(out)
            label_list.append(label)

        new_out_list = [out_list[i] * weights[i] for i in range(len(weights))]
        new_label_list = [label_list[i] * weights[i] for i in range(len(weights))]

        sample_out = torch.sum(torch.FloatTensor(new_out_list),dim=0).cuda()
        sample_label = torch.sum(torch.FloatTensor(new_label_list),dim=0).cuda()
        # print('set2/mm: sample label',sample_label,flush=True)
        return sample_out, sample_label

    def get_neighbours(self, task_id, keys, k=8, with_label=False,weighted=False,all_task=False, self_neighbor_ratio=0.5,no_neighbors=False):
        """
        Returns logits from buffer using nearest neighbour approach
        """
        samples = []
        K_res = []
        task_memory = self.memory[task_id]
        # Iterate over all the input keys to find neigbours for each of them
        total_keys = len(task_memory[0])
        all_keys = np.frombuffer(np.asarray(task_memory[0]), dtype=np.float32).reshape(total_keys, 768)
        if all_task:
            total_k = k
            k = max(min(int(self_neighbor_ratio * total_k), total_keys-1),1)
        else:
            k = min(k, total_keys - 1)
        for key in keys:
            # print('k',k,total_keys,flush=True)
            # compute similarity scores based on Euclidean distance metric
            if not no_neighbors:
                similarity_scores = LA.norm(all_keys-key,axis=1)
                top_k = np.argpartition(similarity_scores, k)[:k]
                # print('topk',top_k,flush=True)
                top_k_score = similarity_scores[top_k]
            else:
                top_k = random.choices(range(len(all_keys)),k=k)
            K_neighbour_keys = all_keys[top_k]
            neighbours = [task_memory[1][p][1] for p in top_k] # neighbor_labels

            if weighted:
                weights = softmax_score(top_k_score/5,tau=10).tolist()
            else:
                weights = [1/len(top_k)] * len(top_k)
            # converts experiences into batch
            if with_label:
                if weighted:
                    batch = self._prepare_batch_with_labels(neighbours,weights,True)
                else:
                    batch = neighbours 
            else:
                batch = self._prepare_batch(neighbours)
            samples.append(torch.tensor(batch))
            K_res.append(torch.tensor(K_neighbour_keys))
        if not all_task:
            return samples, K_res
        else:
            test_tasks = [i+57 for i in range(12)]
            all_ids=self.memory.keys()
            other_keys = []
            other_values = []
            for per in all_ids:
                per_memory=self.memory[per]
                len_per_keys = len(per_memory[0])
                if per!=task_id and per not in test_tasks:
                # if per!=task_id:
                    all_keys = np.frombuffer(np.asarray(per_memory[0]), dtype=np.float32).reshape(len_per_keys, 768)
                    other_keys.extend(all_keys)
                    other_values.extend(per_memory[1])
            if len(other_keys)==0:
                return samples
            k1 = max(min(total_k-k,len(other_keys)-1),1)
            # print('k1',k1,'other_keys num',len(other_keys),flush=True)
            # k1 = 2 * num_k 
            other_keys_array = np.array(other_keys)
            # print('other',len(other_keys_array),len(other_values),flush=True)

            samples1 = []
            K1_res = []
            for key in keys:
                if not no_neighbors:
                    similarity_scores1 = LA.norm(other_keys_array-key,axis=1)
                    top_k1 = np.argpartition(similarity_scores1, k1)[:k1]
                    top_k1_score = similarity_scores1[top_k1]
                else:
                    top_k1 = random.choices(range(len(other_keys_array)),k=k1)
                neighbours1 = [other_values[p][1] for p in top_k1] # for labels
                K1_neighbor_keys = other_keys_array[top_k1]

                if weighted:
                    weights1 = softmax_score(top_k1_score/5,tau=10).tolist()
                else:
                    weights1 = [1/len(top_k1)] * len(top_k1)

                if with_label:
                    if weighted:
                        batch1 = self._prepare_batch_with_labels(neighbours1,weights1,True)
                    else:
                        batch1=neighbours1
                else:
                    batch1 = self._prepare_batch(neighbours1)
                samples1.append(torch.tensor(batch1))
                K1_res.append(torch.tensor(K1_neighbor_keys))
            new_samples = []
            new_keys = []
            for i in range(len(keys)):
                value = torch.cat((samples[i],samples1[i]))
                key = torch.cat((K_res[i],K1_res[i]))
                new_samples.append(value)
                new_keys.append(key)

            return new_samples,new_keys
    

class LocalAdapt(nn.Module):
    """
    Implements Memory based Parameter Adaptation model
    """

    def __init__(self, model,L=5):
        super(LocalAdapt, self).__init__()

            # Key network to find key representation of content
        self.key_encoder = transformers.BertModel.from_pretrained(
                'bert-base-uncased')
        self.classifier = model
        self.base_weights = list(self.classifier.parameters())

        # local adaptation learning rate - 1e-3 or 5e-3
        self.loc_adapt_lr = 1e-3
        # Number of local adaptation steps
        self.L = L

    def get_keys(self, contents, attn_masks):
        """
        Return key representation of the documents
        """
        # Freeze the weights of the key network to prevent key
        # representations from drifting as data distribution changes
        with torch.no_grad():
            outputs = self.key_encoder(
                contents, attention_mask=attn_masks)
        last_hidden_states = outputs.last_hidden_state
        # Obtain key representation of every text content by selecting the its [CLS] hidden representation
        keys = last_hidden_states[:, 0, :]

        return keys

    def infer(self, content, K_contents,  K_labels, output_mode, num_labels):
        """
        Function that performs inference based on memory based local adaptation
        Parameters:
        content   -> document that needs to be classified
        attn_mask -> attention mask over document
        rt_batch  -> the batch of samples retrieved from the memory using nearest neighbour approach

        Returns:
        logit -> label corresponding to the single document provided,i.e, content
        """

        # create a local copy of the classifier network
        adaptive_classifier = copy.deepcopy(self.classifier)
        # adaptive_classifier = self.classifier

        optimizer = transformers.AdamW(
            adaptive_classifier.parameters(), lr=self.loc_adapt_lr)

        # Current model weights
        curr_weights = list(adaptive_classifier.parameters())
        # Train the adaptive classifier for L epochs with the rt_batch
        for _ in range(self.L):
            # zero out the gradients
            optimizer.zero_grad()
            # print(adaptive_classifier(K_contents),flush=True)
            sup_logits = adaptive_classifier(K_contents)
            
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                likelihood_loss = loss_fct(sup_logits.view(-1, num_labels), K_labels.view(-1).long())
            elif output_mode == "regression":
                loss_fct = MSELoss()
                likelihood_loss = loss_fct(sup_logits.view(-1), K_labels.view(-1))

            # Initialize diff_loss to zero and place it on the appropriate device
            diff_loss = torch.Tensor([0]).to(
                "cuda" if torch.cuda.is_available() else "cpu")
            # Iterate over base_weights and curr_weights and accumulate the euclidean norm
            # of their differences
            for base_param, curr_param in zip(self.base_weights, curr_weights):
                diff_loss += (curr_param-base_param).pow(2).sum()

            # Total loss due to log likelihood and weight restraint
            total_loss = 0.01*diff_loss + likelihood_loss
            total_loss.backward()
            optimizer.step()

        logits = adaptive_classifier(content)
        # Note: to prevent keeping track of intermediate values which
        # can lead to cuda of memory runtime error logits should be detached

        return logits.detach()
