# coding=utf-8
import argparse
import logging
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
from transformers import AdamW
from models_classifier import *
import time
import sys

try:
    from transformers import (get_linear_schedule_with_warmup, get_constant_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from tensorboardX import SummaryWriter
from data_processor import *
from utils import *

logger = logging.getLogger(__name__)


def train(args, train_dataloader, val_dataloader, model):
    """ Train the model """
    tb_writer = SummaryWriter('runs_{}_{}/{}'.format(args.fblock, args.dataset, args.output_dir))

    t_total = args.max_number_of_steps
    optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon, weight_decay=args.weight_decay)
    scheduler = get_constant_schedule_with_warmup(optimizer, args.warmup_steps)

    args.scheduler = scheduler
    args.optimizer = optimizer
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num of Steps = %d", args.max_number_of_steps)
    logger.info("  Instantaneous batch size = %f", args.batch_size)
    logger.info("  Total optimization steps = %f", t_total)

    global_step = 0
    dict_loss = {}
    batch_times = []
    model.zero_grad()
    train_iterator = trange(0, 10000000, desc="Epoch")
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        start_epoch = time.time()
        for step, batch in enumerate(epoch_iterator):
            if global_step > args.max_number_of_steps:
                break
            inputs = batch['text']
            senstive_labels = batch['sensitive_label']
            public_labels = batch['public_label']
            model.train()
            start_one_batch = time.time()
            tr_loss_dic = model(inputs, senstive_labels, public_labels)
            end_one_batch = time.time()
            batch_times.append(end_one_batch - start_one_batch)
            global_step += 1
            if dict_loss == {}:
                for key, value in tr_loss_dic.items():
                    dict_loss[key] = 0
            for key, value in tr_loss_dic.items():
                dict_loss[key] += value.item()
            if global_step % 100 == 0:
                for key, value in dict_loss.items():
                    logger.info("  Training {} = %5f".format(key), dict_loss[key] / 100)
                    tb_writer.add_scalar("train_{}".format(key), dict_loss[key] / 100, global_step)
                dict_loss = {}

            if global_step % args.eval_step == 0:
                val_dict_loss = evaluate(args, model, val_dataloader)

                for key, value in val_dict_loss.items():
                    tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                logger.info("  lr = %5f", scheduler.get_lr()[0])

            if global_step % args.save_step == 0 or global_step == 1:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(), os.path.join(output_dir, 'fair_classifier.pt'))
                with open(os.path.join(output_dir, 'training_args.txt'), 'w') as f:
                    dict_to_save = copy.copy(args.__dict__)
                    for key, value in dict_to_save.items():
                        if value is None:
                            pass
                        elif isinstance(value, (bool, int, float)):
                            pass
                        elif isinstance(value, (tuple, list)):
                            pass
                        elif isinstance(value, dict):
                            pass
                        else:
                            dict_to_save[key] = 0
                    json.dump(dict_to_save, f, indent=2)
                logger.info("Saving model checkpoint to %s", output_dir)

        end_epoch = time.time()
        # os.makedirs('time_{}'.format(args.output_dir), exist_ok=True)
        # with open('time_{}/results.txt'.format(args.output_dir), 'w') as file:
        #    file.writelines('Epoch time {}\n'.format(end_epoch - start_epoch))
        #    file.writelines('Update time {}\n'.format(sum(batch_times) / len(batch_times)))
        # sys.exit(0)


def evaluate(args, model, val_dataloader):
    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Batch size = %d", args.batch_size)
    model.eval()
    f_dict_loss = {}
    nb_eval_steps = 0
    for batch in tqdm(val_dataloader, desc="Evaluating"):
        inputs = batch['text']
        senstive_labels = batch['sensitive_label']
        public_labels = batch['public_label']
        model.eval()
        with torch.no_grad():
            dict_loss = model(inputs, senstive_labels, public_labels)
            if len(f_dict_loss) == 0:
                logger.info('Initialization of validation tensorboard dictionary')
                for key, value in dict_loss.items():
                    f_dict_loss[key] = 0
            for key, value in dict_loss.items():
                f_dict_loss[key] += dict_loss[key].item()
        nb_eval_steps += 1
    for key, value in dict_loss.items():
        f_dict_loss[key] = f_dict_loss[key] / nb_eval_steps
    model.train()
    return f_dict_loss


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", type=int, default=23, help="random seed for initialization")  # change of seed

    # Saving Parameters
    parser.add_argument("--output_dir", default='debug',
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--eval_step", type=int, default=2, help="Log every X updates steps.")
    parser.add_argument("--save_step", type=int, default=100, help="Save checkpoint every X updates steps.")

    # Encoder
    parser.add_argument("--encoder_type", default='BERT', choices=["RNN", "BERT", "MLP"],
                        help="random seed for initialization")
    parser.add_argument("--number_of_layers", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--hidden_size", type=int, default=67, help="random seed for initialization")
    parser.add_argument("--filter", type=float, default=0.001, help="random seed for initialization")
    parser.add_argument("--dropout", type=float, default=0.5, help="random seed for initialization")
    parser.add_argument("--l2_normalization", action="store_true", help="output of the encoder need to be normalize")
    # TODO : normalization

    # Fairness
    parser.add_argument("--fblock", type=str, default="MI",
                        choices=["SCL", "MI", "MULTI"])
    parser.add_argument("--multi_loss_type", type=str, default="RAO",
                        choices=["SINKHORN", "HAUSDORFF", "ENERGY", "GAUSSIAN", "LAPLACIAN", 'RAO', 'FRECHET', 'JS'])
    parser.add_argument("--mi_estimator_name", type=str, default="ADV",
                        choices=["KNIFE", "MIReny", "ADV"])
    # Wasserstein
    parser.add_argument("--power", type=int, default=2, choices=[1, 2])
    parser.add_argument("--blur", type=float, default=0.05)
    parser.add_argument("--kernel_size", type=int, default=128, help="KNIFE kernel size")
    parser.add_argument("--alpha", type=float, default=1.5, help="alpha for reny divergence")
    parser.add_argument("--no_reny", action="store_true", help="random seed for initialization")
    parser.add_argument("--number_of_training_encoder", type=int, default=10, help="random seed for initialization")
    parser.add_argument("--mul_lambda", type=float, default=1, help="random seed for initialization")
    # Data
    parser.add_argument("--reduce_training_size", action="store_true", help="random seed for initialization")
    parser.add_argument("--dataset", default='bio', type=str,
                        choices=['blog_age', 'blog_gender', 'pan_age', 'pan_gender', 'dial_mention', 'dial_sentiment',
                                 'trust_age', 'trust_gender', 'bio'])

    # Training Details
    parser.add_argument("--add_noise", action="store_true", help="random seed for initialization")
    parser.add_argument("--noise_p", type=float, default=0.1, help="random seed for initialization")
    parser.add_argument("--number_of_perm", type=int, default=3, help="random seed for initialization")
    parser.add_argument("--batch_size", default=10, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--learning_rate", default=1e-3, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=5e-2, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--max_number_of_steps", default=1000, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_steps", default=1000, type=int, help="Linear warmup over warmup_steps.")

    args = parser.parse_args()

    if args.mi_estimator_name == "KNIFE" and args.fblock == 'MI':
        args.mul_lambda = args.mul_lambda * 1000

    # TODO : MIReny MIKNIFE

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    assert args.filter > 0 and args.filter <= 1.0

    if args.fblock == 'SCL':
        assert args.l2_normalization, "L2 normalization required for SCL"

    # Bounds hold only if considered reny > 1
    assert args.alpha > 1

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)

    # Set seed
    set_seed(args)
    args.logger = logger
    logger.info("Training Dataset")
    training_set = FairClassificationDataset(args, 'train')
    train_sampler = RandomSampler(training_set)
    train_dataloader = DataLoader(training_set, sampler=train_sampler, batch_size=args.batch_size,
                                  drop_last=True)
    logger.info("MI Dataset")
    training_mi_set = FairClassificationDataset(args, 'train')
    mi_sampler = RandomSampler(training_mi_set)
    mi_dataloader = DataLoader(training_mi_set, sampler=mi_sampler, batch_size=args.batch_size,
                               drop_last=True)
    logger.info("Val Dataset")
    val_set = FairClassificationDataset(args, 'val')
    val_sampler = SequentialSampler(val_set)
    val_dataloader = DataLoader(val_set, sampler=val_sampler, batch_size=args.batch_size, drop_last=True)

    #########
    # Model #
    #########
    model = FairClassifier(args, mi_dataloader)
    model.to(args.device)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(params)
    logger.info("Training/evaluation parameters %s", args)

    #########
    # Train #
    #########
    train(args, train_dataloader, val_dataloader, model)
    logger.info(" Training Over ")


if __name__ == "__main__":
    main()
