import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import Trainer, TrainingArguments, BitsAndBytesConfig, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import pickle
import copy
from itertools import permutations
import os


def format_data(episode):
    with open('MEMORY_TAG_DATA', 'rb') as f:
        for_retrieval = pickle.load(f)

    session = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth']
    inputs = []
    targets = []

    for item in episode:
        all_memory = []
        prev_sid = ""
        for sid in session:
            temp_inputs = []
            temp_targets = []
            model_input = ""

            if sid != 'first':
                original_memory = copy.deepcopy(episode[item][f'{prev_sid}_session_memory'])
                for key in original_memory:
                    for value in original_memory[key]:
                        all_memory.append(value.strip() + f" (about {key}, from {prev_sid} session)")

            try:
                main_speaker_name = episode[item][f'{sid}_session_dialogue'][0][0]
                main_speaker_job = episode[item]['speaker_job'][episode[item]['speaker_list'].index(main_speaker_name)]
                sub_speaker_name = episode[item][f'{sid}_session_dialogue'][0][1]
                sub_speaker_job = episode[item]['speaker_job'][episode[item]['speaker_list'].index(sub_speaker_name)]
                prefix = f"[{main_speaker_name}] {main_speaker_job} [{sub_speaker_name}] {sub_speaker_job} "
            except:
                prev_sid = sid
                continue

            for idx in range(len(episode[item][f'{sid}_session_dialogue'][0]) - 1):
                current_speaker = episode[item][f'{sid}_session_dialogue'][0][idx].replace('"', '')
                next_speaker = episode[item][f'{sid}_session_dialogue'][0][idx + 1].replace('"', '')
                current_utterance = episode[item][f'{sid}_session_dialogue'][1][idx]
                model_input += f"[{current_speaker}] {current_utterance} [{next_speaker}]"
                model_target = episode[item][f'{sid}_session_dialogue'][1][idx + 1]

                memory_sequence = ""
                if sid != 'first':
                    try:
                        target_memory = for_retrieval[item][sid][model_target]
                        for tidx in target_memory:
                            memory_sequence += f"[MEMORY] {all_memory[tidx]} "
                            memory_pair = copy.deepcopy(episode[item][f'{prev_sid}_session_memory_pair'])
                            for pair in memory_pair:
                                if len(pair) == 2 and tidx + 1 in pair:
                                    pair.remove(tidx + 1)
                                    neighbor = int(pair[0]) - 1
                                    memory_sequence += f"[LINK] {all_memory[neighbor]} "
                    except:
                        pass

                if len(memory_sequence) > 1:
                    real_inputs = "generation: " + prefix + memory_sequence + f"[NOW] {sid} session " + model_input
                else:
                    real_inputs = "generation: " + prefix + f"[NOW] {sid} session " + model_input

                temp_inputs.append(real_inputs)
                temp_targets.append(model_target)

                model_input += f" {model_target} "

            real_inputs = "generation: " + prefix + f"[NOW] {sid} session " + model_input
            temp_inputs.append(real_inputs)
            temp_targets.append("[END]")

            if sid != 'sixth':
                current_memory = copy.deepcopy(episode[item][f'{sid}_session_memory'])
                for key in current_memory:
                    if len(current_memory[key]) == 0:
                        memory_summary = "[NONE]"
                    elif len(current_memory[key]) == 1:
                        memory_summary = str(current_memory[key])
                    else:
                        memory_summary = " [SEP] ".join(current_memory[key])
                    real_inputs = f"summarize [ABOUT_{key}]: " + prefix + model_input
                    temp_inputs.append(real_inputs)
                    temp_targets.append(memory_summary)
            else:
                memory_pair = copy.deepcopy(episode[item][f'{prev_sid}_session_memory_pair'])
                positive_memory_pair = []
                for pair in memory_pair:
                    if len(pair) != 2:
                        continue
                    else:
                        temp = [pair[0] - 1, pair[1] - 1]
                        positive_memory_pair.append(temp)
                        temp = [pair[1] - 1, pair[0] - 1]
                        positive_memory_pair.append(temp)
                negative_memory_pair = list(permutations(range(len(all_memory)), 2))
                negative_memory_pair = [list(permutation) for permutation in negative_memory_pair]
                for pair in positive_memory_pair:
                    if pair in negative_memory_pair:
                        negative_memory_pair.remove(pair)
                    try:
                        pair_inputs = f"memory sentence1: {all_memory[pair[0]]} memory sentence2: {all_memory[pair[1]]}"
                    except:
                        continue
                    temp_inputs.append(pair_inputs)
                    temp_targets.append("positive")
                for pair in negative_memory_pair:
                    try:
                        pair_inputs = f"memory sentence1: {all_memory[pair[0]]} memory sentence2: {all_memory[pair[1]]}"
                    except:
                        continue
                    temp_inputs.append(pair_inputs)
                    temp_targets.append("negative")

            inputs.extend(temp_inputs)
            targets.extend(temp_targets)

            prev_sid = sid

    print(len(inputs), len(targets))
    data = {'inputs': inputs, 'targets': targets}
    data = Dataset.from_dict(data)

    return data


def preprocess_function(examples, **kwargs):
    inputs = kwargs['tokenizer'](examples['inputs'], max_length=1024, truncation=True)
    targets = kwargs['tokenizer'](examples['targets'], max_length=64, truncation=True)

    inputs["labels"] = targets["input_ids"]

    return inputs


def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


def main():
    MODEL_NAME = "google/flan-t5-large"

    lora_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules=['q', 'v'],
        lora_dropout=0.1,
        bias='none',
        task_type=TaskType.SEQ_2_SEQ_LM
    )

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16,
                                                  quantization_config=bnb_config,
                                                  max_memory={i: '49000MB' for i in range(torch.cuda.device_count())},
                                                  device_map={
                                                      "": "cuda:" + str(int(os.environ.get("LOCAL_RANK") or 0))})

    model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={"use_reentrant": False})
    model = get_peft_model(model, lora_config)
    print_trainable_parameters(model)

    with open('TRAIN_DATA_PATH', 'rb') as f:
        train = pickle.load(f)
    train = format_data(train)

    with open('VALID_DATA_PATH', 'rb') as f:
        valid = pickle.load(f)
    valid = format_data(valid)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.model_max_length = 1024
    tokenized_train = train.map(preprocess_function, fn_kwargs={'tokenizer': tokenizer})
    tokenized_valid = valid.map(preprocess_function, fn_kwargs={'tokenizer': tokenizer})

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    output_dir = "OUTPUT_PATH"
    trainer = Trainer(
        model=model,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_valid,
        args=TrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=52,
            per_device_eval_batch_size=52,
            learning_rate=1e-3,
            bf16=True,
            optim="paged_adamw_32bit",
            logging_strategy="steps",
            logging_steps=100,
            save_strategy="epoch",
        ),
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
    )
    model.config.use_cache = False
    trainer.train()

    peft_model_output = "SAVE_PATH"
    trainer.model.save_pretrained(peft_model_output)
    tokenizer.save_pretrained(peft_model_output)


if __name__ == '__main__':
    main()
