# coding=utf-8
import csv
import argparse
import logging
from argparse import ArgumentParser
import os
import random
import numpy as np
import torch
from multiprocessing import Pool
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
from transformers import AdamW
from data_processor import reader
import json
from metric_adv_disen import AdvDisentanglement
from metric_viz import MetricViz
from voronoi import Voronoi
from metric_fairness import MetricFairness
import sys

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from tensorboardX import SummaryWriter
import copy
from models_classifier import *
from model_utils import *
from data_processor import FairClassificationDataset

logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Loading and Saving Results
    parser.add_argument("--path", default='models/classif_blog_gender_MI_ADV_0.1_RNN',
                        help="loading from path")
    parser.add_argument("--checkpoints", default='checkpoint-6000', help="loading from path")
    parser.add_argument("--output_dir", default='debug',
                        help="The output directory where the model predictions and checkpoints will be written.")

    # Training Args for Classifier
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--save_step", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--eval_step", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--max_number_of_steps", type=int, default=10000, help="random seed for initialization")
    parser.add_argument("--warmup_steps", type=int, default=500, help="random seed for initialization")
    parser.add_argument("--learning_rate", type=float, default=0.0001, help="random seed for initialization")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--filter", default=600, type=int, help="The input training data file (a text file).")
    parser.add_argument("--n_sample", default=100000, type=int, help="The input training data file (a text file).")
    parser.add_argument("--batch_size", default=64, type=int, help="The input training data file (a text file).")
    parser.add_argument("--reduce_training_size", action="store_true",
                        help="output of the encoder need to be normalize")
    parser.add_argument("--fairness_only", action="store_true", help="output of the encoder need to be normalize")

    args_training = parser.parse_args()
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args_training)

    args_model = ArgumentParser()
    # args_model = parser_model.parse_args()
    with open(os.path.join(args_training.path, args_training.checkpoints, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_training.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.reduce_training_size = args_training.reduce_training_size
    args_model.encoder_type = "RNN" if "RNN" in args_training.path else "BERT"
    args_model.fblock = "SCL" if "SCL" in args_training.path else "MI" if "MI" in args_training.path else "MULTI"
    assert args_model.fblock in args_training.path

    for multi_loss_type in ["SINKHORN", "HAUSDORFF", "ENERGY", "GAUSSIAN", "LAPLACIAN", 'RAO', 'FRECHET', 'JS']:
        if multi_loss_type in args_training.path:
            args_model.multi_loss_type = multi_loss_type

    for mi_estimator_name in ["KNIFE", "MIReny", "ADV"]:
        if mi_estimator_name in args_training.path:
            args_model.mi_estimator_name = mi_estimator_name

    for dataset_name in reader.keys():
        if dataset_name in args_training.path:
            args_model.dataset = dataset_name

    args_model.batch_size = args_training.batch_size
    args_model.max_number_of_steps = args_training.max_number_of_steps
    os.makedirs(args_training.output_dir, exist_ok=True)
    with open(os.path.join(args_training.output_dir, 'training_args.txt'), 'w') as f:
        dict_to_save = copy.copy(args_training.__dict__)
        for key, value in dict_to_save.items():
            if value is None:
                pass
            elif isinstance(value, (bool, int, float, str)):  # TODO : string
                pass
            elif isinstance(value, (tuple, list)):
                pass
            elif isinstance(value, dict):
                pass
            else:
                dict_to_save[key] = 0
        json.dump(dict_to_save, f, indent=2)

    args_model.filter = args_training.filter

    logger.info("------------------------------------------ ")
    logger.info("ARGS %s ", args_training)
    logger.info("------------------------------------------ ")
    model = FairClassifier(args_model, None)
    weigths_non_loaded = model.proj_content[0].weight.tolist()
    model.load_state_dict(torch.load(os.path.join(args_training.path, args_training.checkpoints, 'fair_classifier.pt'),
                                     map_location=torch.device(args_model.device)), strict=False)
    assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    model.to(args_model.device)
    model.eval()

    training_set = FairClassificationDataset(args_model, 'train')
    train_sampler = RandomSampler(training_set)
    train_dataloader = DataLoader(training_set, sampler=train_sampler, batch_size=args_training.batch_size,
                                  drop_last=True)

    val_set = FairClassificationDataset(args_model, 'val')
    val_sampler = SequentialSampler(val_set)
    val_dataloader = DataLoader(val_set, sampler=val_sampler, batch_size=args_training.batch_size, drop_last=True)

    test_set = FairClassificationDataset(args_model, 'test')
    test_sampler = SequentialSampler(test_set)
    test_dataloader = DataLoader(test_set, sampler=test_sampler, batch_size=args_training.batch_size, drop_last=True)

    disentanglement_classifier = Classifier(args_model.hidden_size, 2).to(
        args_model.device)  # TODO: generalize to multiple labels

    # optimizer = AdamW(disentanglement_classifier.parameters(), lr=args_training.learning_rate)
    # scheduler = get_linear_schedule_with_warmup(optimizer, args_training.warmup_steps,
    #                                            100 * args_training.max_number_of_steps)

    metric_results = {}
    logger.info(" Fitting  Voronois ")
    metric_voronoi = Voronoi(model, test_dataloader, args_training.n_sample, metric="euclidian")
    args_training.output_dir = os.path.join("results_voronoi", "{}".format(args_training.output_dir))
    os.makedirs(args_training.output_dir, exist_ok=True)
    if True:  # args_training.prepare_data_for_voronoi:
        metric_voronoi.save_data_from_dataloader(args_training.output_dir)
    elif False:  # args_training.load_prepare_data_for_voronoi:
        pass

    # metric_voronoi.fit()

    # Z, U = metric_voronoi.cells_counting()

    # def fast_voronoi_loop(u):
    #    return np.argmin([np.linalg.norm(u - z) for z in Z])

    # with Pool(10) as p:
    #    results = p.map(fast_voronoi_loop, U)
    # metric_voronoi.update(results)

    # evaluaition_voronoi = metric_voronoi.predict()

    # metric_results.update(evaluaition_voronoi)

    logger.info(" Fitting  Adv ")
    # metric_adv = AdvDisentanglement(model, disentanglement_classifier, train_dataloader, val_dataloader,
    #                                 test_dataloader, args_training.output_dir, optimizer, scheduler,
    #                                 args_training.save_step, args_training.eval_step,
    #                                 args_training.max_number_of_steps,
    #                                 args_training.batch_size)
    #
    # metric_adv.fit()
    # evaluaition_adv = metric_adv.predict()
    # metric_results.update(evaluaition_adv)
    #
    # metric_results.update({"fblock": args_model.fblock,
    #                        "encoder_type": args_model.encoder_type,
    #                        "mul_lambda": args_model.mul_lambda,
    #                        "dataset": args_model.dataset,
    #                        "mi_estimator_name": args_model.mi_estimator_name,
    #                        "multi_loss_type": args_model.multi_loss_type})
    # os.makedirs(args_training.output_dir, exist_ok=True)
    # with open(os.path.join(args_training.output_dir, 'evaluation_results.json'), 'w') as file:
    #     json.dump(metric_results, file)

    logger.info(" Fitting  Viz ")
    # metric_viz = MetricViz(model, disentanglement_classifier, train_dataloader, val_dataloader,
    #                       test_dataloader, args_training.output_dir, optimizer, scheduler,
    #                       args_training.save_step, args_training.eval_step,
    #                       args_training.max_number_of_steps,
    #                       args_training.batch_size)

    # viz = metric_viz.predict()
    # args_training.output_dir_viz = os.path.join("results_evaluation", "viz_{}".format(args_training.output_dir))
    # os.makedirs(args_training.output_dir_viz, exist_ok=True)
    # torch.save(viz, os.path.join(args_training.output_dir_viz, 'viz.pt'))

    logger.info(" Fitting  Fairness ")
    # metric_fairness = MetricFairness(model, disentanglement_classifier, train_dataloader, val_dataloader,
    #                                  test_dataloader, args_training.output_dir, optimizer, scheduler,
    #                                  args_training.save_step, args_training.eval_step,
    #                                  args_training.max_number_of_steps,
    #                                  args_training.batch_size)
    #
    # results_fairness = metric_fairness.predict()
    # metric_results.update(results_fairness)
    # args_training.output_dir = os.path.join("results_evaluation", "fairness_only_{}".format(args_training.output_dir))
    # os.makedirs(args_training.output_dir, exist_ok=True)
    # metric_results.update({"fblock": args_model.fblock,
    #                        "encoder_type": args_model.encoder_type,
    #                        "mul_lambda": args_model.mul_lambda,
    #                        "dataset": args_model.dataset,
    #                        "mi_estimator_name": args_model.mi_estimator_name,
    #                        "multi_loss_type": args_model.multi_loss_type})
    #
    # with open(os.path.join(args_training.output_dir, 'evaluation_results.json'), 'w') as file:
    #    json.dump(metric_results, file)

#    main()
