import json
import os
import random
import torch.nn.functional as F
import numpy as np
import torch
import torch.nn as nn
from varslot.models import VariableSlotModel
import sys
from loguru import logger
from overrides import overrides
from prefect import Task
from sklearn.metrics import precision_recall_fscore_support
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm, trange
from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)

torch.set_printoptions(threshold=5000)
torch.set_printoptions(edgeitems=50)
np.set_printoptions(threshold=sys.maxsize)


class VariableSlotTrainer(Task):
    def __init__(self, **kwargs):

        self.per_gpu_batch_size = kwargs.pop("per_gpu_batch_size",32)
        self.cuda = kwargs.pop("cuda", True)
        self.gradient_accumulation_steps = kwargs.pop("gradient_accumulation_steps", 1)
        self.num_train_epochs = kwargs.pop("num_train_epochs", 10)
        self.learning_rate = kwargs.pop("learning_rate", 1e-4)
        self.weight_decay = kwargs.pop("weight_decay", 0.0)
        self.adam_epsilon = kwargs.pop("adam_epsilon", 1e-8)
        self.warmup_steps = kwargs.pop("warmup_steps", 0)
        self.max_grad_norm = kwargs.pop("max_grad_norm", 1.0)
        self.logging_steps = kwargs.pop("logging_steps", 5)
        self.embedding_layer_size = kwargs.pop("embedding_layer_size", 256)
        self.num_att_heads = kwargs.pop("num_att_heads", 8)
        self.args = kwargs
        super(VariableSlotTrainer, self).__init__(**kwargs)

    def set_seed(self, n_gpu, seed=42):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(seed)

    @staticmethod
    def create_weights(df):
        positives = 0
        negatives = 0
        weights = list()
        for x in df:
            if x[5] == 0:
                negatives = negatives + 1
            else:
                positives = positives + 1

        weight_positive = 1.0 / float(positives)
        weight_negative = 1.0 / float(negatives)

        for x in df:
            if x[5] == 0:
                weights.append(weight_negative)
            else:
                weights.append(weight_positive)

        print(positives)
        print(negatives)
        return torch.tensor(weights)

    @overrides
    def run(
        self,
        train_dataset,
        dev_dataset,
        test_dataset,
        task_name,
        output_dir,
        largest=128,
        model_name="bert-base-uncased",
        is_baseline=0,
        dimension_size=768,
        num_iters=-1,
        mode="train",
        eval_fn=None,
        save_optimizer=False,
        eval_params={},
        scores=None,
        msg_report="",
    ):
        torch.cuda.empty_cache()
        device = torch.device(
            "cuda" if torch.cuda.is_available() and self.cuda else "cpu"
        )

        n_gpu = torch.cuda.device_count()
        self.set_seed(n_gpu)
        self.logger.info(f"GPUs used {n_gpu}")

        train_batch_size = self.per_gpu_batch_size * max(1, n_gpu)

        train_class_weight = self.create_weights(train_dataset)

        g_cpu = torch.Generator()
        g_cpu.manual_seed(42)

        sampler_train = WeightedRandomSampler(
            train_class_weight, len(train_class_weight), generator=g_cpu
        )

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=sampler_train,
        )
        dev_dataloader = DataLoader(
            dev_dataset, batch_size=train_batch_size, shuffle=False
        )

        criterion = nn.BCEWithLogitsLoss()

        outputs = {}
        if mode == "train":
            logger.info("Running train mode")
            model = VariableSlotModel(
                device=device,
                model_name=model_name,
                is_baseline=is_baseline,
                dimension_size=dimension_size,
                num_iters=num_iters,
            )
            model = model.to(device)
            if n_gpu > 1:
                model = torch.nn.DataParallel(model)
            epoch_results = self.train(
                model,
                criterion,
                train_dataloader,
                dev_dataloader,
                dev_dataset,
                device,
                n_gpu,
                eval_fn,
                f"{output_dir}/{task_name}",
                save_optimizer,
                eval_params,
            )
            outputs["epoch_results"] = epoch_results

        logger.info("Running evaluation mode")
        logger.info(f"Loading from {output_dir}/{task_name}")
        model = VariableSlotModel(
            device=device,
            model_name=model_name,
            is_baseline=is_baseline,
            dimension_size=dimension_size,
            num_iters=num_iters,
        )
        model.load_state_dict(
            torch.load(os.path.join(f"{output_dir}/{task_name}", "training_args.bin"))
        )

        model.to(device)
        if mode != "test":
            score, score_complete_dev = self.eval(
                criterion,
                model,
                dev_dataloader,
                dev_dataset,
                device,
                n_gpu,
                eval_fn,
                eval_params,
                mode="dev",
            )
            outputs["dev"] = {
                "score": score,
            }
        else:
            score_complete_dev = {"score": "test"}

        if test_dataset is not None:
            test_data_loader = DataLoader(
                test_dataset, batch_size=train_batch_size, shuffle=False
            )
            score, score_complete_test = self.eval(
                criterion,
                model,
                test_data_loader,
                test_dataset,
                device,
                n_gpu,
                eval_fn,
                eval_params,
                mode="test",
            )
            outputs["test"] = {
                "score": score,
            }

            logger.info(msg_report)
            logger.info("Model")
            logger.info(model_name)
            logger.info("Baseline")
            logger.info(is_baseline)
            logger.info("Num iters")
            logger.info(num_iters)
            logger.info("DEV RESULTS")
            logger.info(score_complete_dev)
            logger.info("TEST RESULTS")
            logger.info(score_complete_test)

        return {
            "dev": score_complete_dev,
            "test": {
                "f1_score": outputs["test"]["score"],
                "others": score_complete_test,
            },
        }

    def train(
        self,
        model,
        criterion,
        train_dataloader,
        dev_dataloader,
        dev_dataset,
        device,
        n_gpu,
        eval_fn,
        output_dir,
        save_optimizer,
        eval_params,
    ):
        results = {}
        best_score = 0.0
        t_total = (
            len(train_dataloader)
            // self.gradient_accumulation_steps
            * self.num_train_epochs
        )

        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.weight_decay,
                "lr": self.learning_rate,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
                "lr": self.learning_rate,
            },
        ]

        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters, eps=self.adam_epsilon, lr=self.learning_rate
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=t_total,
        )

        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss, logging_loss, loss_scalar = 0.0, 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained,
            int(self.num_train_epochs),
            desc="Epoch",
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            epoch_loss = 0
            for step, batch in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                model.train()
                batch = tuple(t.to(device) for t in batch)

                input_model = {
                    "bert_pos_var": batch[0],
                    "bert_pos_exp": batch[1],
                    "sentence_input_ids": batch[2],
                    "sentence_token_type_ids": batch[3],
                    "sentence_attention_mask": batch[4],
                    "all_exp_pos": batch[6],
                }

                label = batch[5]

                pred, _ = model(**input_model)

                label = label.type_as(pred)

                loss = criterion(pred, label.unsqueeze(1))

                epoch_loss += loss.item()

                if n_gpu > 1:
                    loss = (
                        loss.mean()
                    )  # mean() to average on multi-gpu parallel training
                if self.gradient_accumulation_steps > 1:
                    loss = (loss) / self.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), self.max_grad_norm
                    )

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if self.logging_steps > 0 and global_step % self.logging_steps == 0:
                        loss_scalar = (tr_loss - logging_loss) / self.logging_steps

                        learning_rate_scalar = scheduler.get_lr()[0]
                        epoch_iterator.set_description(
                            f"Loss :{loss_scalar} LR: {learning_rate_scalar}"
                        )
                        # epoch_iterator.set_description(f"Loss :{loss_scalar} ")
                        logging_loss = tr_loss

            # logger.debug(f"Epoch loss: {loss_scalar}")
            # scheduler.step(loss_scalar)

            score, _ = self.eval(
                criterion,
                model,
                dev_dataloader,
                dev_dataset,
                device,
                n_gpu,
                eval_fn,
                eval_params,
                mode="dev",
            )
            results[epoch] = score

            with torch.no_grad():
                if score > best_score:
                    logger.success(f"Storing the new model with best F1-score: {score}")
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training

                    torch.save(
                        model_to_save.state_dict(),
                        os.path.join(output_dir, "training_args.bin"),
                    )
                    logger.info(f"Saving model checkpoint to {output_dir}")
                    if save_optimizer:
                        torch.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "optimizer.pt"),
                        )
                        torch.save(
                            scheduler.state_dict(),
                            os.path.join(output_dir, "scheduler.pt"),
                        )
                        logger.info(
                            "Saving optimizer and scheduler states to %s", output_dir
                        )
                    best_score = score

        # return results

    def eval(
        self,
        criterion,
        model,
        dataloader,
        dataset,
        device,
        n_gpu,
        eval_fn,
        eval_params,
        mode,
    ):
        if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
            model = torch.nn.DataParallel(model)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None

        for batch in tqdm(dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                input_model = {
                    "bert_pos_var": batch[0],
                    "bert_pos_exp": batch[1],
                    "sentence_input_ids": batch[2],
                    "sentence_token_type_ids": batch[3],
                    "sentence_attention_mask": batch[4],
                    "all_exp_pos": batch[6],
                }

                labels = batch[5]

                outputs, pred = model(**input_model)

                labels_float = labels.type_as(outputs)

                loss = criterion(outputs, labels_float.unsqueeze(1))

                eval_loss += loss.mean().item()

            nb_eval_steps += 1
            if preds is None:
                preds = pred.detach().cpu().numpy()
                out_label_ids = labels.detach().cpu().numpy()

            else:
                preds = np.append(preds, pred.detach().cpu().numpy(), axis=0)

                out_label_ids = np.append(
                    out_label_ids, labels.detach().cpu().numpy(), axis=0
                )

        eval_loss = eval_loss / nb_eval_steps
        preds = preds.squeeze(1)

        logger.info(f"EVAL LOSS: {eval_loss}")
        score = None
        torch.set_printoptions(edgeitems=50)
        if eval_fn is not None:
            preds = [1 if p else 0 for p in preds]
            # print("PREDS:")
            # print(preds)
            score_complete = precision_recall_fscore_support(
                y_true=out_label_ids,
                y_pred=preds,
                average="binary",
            )
            logger.info(f"Score complete: {score_complete}")
            score = eval_fn(preds, out_label_ids, average="binary")
            # if mode == "test":
            #     out_preds = {"preds": preds.tolist(), "gold": out_label_ids.tolist()}
            #     with open(f"./cache/output/bin_preds.json", "w") as fp:
            #         json.dump(out_preds, fp)

            logger.info(f"Score:{score}")

        return score, score_complete
