#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library's seq2seq models for question answering using the 🤗 Seq2SeqTrainer.
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

import logging
import os
import sys
import json
import random
import copy
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
from itertools import chain

from torch.utils.data import DataLoader
from promptsource.templates import DatasetTemplates
import torch
import datasets
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk, load_metric
import numpy as np
import pandas as pd
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    set_seed
)
from transformers.trainer_utils import get_last_checkpoint
from src.utils import ModelArguments, DataTrainingArguments, Seq2SeqArguments, \
                LoraArguments, TASK_DEMONSTRATION_ONLY, get_dataset_sizes, \
                load_bbh_dataset, load_fs_glue_dataset, load_fs_super_glue_dataset, save_predictions, load_p3_eval_dataset
from src.metrics import Accuracy
from src.models import HyperLoRAConfig, HyperLoRAModelForPretrainP3, HyperLoRAModelForFinetune
from src.trainer import Seq2SeqHyperTrainer
from src.dataset import template_list as task2template

import os
os.environ["WANDB_PROJECT"]="HyperLoRA"

logger = logging.getLogger(__name__)

from transformers import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy

@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:

            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
            sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
            maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
            different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.

            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
            Note that it's very NOT recommended to use fp16 to do any time of inference with T0 as the predictions will vastly differ from the predictions using fp32.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    label_pad_token_id: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [
                {
                    k: v[i]
                    for k, v in feature.items()
                    if k != "targets"
                }
                for i in range(num_choices)
            ]
            for feature in features
        ]
        flattened_features = list(chain(*flattened_features))

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
        )

        # Pad the labels because it's not padded automatically
        max_label_length = max([len(elem["labels"]) for elem in flattened_features])
        batch["labels"] = [
            l + [self.tokenizer.pad_token_id]*(max_label_length - len(l))
            for l in [elem["labels"] for elem in flattened_features]
        ]
        batch["labels_attention_mask"] = [
            m + [0]*(max_label_length - len(m))
            for m in [elem["labels_attention_mask"] for elem in flattened_features]
        ]

        # Convert to tensors
        batch = {
            k: torch.tensor(v)
            for k, v in batch.items()
        }

        batch["targets"] = torch.tensor([f.pop("targets") for f in features])
        return batch

def generate_lora_weights(hyperlora, demo_datasets):
    hyperlora.cuda()
    with torch.no_grad():
        input_ids, attention_mask = demo_datasets['demo_input_ids'], demo_datasets['demo_attention_mask']
        input_ids = torch.tensor(input_ids).cuda()
        attention_mask = torch.tensor(attention_mask).cuda()
        generate_lora_weights = hyperlora.generate_lora_weights(
                                        demo_input_ids=input_ids.unsqueeze(0),
                                        demo_attention_mask=attention_mask.unsqueeze(0))

    return generate_lora_weights

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqArguments, LoraArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, lora_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()

    # update model_args

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    training_args.local_rank = -1
    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f", distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load data
    raw_datasets = load_p3_eval_dataset(data_args, n_demonstrations=data_args.n_demonstrations, seed=training_args.seed)
    task2taskname = {'rte': 'super_glue/rte', 'cb': 'super_glue/cb', 'wsc_fixed': 'super_glue/wsc.fixed',
                        "wic": 'super_glue/wic', 'copa': 'super_glue/copa', 'hellaswag': 'hellaswag',
                        "winogrande": "winogrande/winogrande_xl", "anli_r1": 'anli',
                        "anli_r2": 'anli', "anli_r3": 'anli'}
    
    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    hyperlora_config = HyperLoRAConfig.from_pretrained(
        model_args.hypelora_name_or_path, lora_rank=lora_args.lora_rank
    )
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    plm_model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = plm_model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        plm_model.resize_token_embeddings(len(tokenizer))

    if plm_model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    # task2id
    if not model_args.not_use_hyperlora:
        all_task_names = task2taskname.keys()
        hyperlora_model = HyperLoRAModelForPretrainP3(config=hyperlora_config,
                                        model_args=model_args,
                                        lora_args=lora_args,
                                        encoder=plm_model.encoder,
                                        pretrain_task_names=all_task_names)
        task2id = hyperlora_model.task2id
        # load pre-trained checkpoint
        if model_args.pretrain_checkpoint is not None:
            logger.info('Loading checkpoint from {}'.format(model_args.pretrain_checkpoint))
            hyperlora_model.load_state_dict(torch.load(os.path.join(model_args.pretrain_checkpoint, "pytorch_model.bin"), map_location='cpu'))

        if not model_args.finetune:
            model = hyperlora_model
    else:
        model = plm_model
        all_task_names = list(raw_datasets.keys())

    # Temporarily set max_answer_length for training.
    max_answer_length = data_args.max_answer_length
    padding = "max_length" if data_args.pad_to_max_length else False

    if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
        logger.warning(
            "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
            f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
        )

    max_seq_length = data_args.max_seq_length
    input_column = data_args.input_column
    output_column = data_args.output_column

    def preprocess_demonstrations(task_demo_examples):
        demonstration_inputs = ""
        inputs = task_demo_examples[input_column]
        targets = task_demo_examples[output_column]
        # assert len(inputs) == data_args.n_demonstrations    # few-shot
        for i in range(len(inputs)):
            demonstration_inputs += TASK_DEMONSTRATION_ONLY.format(input=inputs[i], target=targets[i])

        demo_tokenize_output= tokenizer(demonstration_inputs.strip(), max_length=max_seq_length*2, padding=padding, truncation=True)
        demo_output = {'demo_input_ids': demo_tokenize_output['input_ids'],
                       'demo_attention_mask': demo_tokenize_output['attention_mask']
                       }
        return demo_output

    def preprocess_function(examples, template, column_names):
        bs = len(examples[column_names[0]])

        input_texts = []
        target_texts = []
        answer_choices_texts = []
        task_ids, demo_input_ids, demo_attention_mask = [], [], []
        for i in range(bs):
            ex = {
                k: examples[k][i]
                for k in column_names
            }
            task = ex.pop('task_name')
            inputs, targets = template.apply(ex)
            ex_answer_choices = template.get_answer_choices_list(ex)
            assert targets in ex_answer_choices
            input_texts.append(inputs)
            target_texts.append(targets)
            answer_choices_texts.append(ex_answer_choices)
            # add demo
            task_ids.append(task2id[task])
            demo_input_ids.append(demo_dataset['demo_input_ids'])
            demo_attention_mask.append(demo_dataset['demo_attention_mask'])

        tokenized_inputs = tokenizer(
            input_texts,
            padding=padding,
            max_length=max_seq_length,
            truncation=True,
            add_special_tokens=False,
        )
        tokenized_targets = [
            tokenizer(
                ans_choi,
                # padding is on the right here.
                padding=False,
                max_length=max_seq_length,
                truncation=True,
            )
            for ans_choi in answer_choices_texts
        ]

        features = {
            k: [
                [elem for _ in range(len(tokenized_targets[idx]["input_ids"]))]
                for idx, elem in enumerate(v)
            ]
            for k, v in tokenized_inputs.items()
        }

        features["labels"] = [
            tokenized_targets[idx]["input_ids"]
            for idx in range(bs)
        ]
        features["labels_attention_mask"] = [
            tokenized_targets[idx]["attention_mask"]
            for idx in range(bs)
        ]
        features["targets"] = [
            answer_choices_texts[idx].index(t)
            for idx, t in enumerate(target_texts)
        ]
        features["task_ids"] = [[task_ids[idx] for _ in range(len(tokenized_targets[idx]["input_ids"]))] for idx in range(bs)]
        features['demo_input_ids'] = [[demo_input_ids[idx] for _ in range(len(tokenized_targets[idx]["input_ids"]))] for idx in range(bs)]
        features['demo_attention_mask'] = [[demo_attention_mask[idx] for _ in range(len(tokenized_targets[idx]["input_ids"]))] for idx in range(bs)]

        return features

    # load from checkpoint
    if training_args.resume_from_checkpoint is not None:
        state_dict = torch.load(os.path.join(training_args.resume_from_checkpoint, 'pytorch_model.bin'), map_location='cpu')
        model.load_state_dict(state_dict, strict=True)
        print('Loading model from {} successful.'.format(training_args.resume_from_checkpoint))

    task2acc = {}
    for task_name in tqdm(all_task_names, total=len(all_task_names), desc='Evaluating on each task'):
        task2acc[task_name] = {}
        for template in task2template[task_name]:
            print('Evaluating on task {} with template {}'.format(task_name, template))
            prompt = DatasetTemplates(task2taskname[task_name])[template]
            
            # load demonstration data
            template_name = template.replace('…', '\u2026')
            template_name = task_name + template_name.replace(' ', '_').replace('/', '_')
            demo_dataset = datasets.load_dataset("json", data_files=os.path.join(f"{data_args.dataset_name}_auto_demonstration", f"{template_name}.json"))['train']

            if not model_args.not_use_hyperlora:
                # prepreocess demonstration
                demo_dataset = preprocess_demonstrations(demo_dataset)

            eval_dataset = raw_datasets[task_name]

            if data_args.max_eval_samples is not None:
                # We will select sample from whole data
                max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
                eval_dataset = eval_dataset.select(range(max_eval_samples))
            # Validation Feature Creation
            with training_args.main_process_first(desc="validation dataset map pre-processing"):
                column_names = eval_dataset.column_names
                eval_dataset = eval_dataset.map(
                    preprocess_function,
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                    desc="Running tokenizer on validation dataset",
                    fn_kwargs={"template": prompt, "column_names": column_names}
                )
            if data_args.max_eval_samples is not None:
                # During Feature creation dataset samples might increase, we will select required samples again
                max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
                eval_dataset = eval_dataset.select(range(max_eval_samples))
                
            # metric
            metrics = Accuracy()

            # Data collator
            label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
            data_collator = DataCollatorForMultipleChoice(
                tokenizer,
                label_pad_token_id=label_pad_token_id,
                pad_to_multiple_of=8 if training_args.fp16 else None
            )
            
            eval_dataloader = DataLoader(
                            eval_dataset, 
                            collate_fn=data_collator, 
                            batch_size=training_args.per_device_eval_batch_size
            )
            
            logger.info("***** Running evaluation *****")
            logger.info(f"  Num examples = {len(eval_dataset)}")
            logger.info(f"  Instantaneous batch size per device = {training_args.per_device_eval_batch_size}")
            # Only show the progress bar once on each machine.

            model = model.to('cuda')
            model.eval()
            with torch.no_grad():
                for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc='Evaluating'):
                    batch = {key: batch[key].cuda() for key in batch}
                    output = model(**batch)
                    predictions = output['logits'].argmax(-1).cpu().detach()
                    metrics.add_batch(
                            predictions=predictions,
                            references=batch["targets"].cpu().detach(),
                    )
            
            metrics = metrics.compute()
            
            # list all task em
            print(metrics)
            print(task_name, metrics['accuracy'] * 100)
            task2acc[task_name][template_name] = metrics['accuracy'] * 100

    print(task2acc)
    task2acc_df = pd.DataFrame.from_dict(task2acc)
    task2acc_df.to_excel(os.path.join(training_args.output_dir, 'task2acc.xlsx'))        

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()