import torch
import torch.nn as nn
from model_utils import *
import numpy as np
from utils import *
import copy


class AdversarialBaseline(nn.Module):
    def __init__(self, args):
        super(AdversarialBaseline, self).__init__()
        self.classifier = Classifier(args.hidden_size, args.number_sensitive_label)
        self.require_learning = True
        self.loss_classif = torch.nn.NLLLoss()

    def forward(self, inputs_embedded, sensitive_labels_mi):
        down_pred = self.classifier(inputs_embedded)
        loss_down_classifier = self.loss_classif(down_pred, sensitive_labels_mi)
        return -loss_down_classifier

    def learning_loss(self, inputs_embedded, sensitive_labels_mi):
        adv_loss = -self.forward(inputs_embedded, sensitive_labels_mi)
        return adv_loss, {"tr_advers_loss": adv_loss}





class MIReny(nn.Module):
    def __init__(self, args):
        super(MIReny, self).__init__()
        self.require_learning = True
        self.args = args
        self.loss_classif = torch.nn.NLLLoss()
        assert self.args.number_sensitive_label == 2, f'Sensitive label not binary you need to code a bit :)'
        self.advers_classifier = Classifier(args.hidden_size, args.number_sensitive_label)
        self.density_ratio_classifier = Classifier(args.hidden_size + args.number_sensitive_label,
                                                   args.number_sensitive_label)

    def learning_loss(self, inputs_embedded, senstive_labels_mi):  # samples have shape [sample_size, dim]
        # Classifier Sensitive
        adv_pred = self.advers_classifier(inputs_embedded)
        loss_adv_classifier = self.loss_classif(adv_pred, senstive_labels_mi)

        # Density Ratio Classifier
        label_content_pred = adv_pred.topk(1, dim=-1)[-1].squeeze(-1)
        label_v = torch.tensor(
            [[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
            self.args.device)
        label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in senstive_labels_mi.tolist()]).to(
            self.args.device)
        u = torch.cat([inputs_embedded, label_u.float().repeat(1, 1)],
                      # TODO : in ACL paper it was copied 4 times impact ?
                      dim=-1)  # 4 is for bid + 2 layers
        v = torch.cat([inputs_embedded, label_v.float().repeat(1, 1)],
                      dim=-1)  # TODO : in ACL paper it was copied 4 times impact ?
        d_gamma_content_pred_u = self.density_ratio_classifier(u)
        d_gamma_content_pred_v = self.density_ratio_classifier(v)
        loss_gamma = - torch.mean(d_gamma_content_pred_u[:, 0]) / 2 - torch.mean(
            d_gamma_content_pred_v[:, 1]) / 2

        return loss_gamma + loss_adv_classifier, {
            "loss_gamma": loss_gamma,
            "loss_adv_classifier": loss_adv_classifier}

    def forward(self, inputs_embedded, senstive_labels_mi):
        adv_pred = self.advers_classifier(inputs_embedded)
        loss_h_sz = self.loss_classif(adv_pred, senstive_labels_mi)

        len_ = len(senstive_labels_mi.tolist())
        loss_h_s = (-1) * torch.log(torch.sum(torch.exp(adv_pred)[:, 0]) / len_) * (
                1 - torch.sum(senstive_labels_mi).item() / len_) - torch.log(
            torch.sum(torch.exp(adv_pred)[:, 1]) / len_) * (torch.sum(senstive_labels_mi).item() / len_)

        # Compute Reny
        label_content_pred = adv_pred.topk(1, dim=-1)[-1].squeeze(-1)
        label_v = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
            self.args.device)
        label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in senstive_labels_mi.tolist()]).to(
            self.args.device)
        u = torch.cat([inputs_embedded, label_u.float().repeat(1, 1)],
                      dim=-1) 
        v = torch.cat([inputs_embedded, label_v.float().repeat(1, 1)],
                      dim=-1) 
        d_gamma_content_pred_u = self.density_ratio_classifier(u)
        d_gamma_content_pred_v = self.density_ratio_classifier(v)
        R = torch.mean((torch.exp(d_gamma_content_pred_u[:, 0]) / torch.exp(d_gamma_content_pred_v[:, 1])) ** (
                self.args.alpha - 1))

        # Remove biais from gradients
        reny = torch.abs(torch.log(R) / (self.args.alpha - 1))

        return torch.abs(loss_h_s - loss_h_sz + reny)  # TODO : KL time


MI_dict = {
    "ADV": AdversarialBaseline,
    "MIReny": MIReny
}
