import argparse
import json
import logging
import math
import os
import shutil
import glob

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss
from fastprogress.fastprogress import master_bar, progress_bar
from attrdict import AttrDict

from transformers import (
    AutoConfig,
    AdamW,
    get_linear_schedule_with_warmup,
    BertTokenizerFast,
)

from model.hate_span_model import HateSpanModel
from src import init_logger, set_seed, compute_metrics, show_ner_report

from processor import ner_load_and_cache_examples as load_and_cache_examples
from processor import ner_tasks_num_labels as tasks_num_labels
from processor import ner_processors as processors

logger = logging.getLogger(__name__)

HUGGINGFACE_AUTH_TOKEN = "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"


def train(args, model, tokenizer, train_dataset, dev_dataset=None, test_dataset=None):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = (args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
    else:
        t_total = (len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(t_total * args.warmup_proportion),
        num_training_steps=t_total,
    )

    if os.path.isfile(os.path.join(args.model_dir, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_dir, "scheduler.pt")
    ):
        # Load optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_dir, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_dir, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    global_step = 0
    tr_loss = 0.0

    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "pooled_labels": batch[3],
                "labels": batch[4],
            }
            if model.config.model_type not in ["roberta"]:
                inputs["token_type_ids"] = batch[2]
            outputs = model(**inputs)

            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps and (step + 1) == len(train_dataloader)):

                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.evaluate_test_during_training:
                        evaluate(args, model, tokenizer, test_dataset, "test", global_step)
                    else:
                        evaluate(args, model, tokenizer, dev_dataset, "dev", global_step)

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, "module") else model
                    model_to_save.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to {}".format(output_dir))

                    if args.save_optimizer:
                        torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        logger.info("Saving optimizer and scheduler states to {}".format(output_dir))

            if 0 < args.max_steps < global_step:
                break

        mb.write("Epoch {} done".format(epoch + 1))

        if 0 < args.max_steps < global_step:
            break

    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, eval_dataset, mode, global_step=None):
    results = {}
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    if global_step is not None:
        logger.info("***** Running evaluation on {} dataset ({} step) *****".format(mode, global_step))
    else:
        logger.info("***** Running evaluation on {} dataset *****".format(mode))
    logger.info("  Num examples = {}".format(len(eval_dataset)))
    logger.info("  Eval Batch size = {}".format(args.eval_batch_size))
    eval_loss = 0.0
    nb_eval_steps = 0
    # preds = np.array([])
    processor = processors[args.data](args, tokenizer)
    preds = np.zeros((1, args.max_seq_len, len(processor.get_labels())))
    out_label_ids = np.zeros((1, args.max_seq_len))
    # comment start indices
    token_type_ids_array = np.zeros((1, args.max_seq_len))

    pooled_pred_ids = np.array([])
    pooled_label_ids = np.array([])

    for batch in progress_bar(eval_dataloader):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "pooled_labels": batch[3],
                "labels": batch[4],
            }
            if model.config.model_type not in ["roberta"]:
                inputs["token_type_ids"] = batch[2]

            outputs = model(**inputs)

            # tmp_eval_loss, logits = outputs[:2]
            tmp_eval_loss = outputs['loss']
            logits = outputs['logits']
            pooled_logits = outputs['pooled_logits']

            eval_loss += tmp_eval_loss.mean().item()

        nb_eval_steps += 1
        pooled_max_indices = pooled_logits.max(axis=-1).indices.detach().cpu().numpy()

        pooled_pred_ids = np.append(pooled_pred_ids, pooled_max_indices, axis=0)
        pooled_label_ids = np.append(pooled_label_ids, inputs["pooled_labels"].detach().cpu().numpy(), axis=0)

        preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
        out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        token_type_ids_array = np.append(token_type_ids_array, inputs["token_type_ids"].tolist(), axis=0)

        """
        if preds is None:
            pooled_pred_ids = pooled_max_indices
            pooled_label_ids = inputs["pooled_labels"].detach().cpu().numpy()

            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs["labels"].detach().cpu().numpy()
        else:
            pooled_pred_ids = np.append(pooled_pred_ids, pooled_max_indices, axis=0)
            pooled_label_ids = np.append(pooled_label_ids, inputs["pooled_labels"].detach().cpu().numpy(), axis=0)

            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
        """
    preds = preds[1:]
    out_label_ids = out_label_ids[1:]
    token_type_ids_array = token_type_ids_array[1:]

    eval_loss = eval_loss / nb_eval_steps
    results = {"loss": eval_loss}
    # processor = processors[args.data](args, tokenizer)
    # handling pooled_outputs
    pooled_label_map = {i: label for i, label in enumerate(processor.get_pooled_labels())}
    pooled_preds = [pooled_label_map[pooled_pred_id] for pooled_pred_id in pooled_pred_ids]
    pooled_labels = [pooled_label_map[pooled_label_id] for pooled_label_id in pooled_label_ids]
    pooled_golds_and_preds = [pooled_labels, pooled_preds]

    # handling seq_outputs
    preds = np.argmax(preds, axis=-1)
    label_map = {i: label for i, label in enumerate(processor.get_labels())}
    eval_total, seq_len = out_label_ids.shape[0], out_label_ids.shape[1]

    out_label_list = [[] for _ in range(eval_total)]
    preds_list = [[] for _ in range(eval_total)]
    pad_token_label_id = CrossEntropyLoss().ignore_index    # -100
    for i in range(eval_total):                 # batch-wise
        for j in range(seq_len):             # length-wise
            if out_label_ids[i, j] != pad_token_label_id:   # skipping the -100
                if token_type_ids_array[i][j] == 1 \
                        and j + 1 < seq_len and token_type_ids_array[i][j+1] != 0:  # exclude [SEP] of token_type_ids
                    out_label_list[i].append(label_map[out_label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

    span_golds_and_preds = [out_label_list, preds_list]
    result = {}
    if args.task.strip().split('-')[0] == "sc":
        result = compute_metrics(args.task, pooled_label_ids, pooled_pred_ids)
    elif args.task.strip().split('-')[0] == "sp":
        result = compute_metrics(args.task, out_label_list, preds_list)
        # Update for pooled_output metrics
        if args.task.strip().split('-')[-1] == "off":
            pooled_result = compute_metrics("sc-off", pooled_label_ids, pooled_pred_ids)
        elif args.task.strip().split('-')[-1] == "tgt":
            pooled_result = compute_metrics("sc-tgt", pooled_label_ids, pooled_pred_ids)
        elif args.task.strip().split('-')[-1] == "group":
            pooled_result = compute_metrics("sc-group", pooled_label_ids, pooled_pred_ids)
        else:
            raise ValueError(f"{args.task} NOT SUPPORTED (sp-off|sp-tgt|sp-group)")
        if len(pooled_result) > 0:
            for _k, _v in pooled_result.items():
                result[f"pooled_{_k}"] = _v

    results.update(result)

    output_dir = os.path.join(args.output_dir, mode)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_eval_file = os.path.join(
        output_dir,
        "{}-{}.txt".format(mode, global_step) if global_step else "{}.txt".format(mode),
    )
    with open(output_eval_file, "w") as f_w:
        logger.info("***** Eval results on {} dataset *****".format(mode))
        for key in sorted(results.keys()):
            logger.info("  {} = {}".format(key, str(results[key])))
            f_w.write("  {} = {}\n".format(key, str(results[key])))
        logger.info("\n" + show_ner_report(out_label_list, preds_list))  # Show report for each tag result
        f_w.write("\n" + show_ner_report(out_label_list, preds_list))

    return results, pooled_golds_and_preds, span_golds_and_preds


def main(cli_args):
    # Read from config file and make args
    if cli_args.target:
        with open(os.path.join(cli_args.config_dir, "target_zeroshot", f"{cli_args.model_size}.json")) as f:
            args = AttrDict(json.load(f))
    else:
        with open(os.path.join(cli_args.config_dir, cli_args.data, f"{cli_args.model_size}.json")) as f:
            args = AttrDict(json.load(f))
    logger.info("Training/evaluation parameters {}".format(args))

    args.task = cli_args.task
    args.data = cli_args.data
    args.model_dir = cli_args.model_dir
    args.min_offensiveness = cli_args.min_offensiveness

    if cli_args.learning_rate is not None:
        args.learning_rate = cli_args.learning_rate

    if cli_args.version is not None:
        args.output_dir = os.path.join(args.ckpt_dir, f"{args.task}-{cli_args.version}")
    else:
        args.output_dir = os.path.join(args.ckpt_dir, args.task)

    if cli_args.target is not None:
        args.target = cli_args.target
    else:
        args.target = None

    init_logger()
    set_seed(args)

    tokenizer = BertTokenizerFast.from_pretrained(args.model_dir, do_lower_case=args.do_lower_case)

    processor = processors[args.data](args, tokenizer)
    labels = processor.get_labels()
    # TODO: embed in processor
    pooled_labels = processor.get_pooled_labels()
    config = AutoConfig.from_pretrained(
        args.model_dir,
        num_labels=tasks_num_labels[args.task],
        id2label={str(i): label for i, label in enumerate(labels)},
        label2id={label: i for i, label in enumerate(labels)},
        task_specific_params={"num_pooled_labels": len(pooled_labels)},
    )

    model = HateSpanModel.from_pretrained(args.model_dir, config=config)

    # GPU or CPU
    args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
    model.to(args.device)

    # Load dataset
    train_dataset = load_and_cache_examples(args, tokenizer, mode="train") if args.train_file else None
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev") if args.dev_file else None
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test") if args.test_file else None

    if dev_dataset is None:
        args.evaluate_test_during_training = True  # If there is no dev dataset, only use testset

    if args.do_train:
        global_step, tr_loss = train(args, model, tokenizer, train_dataset, dev_dataset, test_dataset)
        logger.info(" global_step = {}, average loss = {}".format(global_step, tr_loss))

    results = {}
    best_result, best_step = -math.inf, 0
    if args.do_eval:
        checkpoints = list(
            os.path.dirname(c)
            for c in sorted(glob.glob(args.output_dir + "/**/" + "pytorch_model.bin", recursive=True))
        )
        if not args.eval_all_checkpoints:
            checkpoints = checkpoints[-1:]
        else:
            logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN)  # Reduce logging
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        best_model_pooled_golds_preds = []
        best_model_span_golds_preds = []
        best_checkpoint = None
        best_result_dict ={}
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1]
            model = HateSpanModel.from_pretrained(checkpoint)
            model.to(args.device)
            result, pooled_golds_and_preds, span_golds_and_preds = evaluate(args, model, tokenizer, test_dataset,
                                                                            mode="test",
                                                                            global_step=global_step)
            for k, v in result.items():
                if k == args.metric:
                    if best_result < v:
                        logger.info(f"METRIC[{k}]: {best_result} < {v} UPDATE BEST MODEL: {global_step}")
                        best_result_dict = result
                        best_result = v
                        best_step = global_step
                        best_model_pooled_golds_preds = pooled_golds_and_preds
                        best_model_span_golds_preds = span_golds_and_preds
                        best_checkpoint = checkpoint
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

        # Upload best model to model hub
        # best_model = AutoModelForTokenClassification.from_pretrained(best_checkpoint)
        # best_model.push_to_hub(args.output_dir, use_temp_dir=False, use_auth_token=HUGGINGFACE_AUTH_TOKEN)

        # Write Best Result
        best_result_file = os.path.join(args.output_dir, "best_result.txt")
        with open(best_result_file, "w") as f_w:
            f_w.write(f"Step {best_step}, ")
            for key in sorted(best_result_dict.keys()):
                f_w.write("{} = {:.4f}\t".format(key, best_result_dict[key]))

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as f_w:
            for key in sorted(results.keys()):
                f_w.write("{} = {}\n".format(key, str(results[key])))

        # Write Best Model Prediction
        model_prediction_file = os.path.join(args.output_dir, "test_prediction.tsv")
        with open(model_prediction_file, "w", encoding="utf-8") as f_w:
            test_examples = processor.get_examples("test")
            titles = [example.title for example in test_examples]
            comments = [example.comment for example in test_examples]
            guids = [example.guid for example in test_examples]
            f_w.write("guid\ttitle\ttext\tpooled_gold\tpooled_pred\tspan_gold\tspan_pred\n")
            pooled_golds = best_model_pooled_golds_preds[0]
            pooled_preds = best_model_pooled_golds_preds[1]
            span_golds = best_model_span_golds_preds[0]
            span_preds = best_model_span_golds_preds[1]
            for guid, title, comment, pooled_gold, pooled_pred, span_gold, span_pred in zip(guids, titles, comments,
                pooled_golds, pooled_preds, span_golds, span_preds):
                f_w.write(f"{guid}\t{title}\t{comment}\t{pooled_gold}\t{pooled_pred}\t{' '.join(span_gold)}\t{' '.join(span_pred)}\n")

    # Delete Checkpoint
    directories = glob.glob("{}/checkpoint-*/".format(args.output_dir))
    for directory in directories:
        directory_step = directory[:-1].split("-")[-1]
        if int(best_step) != int(directory_step):
            shutil.rmtree(directory)


if __name__ == "__main__":
    cli_parser = argparse.ArgumentParser()

    cli_parser.add_argument("--task", type=str, required=True)
    cli_parser.add_argument("--data", type=str, required=True)
    cli_parser.add_argument("--target", type=str, default=None)
    cli_parser.add_argument("--config_dir", type=str, default="config")
    cli_parser.add_argument("--model_size", type=str, required=True, help="base, small etc")
    cli_parser.add_argument("--model_dir", type=str, required=True)
    cli_parser.add_argument("--learning_rate", type=float)
    cli_parser.add_argument("--version", type=str)
    cli_parser.add_argument("--min_offensiveness", type=int, default=2)

    cli_args = cli_parser.parse_args()

    main(cli_args)
