import numpy as np
import torch
import os
from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
import logging
import copy

logger = logging.getLogger(__name__)


class MetricViz:
    """
    Compute a number of stats from a dataset with sensitive(s) attribute(s).

    -------------------------------------
    Attributes:

    -------------------------------------
    Methods:

    """

    def __init__(self, model, disentanglement_classifier, train_dataloader, val_dataloader, test_dataloader,
                 output_dir, optimizer, scheduler, save_step, eval_step, max_number_of_steps, batch_size):
        self.model = model
        self.disentanglement_classifier = disentanglement_classifier
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.output_dir = output_dir
        self.max_number_of_steps = max_number_of_steps
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.scheduler = scheduler
        self.save_step = save_step
        self.eval_step = eval_step
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def predict(self):
        # Eval!
        logger.info("***** Running evaluation *****")
        self.model.eval()
        senstive_labels = []
        public_labels = []
        embeddings = []
        for batch in tqdm(self.test_dataloader, desc="Evaluating"):
            with torch.no_grad():
                inputs = batch['text']
                senstive_labels.append(batch['sensitive_label'])
                public_labels.append(batch['public_label'])
                embeddings.append(self.model.predict_latent_space(inputs).cpu())

        results = {
            "public_labels": torch.cat(public_labels, dim=0),
            "senstive_labels": torch.cat(senstive_labels, dim=0),
            "embeddings": torch.cat(embeddings, dim=0)
        }
        return results


if __name__ == '__main__':
    Nobs = 100
    d = 10
    n_samp = 500

    Z = np.random.multivariate_normal(
        np.zeros(d), np.eye(d), size=Nobs)
    norms = np.diag(1 / np.sqrt(np.sum(Z ** 2, axis=1)))
    Z = norms @ Z
    Y = np.random.uniform(0, 1, size=Nobs)
    S = np.random.binomial(1, 0.5, size=Nobs).astype('int')

    data = np.c_[Z, Y, S]

    Voro = Voronoi()
    Voro.fit(data)
    print(Voro.directions)
    areas = Voro.predict()
    print(Voro.cells[0])
    print(areas)
