#!/usr/bin/env python
# coding: utf-8

import os
import sys
# In[0]:
language_model = "t5-xl"#sys.argv[1]
do_generation = "True"#sys.argv[2]
do_generation = True if do_generation == "True" else False
print("Language Model:", language_model)
print("Do Generation:", do_generation)

experiment_name = "t5-xl-v0.5-mix"
model_name = "google/t5-xl-lm-adapt"
deepspeed_config = "ds_config_zero2.json"
train_batch_size = 1
eval_batch_size = 2
gradient_accumulation_steps = 1


os.environ["WANDB_PROJECT"] = "t5_simple_mix_trainer_v0.4"
os.environ["WANDB_NOTEBOOK_NAME"] = "t5_xl_trainer"

# In[1]:
## Imports

import torch
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import json
import datasets
from datasets import load_dataset, Dataset, DatasetDict
import random
from torch.utils.data import DataLoader
import evaluate
import nltk
import numpy as np
from transformers import GenerationConfig
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from evaluate import evaluator
import wandb


device = "cuda"

# In[2]:
## Loading Data
task_train = Dataset.from_json("task_train.json")
task_val = Dataset.from_json("task_val.json")
task_test = Dataset.from_json("task_test.json")

corpus_train = Dataset.from_json("corpus_train_length.json")
corpus_val = Dataset.from_json("corpus_validation_length.json")
corpus_test = Dataset.from_json("corpus_test_length.json")


combined_raw_dataset = DatasetDict({"train":datasets.concatenate_datasets([task_train, corpus_train]),
                                  "validation":datasets.concatenate_datasets([task_val, corpus_val]),
                                  "test": datasets.concatenate_datasets([task_test, corpus_test])})


example = combined_raw_dataset['train'][4]

print("Instruction:", example["input"])
print()
print("Text:", example["output"])


# In[3]:
## Preprocessing

tokenizer = T5Tokenizer.from_pretrained(model_name)

max_input_length = 512
max_target_length = 1024

def preprocess_function(examples):
    model_inputs = tokenizer(examples["input"], max_length=max_input_length, padding="max_length", truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["output"], max_length=max_target_length, truncation=True)


    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

tokenized_datasets = combined_raw_dataset.map(preprocess_function, batched=True)

# In[4]:
## Finetuning

def model_init():
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    model.config.use_cache = False
    # for name, param in model.named_parameters():
    #     if len(name.split("."))>2 and name.split(".")[2].isdigit() and int(name.split(".")[2])<16:
    #         param.requires_grad = False
    # print(sum(p.numel() for p in model.parameters() if p.requires_grad))
    return model

args = Seq2SeqTrainingArguments(
    f"Models/"+experiment_name,
    evaluation_strategy = "steps",
    save_strategy='steps',
    save_steps=100,
    logging_strategy="steps",
    logging_steps=50,
    eval_steps=100,
    num_train_epochs=3,
    learning_rate=2e-5,
    per_device_train_batch_size=train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=True,
    per_device_eval_batch_size=eval_batch_size,
#     weight_decay=0.01,
#     warmup_ratio=0.1,
    save_total_limit=3,
    prediction_loss_only=True,
    # fp16=True,
    bf16=True,
    bf16_full_eval=True,
    report_to="wandb",
    deepspeed=deepspeed_config,
    generation_max_length=max_target_length,
    run_name=experiment_name,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    # hub_private_repo=True,
    # push_to_hub=True,
    # hub_strategy="end"
)

data_collator = DataCollatorForSeq2Seq(tokenizer)

trainer = Seq2SeqTrainer(
    args = args,
    model_init=model_init,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer
)


trainer.train()
trainer.save_model(f"Models/"+experiment_name)
# trainer.push_to_hub()

# In[5]:
## Evaluation

val_corpus_loss = trainer.evaluate(tokenized_datasets["validation"])["eval_loss"]
test_corpus_loss = trainer.evaluate(tokenized_datasets["test"])["eval_loss"]


# In[6]:
## Generation

def generate(custom_input, seed=0):
    import torch
    torch.manual_seed(seed)
    sample = True
    input_ids = tokenizer(custom_input, return_tensors="pt").input_ids.to(device)
    outputs = trainer.model.generate(input_ids,max_length=1024,
                                     temperature=0.7,
                                     no_repeat_ngram_size=5,
                                     do_sample=sample,
                                     num_return_sequences=1)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_batch(dataset):
    generated_text = []
    for el in tqdm(dataset):
        new_el = el.copy()
        prediction = generate(el["input"])
        new_el["prediction"] = prediction
        generated_text.append(new_el)
    return generated_text



## Generate for corpus
base_generation_path = "Generation/"+experiment_name
os.makedirs(base_generation_path, exist_ok=True)

val_corpus_pred = generate_batch(corpus_val)
with open(os.path.join(base_generation_path, "val_corpus.json"), "w") as f:
    json.dump(val_corpus_pred, f, indent=2)

test_corpus_pred = generate_batch(corpus_test)
with open(os.path.join(base_generation_path, "test_corpus.json"), "w") as f:
    json.dump(test_corpus_pred, f, indent=2)

## Generate for task
val_task_pred = generate_batch(task_val)
with open(os.path.join(base_generation_path, "val_task.json"), "w") as f:
    json.dump(val_task_pred, f, indent=2)

test_task_pred = generate_batch(task_test)
with open(os.path.join(base_generation_path, "test_task.json"), "w") as f:
    json.dump(test_task_pred, f, indent=2)

# Empty cache for BERTScore
torch.cuda.empty_cache()


# In[7]:
## Evaluation Metrics

def calculate_rouge(dataset, prefix):
    metric = evaluate.load("rouge")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_"+prefix+"_"+x:y for x,y in metric.compute(predictions=preds, references=labels).items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_"+prefix+"_"+x:y for x,y in metric.compute(predictions=preds, references=labels).items()}}
    return results

def calculate_bleu(dataset, prefix):
    metric = evaluate.load("bleu")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels).items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels).items()}}
    return results

def calculate_bertscore(dataset, prefix):
    metric = evaluate.load("bertscore")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_bertscore_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels, lang="en").items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_bertscore_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels, lang="en").items()}}
    return results

def gen_length(dataset, prefix):
    return {prefix+"_"+"gen_length": np.mean([len(x["prediction"].split()) for x in dataset])}


results = {"val_corpus_loss":val_corpus_loss, "test_corpus_loss":test_corpus_loss}

for dataset, prefix in [(val_corpus_pred, "val_corpus"),
                       (test_corpus_pred, "test_corpus"),
                       (val_task_pred, "val_task"),
                       (test_task_pred, "test_task")]:
    r1 = calculate_rouge(dataset, prefix)
    r2 = calculate_bleu(dataset, prefix)
    r3 = calculate_bertscore(dataset, prefix)
    r4 = gen_length(dataset, prefix)
    results = {**results, **r1, **r2, **r3, **r4}

with open(f"Generation/{experiment_name}/results.json", "w") as f:
    json.dump(results, f, indent=2)

wandb.log(results)
wandb.finish()