import os
import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig
from trl import SFTTrainer
import transformers
import evaluate
import json
from tqdm import tqdm
from huggingface_hub import login
import pandas as pd


login()
output_dir = "./results_7b_llama_chat_sft_no_packing"
# Where to load model results
final_checkpoint_dir = os.path.join(output_dir, "final_merged_checkpoint")
print(f"Final checkpoint dir: {final_checkpoint_dir}")
# Load the entire model on the GPU 0
device_map = {"": 0}
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
    final_checkpoint_dir,
    device_map=device_map,
    return_dict=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    cache_dir="model_cache",
)
model.config.use_cache = True
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir, use_fast=True)


# # Set the padding direction to the right
# tokenizer.padding_side = "right"

reloaded_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1000, return_full_text=False, clean_up_tokenization_spaces=True, repetition_penalty=1.3)

data_path = "/data"
with open(os.path.join(data_path, "test_tokenized_filter.json"), "r") as data_file:
    final_dataset = json.load(data_file)
questions = []
answers_1 = []
answers_2 = []
predictions = []

for el in final_dataset:
    questions.append(f"[INST] {el['text']} {el['context']} [/INST]")
    answers_1.append(el["answer_1"])
    answers_2.append(el["answer_2"])

# responses = tqdm(reloaded_generator(questions), total=len(questions))
# predictions = [el[0]["generated_text"] for el in responses]
# for idx in tqdm(range(0, len(questions), batch_size), total=len(questions)):
#     if idx + batch_size >= len(questions):
#         responses = reloaded_generator(questions[idx:len(questions)])
#     else:
#         responses = reloaded_generator(questions[idx:idx+batch_size])
#     responses = [response[0]["generated_text"] for response in responses]
#     predictions += responses

for el in tqdm(questions, total=len(questions)):
    response = reloaded_generator(el)
    predictions.append(response[0]["generated_text"])

final_result = []
for idx, el in enumerate(questions):
    cur_result = {
        "questions": el,
        "prediction": predictions[idx],
        "reference_1": answers_1[idx],
        "reference_2": answers_2[idx]
    }
    final_result.append(cur_result)
output_file_predictions = "llama_chat_sft_dpo_rep_1_3.json"
with open(output_file_predictions, "w") as f:
    json.dump(final_result, f)

bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=answers_1)

print(results)

# Load the ROUGE evaluation metric
rouge = evaluate.load('rouge')

# Compute the ROUGE score
results = rouge.compute(predictions=predictions, references=answers_1)

print(results)

# Load the METEOR evaluation metric
meteor = evaluate.load('meteor')

# Compute the METEOR score
results = meteor.compute(predictions=predictions, references=answers_1)

print(results)

