from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from transformers import BertForSequenceClassification
from transformers import BertConfig, BertModel
from copy import deepcopy
import gc
import torch
from sklearn.metrics import accuracy_score
import numpy as np

class BASE(nn.Module):
    '''
        BASE model
    '''
    def __init__(self, args):
        super(BASE, self).__init__()
        self.args = args

        # # cached tensor for speed
        # self.I_way = nn.Parameter(torch.eye(self.args.way, dtype=torch.float),
        #                           requires_grad=False)

    def _compute_l2(self, XS, XQ):
        '''
            Compute the pairwise l2 distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size x support_size

        '''
        diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
        dist = torch.norm(diff, dim=2)

        return dist

    def _compute_cos(self, XS, XQ):
        '''
            Compute the pairwise cos distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size support_size

        '''
        dot = torch.matmul(
                XS.unsqueeze(0).unsqueeze(-2),
                XQ.unsqueeze(1).unsqueeze(-1)
                )
        dot = dot.squeeze(-1).squeeze(-1)

        scale = (torch.norm(XS, dim=1).unsqueeze(0) *
                 torch.norm(XQ, dim=1).unsqueeze(1))

        scale = torch.max(scale,
                          torch.ones_like(scale) * 1e-8)

        dist = 1 - dot/scale

        return dist

    def reidx_y(self, YS, YQ):
        '''
            Map the labels into 0,..., way
            @param YS: batch_size
            @param YQ: batch_size

            @return YS_new: batch_size
            @return YQ_new: batch_size
        '''
        unique1, inv_S = torch.unique(YS, sorted=True, return_inverse=True)
        unique2, inv_Q = torch.unique(YQ, sorted=True, return_inverse=True)

#         if len(unique1) != len(unique2):
#             raise ValueError(
#                 'Support set classes are different from the query set')

#         if len(unique1) != self.args.way:
#             raise ValueError(
#                 'Support set classes are different from the number of ways')

#         if int(torch.sum(unique1 - unique2).item()) != 0:
#             raise ValueError(
#                 'Support set classes are different from the query set classes')

        Y_new = torch.arange(start=0, end=self.args.way, dtype=unique1.dtype,
                device=unique1.device)

        return Y_new[inv_S], Y_new[inv_Q]

    def _init_mlp(self, in_d, hidden_ds, drop_rate):
        modules = []

        for d in hidden_ds[:-1]:
            modules.extend([
                nn.Dropout(drop_rate),
                nn.Linear(in_d, d),
                nn.ReLU()])
            in_d = d

        modules.extend([
            nn.Dropout(drop_rate),
            nn.Linear(in_d, hidden_ds[-1])])

        return nn.Sequential(*modules)

#     def _label2onehot(self, Y):
#         '''
#             Map the labels into 0,..., way
#             @param Y: batch_size

#             @return Y_onehot: batch_size * ways
#         '''
#         Y_onehot = F.embedding(Y, self.I_way)

#         return Y_onehot

    @staticmethod
    def compute_acc(pred, true):
        '''
            Compute the accuracy.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        return torch.mean((torch.argmax(pred, dim=1) == true).float()).item()


    
class Bert(nn.Module):
    def __init__(self, pretrained='bert-base-uncased', pool_type="cls", dropout_prob=0.3, n_classes=2):
        super().__init__()
        conf = BertConfig.from_pretrained(pretrained)
        conf.attention_probs_dropout_prob = dropout_prob
        conf.hidden_dropout_prob = dropout_prob
        self.dropout_prob = dropout_prob
        self.encoder = BertModel.from_pretrained(pretrained, config=conf)
        assert pool_type in ["cls", "pooler"], "invalid pool_type: %s" % pool_type
        self.pool_type = pool_type
        
        
        self.n_classes=n_classes
        self.fc = nn.Linear(768, self.n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.encoder(input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids)
        if self.pool_type == "cls":
            output = output.last_hidden_state[:, 0]
        elif self.pool_type == "pooler":
            output = output.pooler_output
     
        pred_label = self.fc(output)

        return output, pred_label

    
class Learner_RidgeRgressor(BASE):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Learner_RidgeRgressor, self).__init__(args)
        
        self.num_k_shot = args.num_k_shot
        self.num_labels = args.num_labels
        self.outer_batch_size = args.outer_batch_size
        self.inner_batch_size = args.inner_batch_size
        self.outer_update_lr  = args.outer_update_lr
        self.inner_update_lr  = args.inner_update_lr
        self.inner_update_step = args.inner_update_step
        self.inner_update_step_eval = args.inner_update_step_eval
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = Bert('bert-base-uncased', 'cls', 0.3, self.num_labels)#.to(self.device)
        self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
        
        # meta parameters to learn
        self.lam = nn.Parameter(torch.tensor(-1, dtype=torch.float)).to(self.device)
        self.alpha = nn.Parameter(torch.tensor(0, dtype=torch.float)).to(self.device)
        self.beta = nn.Parameter(torch.tensor(1, dtype=torch.float)).to(self.device)
        self.I_support = nn.Parameter(torch.eye(self.num_k_shot, dtype=torch.float), requires_grad=False).to(self.device)
        self.I_way = nn.Parameter(torch.eye(self.num_labels, dtype=torch.float), requires_grad=False).to(self.device)
        
    
    def _label2onehot(self, Y):
        '''
            Map the labels into 0,..., way
            @param Y: batch_size

            @return Y_onehot: batch_size * ways
        '''

        Y_onehot = F.embedding(Y, self.I_way)

        return Y_onehot
    
    def _compute_w(self, XS, YS_onehot):
        '''
            Compute the W matrix of ridge regression
            @param XS: support_size x ebd_dim
            @param YS_onehot: support_size x way

            @return W: ebd_dim * way
        '''

        W = XS.t() @ torch.inverse(XS @ XS.t() + (10. ** self.lam) * self.I_support) @ YS_onehot

        return W
        

    def forward(self, batch_tasks, training = True):
        """
        batch = [(support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset)]
        
        # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
        """
        task_accs = []
        sum_gradients = []
        num_task = len(batch_tasks)
        num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval

     
            
        for task_id, task in enumerate(batch_tasks):
            support = task[0]
            query   = task[1]
            
            self.model.train()
            self.model.to(self.device)
            
            
            support_dataloader = DataLoader(support, sampler=None, batch_size=len(support))
            support_batch = iter(support_dataloader).next()
            support_batch = tuple(t.to(self.device) for t in support_batch)
            s_input_ids, s_attention_mask, s_segment_ids, s_label_ids = support_batch
            YS = s_label_ids
            XS, _ = self.model(s_input_ids, s_attention_mask, s_segment_ids)
            
            query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
            query_batch = iter(query_dataloader).next()
            query_batch = tuple(t.to(self.device) for t in query_batch)
            q_input_ids, q_attention_mask, q_segment_ids, q_label_ids = query_batch
            YQ = q_label_ids
            XQ, _ = self.model(q_input_ids, q_attention_mask, q_segment_ids)
            
            
#             print("YS.device, self.I_way.device")
#             print(YS.device, self.I_way.device)
            
            YS_onehot = self._label2onehot(YS)
            W = self._compute_w(XS, YS_onehot)

            q_pred = (10.0 ** self.alpha) * XQ @ W + self.beta
 
            if training:
                self.model.zero_grad()
                loss = F.cross_entropy(q_pred,  YQ)
                loss.backward()
                self.outer_optimizer.step()
                 
            q_logits = F.softmax(q_pred, dim=1)
            q_pre_label_id = torch.argmax(q_logits,dim=1)
            q_pre_label_id = q_pre_label_id.detach().cpu().numpy().tolist()
            q_label_ids = q_label_ids.detach().cpu().numpy().tolist()
            
            acc = accuracy_score(q_pre_label_id, q_label_ids)
            task_accs.append(acc)
            
       
        
        return np.mean(task_accs)