#!/usr/bin/env python
# coding: utf-8

import os
import sys
import argparse

# In[0]:
argparse = argparse.ArgumentParser()
argparse.add_argument("--train_filepath", type=str, default="train.json")
argparse.add_argument("--experiment_name", type=str, default=None)
argparse.add_argument("--model_name", type=str, default="llama")
argparse.add_argument("--local_rank", type=int, default=0)
argparse.add_argument("--master_port", type=str, default="29500")

print("Using Llama")
model_name = "/mounts/data/corp/huggingface/llama/llama-7b"
tokenizer_name = "/mounts/data/corp/huggingface/llama/tokenizer"


train_filepath = argparse.parse_args().train_filepath
experiment_name = argparse.parse_args().experiment_name

do_generation = "True"#sys.argv[2]
do_generation = True if do_generation == "True" else False
print("Language Model:", model_name)
print("Do Generation:", do_generation)
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
# DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "</s>"
DEFAULT_SEP_TOKEN = "[EOI]"


deepspeed_config = "ds_config_zero3.json"
train_batch_size = 8
gradient_accumulation_steps = 2
eval_batch_size = 8
learning_rate = 5e-6
logging_steps = 25 #50
save_steps = 50 #100

wd = 0
lr_scheduler_type = "linear"
lr_scheduler_text = "" if lr_scheduler_type == "linear" else f"-{lr_scheduler_type}"
wd_text = "" if wd == 0 else f"-wd{wd}"

if experiment_name is None:
    experiment_name = "llama2-7b"+"-v0.7-mix-"+str(learning_rate)+wd_text+lr_scheduler_text


os.environ["WANDB_PROJECT"] = experiment_name
if argparse.parse_args().model_name == "llama2":
    os.environ["WANDB_NOTEBOOK_NAME"] = "llama2-7b_deepspeed_trainer"
elif argparse.parse_args().model_name == "llama":
    os.environ["WANDB_NOTEBOOK_NAME"] = "llama-7b_deepspeed_trainer"
os.environ["TORCH_DISTRIBUTED_DEFAULT_PORT"] = "69500"
os.environ['MASTER_PORT'] = str(argparse.parse_args().master_port)
os.environ["LOCAL_RANK"] = str(argparse.parse_args().local_rank)

# In[1]:
## Imports

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer
import json
import datasets
from datasets import load_dataset, Dataset, DatasetDict
import random
from torch.utils.data import DataLoader
import evaluate
import nltk
import numpy as np
from transformers import GenerationConfig
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from evaluate import evaluator
import wandb


device = "cuda"

# In[2]:
## Loading Data

with open(train_filepath, "r") as f:
    data = json.load(f)
v = {k: [dic[k] for dic in data] for k in data[0]}
train_datasets = Dataset.from_dict(v)

with open("eval.json", "r") as f:
    data = json.load(f)
v = {k: [dic[k] for dic in data] for k in data[0]}
val_datasets = Dataset.from_dict(v)

with open("test.json", "r") as f:
    data = json.load(f)
v = {k: [dic[k] for dic in data] for k in data[0]}
test_datasets = Dataset.from_dict(v)
combined_raw_dataset = DatasetDict({"train":train_datasets,
                                    "validation":val_datasets,
                                    "test": test_datasets})

combined_raw_dataset = combined_raw_dataset.shuffle(seed=42)


example = combined_raw_dataset['train'][4]

print("Instruction:", example["input"])
print()
print("Text:", example["output"])


print(combined_raw_dataset)
# In[3]:
## Preprocessing

context_length = 1024
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False, add_eos_token=True)
tokenizer_inference = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)

tokenizer.add_tokens(["[EOI]"])
tokenizer_inference.add_tokens(["[EOI]"])
tokenizer.add_special_tokens(
            {
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
                "pad_token": DEFAULT_PAD_TOKEN,
            }
        )

tokenizer_inference.add_special_tokens(
            {
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
                "pad_token": DEFAULT_PAD_TOKEN,
            }
        )
print(len(tokenizer))

def preprocess_function(examples):
    return tokenizer([inp+" [EOI]"+outp for inp, outp in zip(examples["input"], examples["output"])])
    
tokenized_dataset = combined_raw_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=combined_raw_dataset["train"].column_names,
)

block_size = 1024


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size

    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=4)

# tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# In[4]:
## Finetuning

def model_init():
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    model.resize_token_embeddings(len(tokenizer))
    model.config.use_cache = False
    return model

training_args = TrainingArguments(
    output_dir=f"Models/"+experiment_name,
    evaluation_strategy = "steps",
    save_strategy='steps',
    save_steps=save_steps,
    logging_strategy="steps",
    logging_steps=logging_steps,
    eval_steps=save_steps,
    num_train_epochs=3,
    learning_rate=learning_rate,
    weight_decay=wd,
    lr_scheduler_type=lr_scheduler_type,
    per_device_train_batch_size=train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=True,
    per_device_eval_batch_size=eval_batch_size,
    save_total_limit=3,
    prediction_loss_only=True,
    # fp16=True,
    warmup_ratio=0.1,
    bf16=True,
    bf16_full_eval=True,
    report_to="wandb",
    deepspeed=deepspeed_config,
    run_name=experiment_name,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

trainer.train()

trainer.save_model(f"Models/"+experiment_name)
# trainer.push_to_hub()

# In[5]:
## Evaluation

val_corpus_loss = trainer.evaluate(lm_dataset["validation"])["eval_loss"]
test_corpus_loss = trainer.evaluate(lm_dataset["test"])["eval_loss"]


# In[6]:
## Generation

def generate(custom_input, seed=0):
    import torch
    torch.manual_seed(seed)
    sample = True
    input_ids = tokenizer_inference(custom_input, return_tensors="pt").input_ids.to(device)
    outputs = trainer.model.generate(input_ids,max_length=1024,
                                     temperature=0.7,
                                     no_repeat_ngram_size=5,
                                     do_sample=sample,
                                     num_return_sequences=1)
    return tokenizer_inference.decode(outputs[0], skip_special_tokens=False)

def generate_batch(dataset):
    generated_text = []
    for el in tqdm(dataset):
        new_el = el.copy()
        prediction = generate(el["input"]+" [EOI]")
        new_el["prediction"] = prediction
        generated_text.append(new_el)
    return generated_text



## Generate for corpus
combined_raw_dataset["test"] = combined_raw_dataset["test"].shuffle(seed=42)
base_generation_path = "Generation/"+experiment_name
os.makedirs(base_generation_path, exist_ok=True)

test_raw_path = os.path.join(base_generation_path, "test_raw")
os.makedirs(test_raw_path, exist_ok=True)

test_pred = []
for idx, el in tqdm(enumerate(combined_raw_dataset["test"])):
    if os.path.exists(os.path.join(test_raw_path, str(idx+1)+".json")):
        print("Skipping", idx+1, "because it exists")
        with open(os.path.join(test_raw_path, str(idx+1)+".json"), "r") as f:
            test_pred.append(json.load(f))
        continue

    new_el = el.copy()
    output = generate(el["input"])
    new_el["prediction"] = output
    test_pred.append(new_el)
    with open(os.path.join(test_raw_path, str(idx+1)+".json"), "w") as f:
        json.dump(new_el, f, indent=2)

with open(os.path.join(base_generation_path, "test.json"), "w") as f:
    json.dump(test_pred, f, indent=2)

combined_raw_dataset["validation"] = combined_raw_dataset["validation"].shuffle(seed=42)
val_raw_path = os.path.join(base_generation_path, "val_raw")
os.makedirs(val_raw_path, exist_ok=True)

val_pred = []
for idx, el in tqdm(enumerate(combined_raw_dataset["validation"])):
    if os.path.exists(os.path.join(val_raw_path, str(idx+1)+".json")):
        print("Skipping", idx+1, "because it exists")
        with open(os.path.join(val_raw_path, str(idx+1)+".json"), "r") as f:
            val_pred.append(json.load(f))
        continue
    new_el = el.copy()
    output = generate(el["input"])
    new_el["prediction"] = output
    val_pred.append(new_el)
    
    with open(os.path.join(val_raw_path, str(idx+1)+".json"), "w") as f:
        json.dump(new_el, f, indent=2)


with open(os.path.join(base_generation_path, "val.json"), "w") as f:
    json.dump(val_pred, f, indent=2)


# Empty cache for BERTScore
torch.cuda.empty_cache()


# In[7]:
## Evaluation Metrics

def calculate_rouge(dataset, prefix):
    metric = evaluate.load("rouge")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_"+prefix+"_"+x:y for x,y in metric.compute(predictions=preds, references=labels).items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_"+prefix+"_"+x:y for x,y in metric.compute(predictions=preds, references=labels).items()}}
    return results

def calculate_bleu(dataset, prefix):
    metric = evaluate.load("bleu")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels).items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels).items()}}
    return results

def calculate_bertscore(dataset, prefix):
    metric = evaluate.load("bertscore")
    subsets = set([(x["style"], x["source"]) for x in dataset])
    preds = [x["prediction"] for x in dataset]
    labels = [x["output"] for x in dataset]
    results = {"all_bertscore_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels, lang="en").items()}
    for subset in subsets:
        preds = [x["prediction"] for x in dataset if (x["style"], x["source"])==subset]
        labels = [x["output"] for x in dataset if (x["style"], x["source"])==subset]
        results = {**results, **{subset[0]+"_bertscore_"+prefix+"_"+x:np.mean(y) if type(y)==list else y for x,y in metric.compute(predictions=preds, references=labels, lang="en").items()}}
    return results

def gen_length(dataset, prefix):
    return {prefix+"_"+"gen_length": np.mean([len(x["prediction"].split()) for x in dataset])}


results = {"val_loss":val_corpus_loss, "test_loss":test_corpus_loss}

for dataset, prefix in [(val_pred, "val"),
                       (test_pred, "test")]:
    r1 = calculate_rouge(dataset, prefix)
    r2 = calculate_bleu(dataset, prefix)
    r3 = calculate_bertscore(dataset, prefix)
    r4 = gen_length(dataset, prefix)
    results = {**results, **r1, **r2, **r3, **r4}

with open(f"Generation/{experiment_name}/results.json", "w") as f:
    json.dump(results, f, indent=2)

wandb.log(results)
wandb.finish()