import numpy as np
import torch
import os
from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
import logging
import copy

logger = logging.getLogger(__name__)


class AdvDisentanglement:
    """
    Compute a number of stats from a dataset with sensitive(s) attribute(s).

    -------------------------------------
    Attributes:

    -------------------------------------
    Methods:

    """

    def __init__(self, model, disentanglement_classifier, train_dataloader, val_dataloader, test_dataloader,
                 output_dir, optimizer, scheduler, save_step,eval_step, max_number_of_steps, batch_size):
        self.model = model
        self.disentanglement_classifier = disentanglement_classifier
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.output_dir = output_dir
        self.max_number_of_steps = max_number_of_steps
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.scheduler = scheduler
        self.save_step = save_step
        self.eval_step = eval_step
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def fit(self):
        tb_writer = SummaryWriter('metric_evaluation/{}/'.format(self.output_dir))

        loss_fct = torch.nn.NLLLoss()
        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num Epochs = %d", self.max_number_of_steps)
        logger.info("  Instantaneous batch size = %f", self.batch_size)

        best_loss = 100000000
        self.model.eval()
        global_step = 0
        self.disentanglement_classifier.zero_grad()

        epoch_iterator = tqdm(self.train_dataloader, desc="Iteration")
        tr_loss = 0
        for _ in trange(0, int(100), desc="Epoch"):
            for step, batch in enumerate(epoch_iterator):
                if global_step > self.max_number_of_steps:
                    break
                self.disentanglement_classifier.train()
                inputs = batch['text']
                senstive_labels = batch['sensitive_label']
                with torch.no_grad():
                    embeddings = self.model.predict_latent_space(inputs)
                prediction = self.disentanglement_classifier(embeddings)
                loss = loss_fct(prediction, senstive_labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.disentanglement_classifier.parameters(), 1.0)
                self.optimizer.step()
                self.scheduler.step()
                self.disentanglement_classifier.zero_grad()
                tr_loss += loss
                if global_step % self.eval_step == 0:
                    tb_writer.add_scalar("train_loss", tr_loss / self.eval_step, global_step)
                    tb_writer.add_scalar("train_lr", self.scheduler.get_lr()[0], global_step)
                    tr_loss = 0

                global_step += 1

                if global_step % self.eval_step == 0:
                    checkpoint_prefix = "checkpoint"
                    loss = 0
                    for val_enumerate, val_batch in enumerate(tqdm(self.val_dataloader, desc="Validation")):
                        inputs = val_batch['text']
                        senstive_labels = val_batch['sensitive_label']
                        with torch.no_grad():
                            embeddings = self.model.predict_latent_space(inputs)
                            self.disentanglement_classifier.eval()
                            prediction = self.disentanglement_classifier(embeddings)
                        loss += loss_fct(prediction, senstive_labels).item()
                    loss = loss / val_enumerate
                    tb_writer.add_scalar("val_loss", loss, global_step)

                if global_step % self.save_step == 0 and best_loss > loss and global_step > 5000:
                    # Save model checkpoint
                    best_loss = loss
                    output_dir = os.path.join(self.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    classifier_path = os.path.join(output_dir, 'classifier_disentanglement.pt')
                    torch.save(self.disentanglement_classifier.state_dict(), classifier_path)
                    logger.info("Saving model checkpoint to %s", output_dir)
                self.disentanglement_classifier.train()
        # Load last best classifier :)
        logger.info("Last checkpoint %s", global_step)
        logger.info("Reloading Best Saved model at %s", classifier_path)
        self.disentanglement_classifier.load_state_dict(
            torch.load(classifier_path, map_location=torch.device(self.device)))

    def predict(self):
        loss_fct = torch.nn.NLLLoss()
        # Eval!
        logger.info("***** Running evaluation *****")
        self.disentanglement_classifier.eval()
        self.model.eval()
        self.disentanglement_classifier.eval()
        losses_sensitive = []
        losses_downstream = []
        accuracies_downstream = []
        accuracies_sensitive = []
        for batch in tqdm(self.test_dataloader, desc="Evaluating"):
            with torch.no_grad():
                inputs = batch['text']
                senstive_labels = batch['sensitive_label']
                public_labels = batch['public_label']
                embeddings = self.model.predict_latent_space(inputs)

                prediction_sensitive = self.disentanglement_classifier(embeddings)
                loss_sensitive = loss_fct(prediction_sensitive, senstive_labels)
                losses_sensitive.append(loss_sensitive.item())
                accuracy_sensitive = sum(
                    [i == j for i, j in
                     zip(prediction_sensitive.topk(1)[-1].squeeze(-1).tolist(), senstive_labels.tolist())]) / len(
                    senstive_labels.tolist())
                accuracies_sensitive.append(accuracy_sensitive)

                prediction_downstream = self.model.predict_downstream(inputs)
                loss_downstream = loss_fct(prediction_downstream, public_labels)
                losses_downstream.append(loss_downstream.item())
                accuracy_downstream = sum(
                    [i == j for i, j in
                     zip(prediction_downstream.topk(1)[-1].squeeze(-1).tolist(), public_labels.tolist())]) / len(
                    public_labels.tolist())
                accuracies_downstream.append(accuracy_downstream)

        accuracies_sensitive = sum(accuracies_sensitive) / len(accuracies_sensitive)
        accuracies_downstream = sum(accuracies_downstream) / len(accuracies_downstream)
        loss_sensitive = sum(losses_sensitive) / len(losses_sensitive)
        losses_downstream = sum(losses_downstream) / len(losses_downstream)

        results = {
            "accuracies_sensitive": accuracies_sensitive,
            "accuracies_downstream": accuracies_downstream,
            "loss_sensitive": loss_sensitive,
            "losses_downstream": losses_downstream,
        }
        return results


if __name__ == '__main__':
    Nobs = 100
    d = 10
    n_samp = 500

    Z = np.random.multivariate_normal(
        np.zeros(d), np.eye(d), size=Nobs)
    norms = np.diag(1 / np.sqrt(np.sum(Z ** 2, axis=1)))
    Z = norms @ Z
    Y = np.random.uniform(0, 1, size=Nobs)
    S = np.random.binomial(1, 0.5, size=Nobs).astype('int')

    data = np.c_[Z, Y, S]

    Voro = Voronoi()
    Voro.fit(data)
    print(Voro.directions)
    areas = Voro.predict()
    print(Voro.cells[0])
    print(areas)
