# Supervised fine-tuning of language models.

import argparse
import time
import torch.nn as nn

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import Trainer, TrainingArguments, Seq2SeqTrainingArguments, default_data_collator

from utils import *
from data import *
from dataset import *
from dataset_trainer import *
from transfer_utils import *
from model import FTModel



parser = argparse.ArgumentParser()

# general
parser.add_argument('--seed', type=int, default = 42)
parser.add_argument('--cuda', type=bool, default = True)
parser.add_argument('--fp16', type=bool, default = False) # Pegasus: False, BART: True
parser.add_argument('--mp', type=bool, default = False)
parser.add_argument('--debug', type=bool, default = False)
parser.add_argument('--debug_size', type=int, default = 20)
parser.add_argument('--deepspeed', type=str, default = None)  # "ds_config.json"
parser.add_argument('--sharded_ddp', type=str, default = "simple")  # ["", "simple"]
parser.add_argument("--local_rank", type=int, default = 0, help="Local rank. Necessary for using the torch.distributed.launch utility.")

# task
parser.add_argument('--train', type=bool, default = True)
parser.add_argument('--prompt', type=str, default = "summarize: ")  # in ["summarize: "]
parser.add_argument('--add_prompt_to_text', type = bool, default = False)

# data
parser.add_argument('--data_folder', type=str, default = "../../DATASETS/RedditTIFU/data/") # CNNDM / WikiHow / XSum / RedditTIFU / BillSum
parser.add_argument('--max_length', type=int, default = 512) # CNNDM: 1024 / WikiHow: 512 / XSum: 512 / Reddit: 512 / BillSum: 768 (700)
# train
parser.add_argument('--train_dataset', type = str, default = "train")
parser.add_argument('--train_size', type=int, default = 1000000)
# val
parser.add_argument('--val_dataset', type = str, default = "small_val")
parser.add_argument('--val_size', type = int, default = 10000000)
# test
parser.add_argument('--test_dataset', type = str, default = "test")
parser.add_argument('--test_size', type = int, default = 100000)

# model
parser.add_argument('--model_type', type=str, default = "pegasus") # in ["t5", "pegasus", "bart"]
parser.add_argument('--model', type=str, default = "google/pegasus-large") # in ["t5-small", "t5-base", "google/t5-v1_1-base", "google/pegasus-large", "facebook/bart-large"]
parser.add_argument('--hidden_size', type=int, default = 768)
parser.add_argument('--cache_dir', type=str, default = "../../hf_models/pegasus-large/") # in ["t5-base", "t5-base-v1", "pegasus-large", "bart-large"]
parser.add_argument('--load_model', type=bool, default = False)
parser.add_argument('--load_model_path', type=str, default = "")

# optimization
parser.add_argument('--n_epochs', type=int, default = 15)
parser.add_argument('--adafactor', type=bool, default = True) # Pegasus: True / Bart: False 
parser.add_argument('--scheduler', type=str, default = "constant") # in ["constant", "linear"] Pegasus: constant / Bart: linear 
parser.add_argument('--warmup_ratio', type=float, default = 0.025) # Pegasus: _ / Bart: 0.025
parser.add_argument('--train_bs', type=int, default = 2) # Pegasus: 2 / Bart: 4 
parser.add_argument('--inference_bs', type=int, default = 2) # Pegasus: 2 / Bart: 4 
parser.add_argument('--lr', type=float, default = 1e-4) # CNNDM: Pegasus: 5e-5, Bart: 3e-5 / WikiHow: Pegasus: 8e-4 / XSum: Pegasus: 1e-4. BART: 3e-5 / Reddit: Pegasus: 1e-4 / BillSum: Pegasus: 2e-4
parser.add_argument('--gradient_accumulation_steps', type=int, default = 128) # Pegasus: 128 / Bart: 20 
parser.add_argument('--wd', type=float, default = 0)
parser.add_argument('--gradient_clipping', type=float, default = 100000000000.0)
parser.add_argument('--label_smoothing', type=float, default = 0.1) # Pegasus: 0.1 / Bart: 0.1 

# evaluation
parser.add_argument('--eval_epoch_0', type = bool, default = True)
parser.add_argument('--evaluation_strategy', type=str, default = "steps")
parser.add_argument('--eval_every', type=int, default = 250) # Pegasus: 250 / Bart: 500 
parser.add_argument('--eval_test', type=bool, default = False)

# summaries
parser.add_argument('--generate_summaries', type=bool, default = False)
parser.add_argument('--stemmer', type=bool, default = True)
parser.add_argument('--show_summaries', type=bool, default = True)
parser.add_argument('--show_summaries_count', type=int, default = 1) # batches
parser.add_argument('--rouge_to_use', type=str, default = "rouge_score") # in ["rouge_score", "rouge"]

# summary generation
parser.add_argument('--num_beams', type=int, default = 5)
parser.add_argument('--max_summary_length', type=int, default = 128) # CNNDM: 128 / WikiHow: 256 / XSum: 64 / Reddit: 128

# export
parser.add_argument('--n_checkpoints_to_save', type=int, default = 2)
parser.add_argument('--save_model_path', type=str, default = "ft_saved_models/pegasus_reddit_train_1")

args = parser.parse_args()

print("*" * 50)
print(args)


# time.sleep(10000)


def main(args):
    # seed
    seed_everything(args.seed)

    # data
    train_data = load_data(args.train_dataset, args, individual_txt=False)
    val_data = load_data(args.val_dataset, args, individual_txt=False)
    test_data = load_data(args.test_dataset, args, individual_txt=False)

    # tokenizer
    tokenizer = build_tokenizer(args)

    # datasets
    datasets = []
    for x in [("val", val_data), ("test", test_data), ("train", train_data)]:
        mode, data = x
        texts, summaries = data
        print(len(texts), len(summaries))
        if args.debug:
            texts = texts[:args.debug_size]
            summaries = summaries[:args.debug_size]
        if mode == "train":
            texts = texts[:args.train_size]
            summaries = summaries[:args.train_size]
            train_dataset = TrainFTDatasetTrainer(mode, tokenizer, texts, summaries, args)
            datasets.append(train_dataset)
            print("There are {} train data points".format(len(texts)))
        else:
            if mode == "val":
                texts = texts[:args.val_size]
                summaries = summaries[:args.val_size]
            else:
                texts = texts[:args.test_size]
                summaries = summaries[:args.test_size]
            dataset = InferenceFTDatasetTrainer(mode, tokenizer, texts, summaries, args)
            datasets.append(dataset)
            print("There are {} {} batches".format(int(len(dataset.texts) / args.train_bs), mode))
    train_dataset = datasets[2]
    val_dataset = datasets[0]
    test_dataset = datasets[1]

    # model
    base_model = build_model(args)
    model = FTModel(base_model, args)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("\nThe model has {} trainable parameters".format(n_params))

    # loading checkpoint
    if args.load_model:
        print("Loading checkpoint: {}".format(args.load_model_path))
        model.load_state_dict(torch.load(args.load_model_path))

    if args.mp:
        if "t5" in args.model_type:
            print("Using model parallelism...")
            model.parallelize()
        else:
            print("Can't do Model Parallelism on that model")
            raise Exception

    train_args = Seq2SeqTrainingArguments(
        output_dir=args.save_model_path,  # will be changed
        overwrite_output_dir=True,
        do_train=True,
        do_eval=True,
        do_predict=False,
        evaluation_strategy=args.evaluation_strategy,
        eval_steps=args.eval_every,
        save_total_limit=args.n_checkpoints_to_save,
        save_steps = args.eval_every,
        num_train_epochs=args.n_epochs,
        adafactor=args.adafactor,
        lr_scheduler_type=args.scheduler,
        warmup_ratio=args.warmup_ratio,
        per_device_train_batch_size=args.train_bs,
        per_device_eval_batch_size=args.inference_bs,
        learning_rate=args.lr,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        weight_decay=args.wd,
        max_grad_norm=args.gradient_clipping,
        label_smoothing_factor=args.label_smoothing,
        logging_strategy="no",
        save_strategy=args.evaluation_strategy,
        fp16=args.fp16,
        load_best_model_at_end=True,
        greater_is_better=False,
        disable_tqdm=False,
        deepspeed=args.deepspeed,
        sharded_ddp=args.sharded_ddp,
        local_rank=args.local_rank,
    )

    data_collator = default_data_collator

    trainer = CustomTrainer(
        model=model,
        args=train_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )

    if args.eval_epoch_0:
        results = trainer.evaluate()
        print("*" * 50, "EPOCH 0 RESULTS")
        print(results)

    # training loop
    trainer.train()

    # validate with the best model
    results = trainer.evaluate()
    print("*" * 50, "RESULTS")
    print(results)


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        text_inputs_ids = inputs["text_input_ids"]
        text_attention_mask = inputs["text_attention_mask"]
        labels = inputs["labels"]
        outputs = model(text_inputs_ids, text_attention_mask, labels=labels)
        loss_ce = outputs["loss"]
        loss = loss_ce

        return (loss, outputs) if return_outputs else loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if has_labels:
                loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                loss = loss.mean().detach()
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs[1:]
            else:
                loss = None
                if self.use_amp:
                    with autocast():
                        outputs = model(**inputs)
                else:
                    text_inputs_ids = inputs["text_inputs_ids"]
                    text_attention_mask = inputs["text_attention_mask"]
                    text_inputs = {
                        "input_ids": text_inputs_ids,
                        "attention_mask": text_attention_mask
                    }
                    outputs = model(**text_inputs)
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                else:
                    logits = outputs
                # TODO: this needs to be fixed and made cleaner later.
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)


if __name__ == '__main__':
    main(args)
