import os
import argparse
import numpy as np
import logging
logger = logging.getLogger(__name__)
import torch
from datasets import load_dataset
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    BertConfig,
    BertForMultipleChoice,
    BertTokenizer,
    RobertaConfig,
    RobertaModel,
    RobertaForMultipleChoice,
    RobertaTokenizer,
    XLNetConfig,
    XLNetForMultipleChoice,
    XLNetTokenizer,
    AlbertConfig,
    AlbertForMultipleChoice,
    AlbertTokenizer,
    DebertaConfig,
    DebertaTokenizer,
    set_seed,
    default_data_collator,
    PreTrainedTokenizer,
    BatchEncoding,
)
from Contrastive_Trainer import ContrastiveTrainer
from transformers import Trainer, TrainingArguments, TrainerCallback
from deberta_multichoice_model import DebertaForMultipleChoice

MODEL_CLASSES = {
    "bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
    "roberta": (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
    "albert": (AlbertConfig, AlbertForMultipleChoice, AlbertTokenizer),
    "deberta": (DebertaConfig, DebertaForMultipleChoice, DebertaTokenizer),
}


class StopCallback(TrainerCallback):
    """
    A bare :class:`~transformers.TrainerCallback` that just stop the training.
    """

    def on_step_begin(self, args, state, control, lr_scheduler, **kwargs):
        # print("2", lr_scheduler.get_lr())
        # print('step', state.global_step)
        if state.global_step > 100 and lr_scheduler.get_lr()[0] == 0:
            control.should_training_stop = True


def init_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument("--model_type", default=None, type=str, required=True)
    parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.")
    parser.add_argument("--max_seq_length", default=128, type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
             "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--preprocessing_num_workers", default=None, type=int,
                        help="The number of processes to use for the preprocessing."
    )
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " #+ ", ".join(ALL_MODELS),
    )
    parser.add_argument("--task_name", default=None, type=str, required=True,
        help="The name of the task to train",
    )
    parser.add_argument("--data_dir", default=None, type=str, required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument("--output_dir", default=None, type=str, required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--run_name", default=None, type=str, required=True,
        help="The name of such run.",
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--pad_to_max_length", action="store_true", help="Whether to pad all samples to the maximum sentence length."
    )
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument("--load_checkpoint", default=None, type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument("--seed", default=42, type=int, help="set the random seed of the experiment")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run test on the test set")
    parser.add_argument("--logging_steps", type=int, default=10, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing_factor", default=0.0, type=float, help="label smoothing factor.")
    args = parser.parse_args()

    return args


class DataCollatorForMultipleChoice:

    def __call__(self, features):
        if not isinstance(features[0], (dict, BatchEncoding)):
            features = [vars(f) for f in features]

        first = features[0]
        batch = {}

        # Special handling for labels.
        # Ensure that tensor is created with the correct type
        # (it should be automatically the case, but let's make sure of it.)
        if "label" in first and first["label"] is not None:
            label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
            dtype = torch.long if isinstance(label, int) else torch.float
            batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
        elif "label_ids" in first and first["label_ids"] is not None:
            if isinstance(first["label_ids"], torch.Tensor):
                batch["labels"] = torch.stack([f["label_ids"] for f in features])
            else:
                dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
                batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)

        if "contras_label" in first:
            contras_label = first["contras_label"].item() if isinstance(first["contras_label"], torch.Tensor) else first["contras_label"]
            contras_dtype = torch.long if isinstance(contras_label, int) else torch.float
            batch["contras_label"] = torch.tensor([f["contras_label"] for f in features], dtype=contras_dtype)

        # Handling of all other possible keys.
        # Again, we will use the first element to figure out which key/values are not None for this model.
        for k, v in first.items():
            if k not in ("label", "label_ids") and v is not None and not isinstance(v, str) and (isinstance(v, list) and not isinstance(v[0], str)):
                if isinstance(v, torch.Tensor):
                    batch[k] = torch.stack([f[k] for f in features])
                else:
                    batch[k] = torch.tensor([f[k] for f in features])

        return batch


def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).mean().item()}


def main():
    args = init_args()
    set_seed(args.seed)
    print(args.output_dir)

    if not args.do_test:
        os.environ['WANDB_PROJECT'] = 'logiqa'

    data_files = {}
    data_files["train"] = os.path.join(args.data_dir, "enrich_train_logi.json")
    data_files["val"] = os.path.join(args.data_dir, "enrich_val_logi.json")
    data_files["test"] = os.path.join(args.data_dir, "enrich_test_logi.json")
    datasets = load_dataset("json", data_files=data_files, field="data")

    args.total_steps = int(len(datasets['train']) // torch.cuda.device_count() // args.gradient_accumulation_steps // args.per_gpu_train_batch_size * args.num_train_epochs)
    args.warmup_steps = int(len(datasets['train']) // torch.cuda.device_count() // args.gradient_accumulation_steps // args.per_gpu_train_batch_size * args.num_train_epochs * 0.1)
    print("warm up steps", args.warmup_steps)
    print("label smoothing", args.label_smoothing_factor)
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        do_train=args.do_train,
        do_eval=args.do_eval,
        evaluation_strategy='steps',
        eval_steps=args.save_steps,
        per_device_train_batch_size=args.per_gpu_train_batch_size,
        per_device_eval_batch_size=args.per_gpu_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=args.total_steps,
        # num_train_epochs=args.num_train_epochs,
        logging_dir='./logs',
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        save_total_limit=20,
        seed=args.seed,
        fp16=args.fp16,
        local_rank=args.local_rank,
        dataloader_num_workers=6,
        learning_rate=args.learning_rate,
        lr_scheduler_type='linear',
        warmup_steps=args.warmup_steps,
        deepspeed='ds_config.json',
        report_to=['wandb'] if not args.do_test else [],
        run_name=args.run_name,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        remove_unused_columns=False,
        label_smoothing_factor=args.label_smoothing_factor,
    )

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    config = config_class.from_pretrained(
        args.model_name_or_path,
        num_labels=4,
        finetuning_task=args.task_name.lower(),
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case,
    )
    if not args.do_test:
        if args.load_checkpoint is None:
            model = model_class.from_pretrained(
                args.model_name_or_path,
                from_tf=False,
                config=config,
            )
        else:
            print('*'*40, "Load pretrained checkpoint")
            model = model_class.from_pretrained(args.load_checkpoint)
    else:
        model = model_class.from_pretrained(args.output_dir+"/checkpoint-2800")

    special_tokens_dict = {'additional_special_tokens': ['<ext>']}
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    def preprocess_function(examples):
        first_sentences = [[context] * 4 for context in examples['context']]
        # second_sentences = [
        #     [f"{ques} {examples['answers'][i][j]}" for j in range(4)] for i, ques in enumerate(examples['question'])
        # ]
        second_sentences = [
            [f"{ques} {examples['answers'][i][j]} {tokenizer.additional_special_tokens[0]} {examples['extend_contexts'][i][j]}" for j in range(4)] for i, ques in enumerate(examples['question'])
        ]

        # Flatten out
        first_sentences = sum(first_sentences, [])
        second_sentences = sum(second_sentences, [])
        # Tokenize
        tokenized_examples = tokenizer(
            first_sentences,
            second_sentences,
            add_special_tokens=True,
            truncation=True,
            max_length=args.max_seq_length,
            padding="max_length" if args.pad_to_max_length else False,
            return_attention_mask=True,
            return_token_type_ids=True
        )

        conts_first_sentences = [conts_context for conts_context in examples['contras_contexts']]
        # conts_second_sentences = [
        #     [
        #         f"{ques} {examples['contras_endings'][i][j]}"
        #         for j in range(2)] for i, ques in enumerate(examples['question'])
        # ]
        conts_second_sentences = [
            [f"{ques} {examples['contras_endings'][i][j]} {tokenizer.additional_special_tokens[0]} {examples['contras_extend_context'][i][j]}" for j in range(2)] for i, ques in enumerate(examples['question'])
        ]
        # Flatten out
        conts_first_sentences = sum(conts_first_sentences, [])
        conts_second_sentences = sum(conts_second_sentences, [])
        # Tokenize
        conts_tokenized_examples = tokenizer(
            conts_first_sentences,
            conts_second_sentences,
            add_special_tokens=True,
            truncation=True,
            max_length=args.max_seq_length,
            padding="max_length" if args.pad_to_max_length else False,
            return_attention_mask=True,
            return_token_type_ids=True
        )

        dict_1 = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
        dict_2 = {"conts_"+k: [v[i: i + 2] for i in range(0, len(v), 2)] for k, v in conts_tokenized_examples.items()}
        return dict(dict_1, **dict_2)

    tokenized_datasets = datasets.map(
        preprocess_function,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        load_from_cache_file=not args.overwrite_cache,
    )

    data_collator = DataCollatorForMultipleChoice()

    trainer = ContrastiveTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,  # training arguments, defined above
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["val"],
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[StopCallback],
    )

    if not args.do_test:
        trainer.train()

        # print(trainer.state.best_model_checkpoint)
        output = trainer.predict(tokenized_datasets["test"], metric_key_prefix="test")
        print(output.metrics)
    else:
        # Test
        trainer.model.to(torch.device("cuda", training_args.local_rank))
        output = trainer.predict(tokenized_datasets["test"], metric_key_prefix="test")
        print(output.metrics)

        # import numpy as np
        # np.save("predictions/test_preds_all.npy", output.predictions)

        # import numpy as np
        # np.save("predictions/val_preds_all.npy", output.predictions)


if __name__ == "__main__":
    main()