import os
import argparse
import numpy as np
import logging
logger = logging.getLogger(__name__)
import torch
from datasets import load_dataset, set_caching_enabled, is_caching_enabled
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    BertConfig,
    BertForMultipleChoice,
    BertTokenizer,
    RobertaConfig,
    RobertaModel,
    RobertaForMultipleChoice,
    RobertaTokenizer,
    XLNetConfig,
    XLNetForMultipleChoice,
    XLNetTokenizer,
    AlbertConfig,
    AlbertForMultipleChoice,
    AlbertTokenizer,
    DebertaConfig,
    DebertaTokenizer,
    set_seed,
    default_data_collator,
    PreTrainedTokenizer,
)
from transformers import Trainer, TrainingArguments, TrainerCallback
from deberta_multichoice_model import DebertaForMultipleChoice

MODEL_CLASSES = {
    "bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
    "roberta": (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
    "albert": (AlbertConfig, AlbertForMultipleChoice, AlbertTokenizer),
    "deberta": (DebertaConfig, DebertaForMultipleChoice, DebertaTokenizer),
}


class StopCallback(TrainerCallback):
    """
    A bare :class:`~transformers.TrainerCallback` that just stop the training.
    """

    def on_step_begin(self, args, state, control, lr_scheduler, **kwargs):
        # print("*"*100, lr_scheduler.get_lr())
        # print('step', state.global_step)
        if state.global_step > 100 and lr_scheduler.get_lr()[0] == 0:
            control.should_training_stop = True


def init_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument("--model_type", default=None, type=str, required=True)
    parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.")
    parser.add_argument("--max_seq_length", default=128, type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
             "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--preprocessing_num_workers", default=None, type=int,
                        help="The number of processes to use for the preprocessing."
    )
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " #+ ", ".join(ALL_MODELS),
    )
    parser.add_argument("--task_name", default=None, type=str, required=True,
        help="The name of the task to train",
    )
    parser.add_argument("--data_dir", default=None, type=str, required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument("--output_dir", default=None, type=str, required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--run_name", default=None, type=str, required=True,
        help="The name of such run.",
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--pad_to_max_length", action="store_true", help="Whether to pad all samples to the maximum sentence length."
    )
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )

    parser.add_argument("--seed", default=42, type=int, help="set the random seed of the experiment")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run test on the test set")
    parser.add_argument("--logging_steps", type=int, default=10, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing_factor", default=0.0, type=float, help="label smoothing factor.")
    args = parser.parse_args()

    return args


# class DataCollatorForMultipleChoice:
#     """
#     Data collator that will dynamically pad the inputs for multiple choice received.
#     Args:
#         tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
#             The tokenizer used for encoding the data.
#         padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
#             Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
#             among:
#             * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
#               sequence if provided).
#             * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
#               maximum acceptable input length for the model if that argument is not provided.
#             * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
#               different lengths).
#         max_length (:obj:`int`, `optional`):
#             Maximum length of the returned list and optionally padding length (see above).
#         pad_to_multiple_of (:obj:`int`, `optional`):
#             If set will pad the sequence to a multiple of the provided value.
#             This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
#             7.5 (Volta).
#     """
# 
#     tokenizer: PreTrainedTokenizer
#     padding: Union[bool, str, PaddingStrategy] = True
#     max_length: Optional[int] = None
#     pad_to_multiple_of: Optional[int] = None
# 
#     def __call__(self, features):
#         label_name = "label" if "label" in features[0].keys() else "labels"
#         labels = [feature.pop(label_name) for feature in features]
#         batch_size = len(features)
#         num_choices = len(features[0]["input_ids"])
#         flattened_features = [
#             [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
#         ]
#         flattened_features = sum(flattened_features, [])
# 
#         batch = self.tokenizer.pad(
#             flattened_features,
#             padding=self.padding,
#             max_length=self.max_length,
#             pad_to_multiple_of=self.pad_to_multiple_of,
#             return_tensors="pt",
#         )
# 
#         # Un-flatten
#         batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
#         # Add back labels
#         batch["labels"] = torch.tensor(labels, dtype=torch.int64)
#         return batch


def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).mean().item()}


def main():
    args = init_args()
    set_seed(args.seed)
    print(args.output_dir)
    if not args.do_test:
        os.environ['WANDB_PROJECT'] = args.task_name

    data_files = {}
    if args.task_name == 'logiqa':
        sufix = 'logi'
    elif args.task_name == 'arc':
        sufix = 'arc'
    else:
        sufix = 'race'
    data_files["train"] = os.path.join(args.data_dir, "wrap_train_"+ sufix +".json")
    data_files["val"] = os.path.join(args.data_dir, "wrap_val_"+ sufix +".json")
    data_files["test"] = os.path.join(args.data_dir, "wrap_test_"+ sufix +".json")
    datasets = load_dataset("json", data_files=data_files, field="data")

    args.total_steps = int(len(datasets['train']) // torch.cuda.device_count() // args.gradient_accumulation_steps // args.per_gpu_train_batch_size * args.num_train_epochs)
    args.warmup_steps = int(len(datasets['train']) // torch.cuda.device_count() // args.gradient_accumulation_steps // args.per_gpu_train_batch_size * args.num_train_epochs * 0.1)
    print("warm up steps", args.warmup_steps)
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        do_train=args.do_train,
        do_eval=args.do_eval,
        evaluation_strategy='steps',
        eval_steps=args.save_steps,
        per_device_train_batch_size=args.per_gpu_train_batch_size,
        per_device_eval_batch_size=args.per_gpu_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=args.total_steps,
        # num_train_epochs=args.num_train_epochs,
        logging_dir='./logs',
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        save_total_limit=20,
        seed=args.seed,
        fp16=args.fp16,
        local_rank=args.local_rank,
        dataloader_num_workers=6,
        learning_rate=args.learning_rate,
        lr_scheduler_type='linear',
        warmup_steps=args.warmup_steps,
        deepspeed='ds_config.json',
        report_to=['wandb'] if not args.do_test else [],
        run_name=args.run_name,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        label_smoothing_factor=args.label_smoothing_factor,
    )

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    config = config_class.from_pretrained(
        args.model_name_or_path,
        num_labels=4,
        finetuning_task=args.task_name.lower(),
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case,
    )
    if not args.do_test:
        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=False,
            config=config,
        )
    else:
        model = model_class.from_pretrained(args.output_dir+"/checkpoint-4400")

    # Preprocessing the datasets.
    def preprocess_function(examples):

        first_sentences = [[context] * 4 for context in examples['context']]
        second_sentences = [
            [f"{ques} {examples['answers'][i][j]}" for j in range(4)] for i, ques in enumerate(examples['question'])
        ]

        # Flatten out
        first_sentences = sum(first_sentences, [])
        second_sentences = sum(second_sentences, [])

        # Tokenize
        tokenized_examples = tokenizer(
            first_sentences,
            second_sentences,
            add_special_tokens=True,
            truncation=True,
            max_length=args.max_seq_length,
            padding="max_length" if args.pad_to_max_length else False,
            return_attention_mask=True,
            return_token_type_ids=True
        )
        # Un-flatten
        return {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

    print("dataset enabled", is_caching_enabled())

    tokenized_datasets = datasets.map(
        preprocess_function,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        load_from_cache_file=not args.overwrite_cache,
        # cache_file_names={"train":"cache_race_384_albert_train", "val": "cache_race_384_albert_val" , "test":"cache_race_384_albert_test"},
    )

    data_collator = default_data_collator

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,  # training arguments, defined above
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["val"],
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[StopCallback],
    )

    if not args.do_test:
        trainer.train()

        # print(trainer.state.best_model_checkpoint)
        output = trainer.predict(tokenized_datasets["test"], metric_key_prefix="test")
        print(output.metrics)
    else:
        # Test
        trainer.model.to(torch.device("cuda", training_args.local_rank))
        output = trainer.predict(tokenized_datasets["test"], metric_key_prefix="test")
        print(output.metrics)

        # import numpy as np
        # np.save("predictions/val_preds_all_deberta.npy", output.predictions)
        # print(output.label_ids)
        # np.save("predictions/val_labels.npy", output.label_ids)



if __name__ == "__main__":
    main()