from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from copy import deepcopy
import gc
import torch
from sklearn.metrics import accuracy_score
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig

class BASE(nn.Module):
    '''
        BASE model
    '''
    def __init__(self, args):
        super(BASE, self).__init__()
        self.args = args

        # # cached tensor for speed
        # self.I_way = nn.Parameter(torch.eye(self.args.way, dtype=torch.float),
        #                           requires_grad=False)

    def _compute_l2(self, XS, XQ):
        '''
            Compute the pairwise l2 distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size x support_size

        '''
        diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
        dist = torch.norm(diff, dim=2)

        return dist

    def _compute_cos(self, XS, XQ):
        '''
            Compute the pairwise cos distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size support_size

        '''
        dot = torch.matmul(
                XS.unsqueeze(0).unsqueeze(-2),
                XQ.unsqueeze(1).unsqueeze(-1)
                )
        dot = dot.squeeze(-1).squeeze(-1)

        scale = (torch.norm(XS, dim=1).unsqueeze(0) *
                 torch.norm(XQ, dim=1).unsqueeze(1))

        scale = torch.max(scale,
                          torch.ones_like(scale) * 1e-8)

        dist = 1 - dot/scale

        return dist

    def reidx_y(self, YS, YQ):
        '''
            Map the labels into 0,..., way
            @param YS: batch_size
            @param YQ: batch_size

            @return YS_new: batch_size
            @return YQ_new: batch_size
        '''
        unique1, inv_S = torch.unique(YS, sorted=True, return_inverse=True)
        unique2, inv_Q = torch.unique(YQ, sorted=True, return_inverse=True)

        if len(unique1) != len(unique2):
            raise ValueError(
                'Support set classes are different from the query set')

        if len(unique1) != self.args.way:
            raise ValueError(
                'Support set classes are different from the number of ways')

        if int(torch.sum(unique1 - unique2).item()) != 0:
            raise ValueError(
                'Support set classes are different from the query set classes')

        Y_new = torch.arange(start=0, end=self.args.way, dtype=unique1.dtype,
                device=unique1.device)

        return Y_new[inv_S], Y_new[inv_Q]
    
    
    def _label2onehot(self, Y):
        '''
            Map the labels into 0,..., way
            @param Y: batch_size

            @return Y_onehot: batch_size * ways
        '''
        Y_onehot = F.embedding(Y, self.I_way)

        return Y_onehot
    
    @staticmethod
    def compute_acc(pred, true):
        '''
            Compute the accuracy.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        return torch.mean((torch.argmax(pred, dim=1) == true).float()).item()


    
class Bert(nn.Module):
    def __init__(self, pretrained_bert, pool_type="cls", dropout_prob=0.3, n_classes=2):
        super().__init__()
        # conf = AutoConfig.from_pretrained(pretrained_bert)
#         conf.attention_probs_dropout_prob = dropout_prob
#         conf.hidden_dropout_prob = dropout_prob
#         self.dropout_prob = dropout_prob
        
        
        
        # self.encoder = AutoModel.from_pretrained(pretrained_bert, config=conf)
        self.encoder = AutoModel.from_pretrained(pretrained_bert)
        # self.encoder = AutoModel.from_pretrained(pretrained_bert)
        # self.encoder = AutoModel.from_pretrained("../baserun/AAAAAAAAAA/checkpoint-8000", config=conf)
         
        
        
        self.pool_type = pool_type
        
        
        self.n_classes=n_classes
        self.fc = nn.Linear(768, self.n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.encoder(input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids)
        if self.pool_type == "cls":
            output = output.last_hidden_state[:, 0]
        elif self.pool_type == "pooler":
            output = output.pooler_output
     
        pred_label = self.fc(output)

        return output, pred_label


class PROTO(BASE):
    '''
        META-LEARNING WITH DIFFERENTIABLE CLOSED-FORM SOLVERS
    '''
    def __init__(self, args):
        super(PROTO, self).__init__(args)
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = Bert(self.args.pretrained_bert, 'cls', 0.3, self.args.way)#.to(self.device)
            
        
    def _compute_prototype(self, XS, YS):
        '''
            Compute the prototype for each class by averaging over the ebd.

            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size

            @return prototype: way x ebd_dim
        '''
        # sort YS to make sure classes of the same labels are clustered together
        sorted_YS, indices = torch.sort(YS)
        sorted_XS = XS[indices]

        prototype = []
        for i in range(self.args.way):
            prototype.append(torch.mean(
                sorted_XS[i*self.args.shot:(i+1)*self.args.shot], dim=0,
                keepdim=True))

        prototype = torch.cat(prototype, dim=0)

        return prototype

    def forward(self, support, query):
        '''
            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size
            @param XQ (support x): query_size x ebd_dim
            @param YQ (support y): query_size

            @return acc
            @return loss
        '''
        max_seq_length = 256
        s_label_ids = support['label']
        s_input_ids_text = support['text']
        s_attention_mask = support['ret_mask']
        s_segment_ids = support['ret_type_ids']
        s_len = len(s_label_ids)
        
        q_label_ids = query['label']
        q_input_ids_text = query['text']
        q_attention_mask = query['ret_mask']
        q_segment_ids = query['ret_type_ids']
        q_len = len(q_label_ids)
        
# #         print("s_len,q_len")
# #         print(s_len,q_len)
        
#         s_input_ids      = torch.empty(s_len, max_seq_length, dtype = torch.long).to(self.device)
#         s_attention_mask = torch.empty(s_len, max_seq_length, dtype = torch.long).to(self.device)
#         s_segment_ids = torch.empty(s_len, max_seq_length, dtype = torch.long).to(self.device)
       
#         for i in range(s_len):
#             s_input_ids[i] = s_input_ids_text[i] 
#             s_attention_mask[i] =  s_attention_mask[i] 
#             s_segment_ids[i] =  s_segment_ids[i] 

        # print(s_input_ids_text, s_attention_mask, s_segment_ids)
        # exit()

        
        YS = s_label_ids
        
        # print("s_input_ids.shape, s_attention_mask.shape, s_segment_ids.shape")
        # print(s_input_ids.shape, s_attention_mask.shape, s_segment_ids.shape)
        XS, _ = self.model(s_input_ids_text, s_attention_mask, s_segment_ids)
        
        
            
     #######################################################################
#         q_input_ids      = torch.empty(q_len, max_seq_length, dtype = torch.long).to(self.device)
#         q_attention_mask = torch.empty(q_len, max_seq_length, dtype = torch.long).to(self.device)
#         q_segment_ids = torch.empty(q_len, max_seq_length, dtype = torch.long).to(self.device)
       
#         for i in range(q_len):
#             q_input_ids[i] = torch.Tensor(q_input_idq_text[i]).to(torch.long).to(self.device)
#             q_attention_mask[i] = torch.Tensor(q_attention_mask[i]).to(torch.long).to(self.device)
#             q_segment_ids[i] = torch.Tensor(q_segment_ids[i]).to(torch.long).to(self.device)
            
        YQ = q_label_ids
        XQ, _ = self.model(q_input_ids_text, q_attention_mask, q_segment_ids)
             

        YS, YQ = self.reidx_y(YS, YQ)

        prototype = self._compute_prototype(XS, YS)
        
        # print("len(prototype):", len(prototype))  #5 Way

        pred = -self._compute_l2(prototype, XQ)

        loss = F.cross_entropy(pred, YQ)

        acc = BASE.compute_acc(pred, YQ)

        return acc, loss
