from mi_estimator import *
from multivairate_loss import *
from model_utils import *
from supervized_contrastive_loss import *
from utils import *
from tqdm import tqdm


class FairClassifier(torch.nn.Module):

    def __init__(self, args, mi_dataloader):
        super(FairClassifier, self).__init__()
        self.args = args
        self.mi_dataloader = mi_dataloader
        self.encoder = encoder_dic[args.encoder_type](args)
        self.proj_content = nn.Sequential(nn.Linear(args.hidden_size, args.hidden_size), nn.LeakyReLU(),
                                          nn.Linear(args.hidden_size, args.hidden_size))

        self.loss_classif = torch.nn.NLLLoss()
        if self.args.fblock == 'SCL':
            self.fblock = SupConLoss(temperature=1, contrast_mode='all',
                                     base_temperature=0.07)
        elif self.args.fblock == 'MI':
            self.fblock = MI_dict[args.mi_estimator_name](args)
        elif self.args.fblock == 'MULTI':
            self.fblock = MultiVariateLoss(args)
        else:
            raise NotImplementedError
        self.downstream_classifier = Classifier(args.hidden_size, args.number_of_public_labels)

    def forward(self, inputs, senstive_labels, public_labels):
        ###############################
        # Update Genloss + \lambda * MI
        ###############################
        losses_dic = {}

        if self.args.add_noise:
            inputs = corrupt_input(self, inputs)

        count = 0
        if self.training:
            frozen_params(self.encoder)
            frozen_params(self.proj_content)
            free_params(self.downstream_classifier)
            if self.fblock.require_learning:
                free_params(self.fblock)
            for step, batch_mi in enumerate(self.mi_dataloader):
                if count > self.args.number_of_training_encoder + 1:
                    losses_dic.update(fblock_tb_dict)
                    break
                count += 1
                try:
                    inputs_mi = batch_mi['text']
                    senstive_labels_mi = batch_mi['sensitive_label']
                    public_labels_mi = batch_mi['public_label']
                except:
                    inputs_mi, public_labels_mi, senstive_labels_mi = batch_mi
                inputs_embedded = self.encoder(inputs_mi)
                inputs_embedded = self.proj_content(inputs_embedded)
                if self.args.l2_normalization:
                    # qn = torch.norm(inputs_embedded, p=2, dim=-1).detach().unsqueeze(-1)
                    # inputs_embedded = inputs_embedded.div(qn.expand_as(inputs_embedded))
                    inputs_embedded = torch.nn.functional.normalize(inputs_embedded, dim=0)
                loss, fblock_tb_dict = 0, {}
                if self.fblock.require_learning:
                    loss, fblock_tb_dict = self.fblock.learning_loss(inputs_embedded,
                                                                     senstive_labels_mi)  # learning with dictionnary

                down_pred = self.downstream_classifier(inputs_embedded)
                loss_down_classifier = self.loss_classif(down_pred, public_labels_mi)
                fblock_tb_dict.update({'loss_down_classifier': loss_down_classifier})

                loss += loss_down_classifier
                loss.backward()
                self.args.optimizer.step()
                self.args.scheduler.step()
                self.proj_content.zero_grad()
                self.fblock.zero_grad()
                self.downstream_classifier.zero_grad()
                self.encoder.zero_grad()

        ###################
        # Downstream loss #
        ###################
        free_params(self.encoder)
        free_params(self.proj_content)
        frozen_params(self.downstream_classifier)
        if self.fblock.require_learning:
            frozen_params(self.fblock)
        encoder_hidden = self.encoder(inputs)
        encoder_hidden = self.proj_content(encoder_hidden)
        if self.args.l2_normalization:
            # qn = torch.norm(encoder_hidden, p=2, dim=-1).detach().unsqueeze(-1)
            # encoder_hidden = encoder_hidden.div(qn.expand_as(encoder_hidden))
            encoder_hidden = torch.nn.functional.normalize(encoder_hidden, dim=0)
        loss_fairness = self.fblock(encoder_hidden, senstive_labels)  # forward without dictionnary

        down_pred = self.downstream_classifier(encoder_hidden)
        loss_down_classifier = self.loss_classif(down_pred, public_labels)

        # Compute All Losses for MI
        if self.training:
            loss = loss_down_classifier + self.args.mul_lambda * loss_fairness
            loss.backward()
            gradient_encoder, gradient_content_proj = comput_gradient_norm(
                self.encoder), comput_gradient_norm(self.proj_content)
            torch.nn.utils.clip_grad_norm_(self.proj_content.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.max_grad_norm)
            self.args.optimizer.step()
            self.args.scheduler.step()
            self.proj_content.zero_grad()
            self.encoder.zero_grad()

            losses_dic.update({
                'gradient_encoder': gradient_encoder,
                'gradient_content_proj': gradient_content_proj
            })

        losses_dic.update({'ce_loss': loss_down_classifier,
                           'fairness_loss': loss_fairness,
                           'total_loss': loss_down_classifier + self.args.mul_lambda * loss_fairness
                           })
        return losses_dic

    def predict_latent_space(self, input_tensor):
        encoder_hidden = self.encoder(input_tensor)
        encoder_hidden = self.proj_content(encoder_hidden)
        if self.args.l2_normalization:
            qn = torch.norm(encoder_hidden, p=2, dim=-1).detach().unsqueeze(-1)
            encoder_hidden = encoder_hidden.div(qn.expand_as(encoder_hidden))
        return encoder_hidden

    def predict_downstream(self, inputs):
        encoder_hidden = self.encoder(inputs)
        encoder_hidden = self.proj_content(encoder_hidden)
        if self.args.l2_normalization:
            qn = torch.norm(encoder_hidden, p=2, dim=-1).detach().unsqueeze(-1)
            encoder_hidden = encoder_hidden.div(qn.expand_as(encoder_hidden))
        down_preds = self.downstream_classifier(encoder_hidden)
        return down_preds

    def predict_fairness(self, inputs, sensitive_labels, public_labels):
        encoder_hidden = self.encoder(inputs)
        encoder_hidden = self.proj_content(encoder_hidden)
        if self.args.l2_normalization:
            # qn = torch.norm(encoder_hidden, p=2, dim=-1).detach().unsqueeze(-1)
            # encoder_hidden = encoder_hidden.div(qn.expand_as(encoder_hidden))
            encoder_hidden = torch.nn.functional.normalize(encoder_hidden, dim=0)
        loss_fairness = self.fblock(encoder_hidden, sensitive_labels)
        return loss_fairness
