import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
import json
from tqdm import tqdm
from huggingface_hub import login
from collections import Counter

login()
output_dir = "./results_7b_llama_chat_sft_no_packing"
# Where to load model results
final_checkpoint_dir = os.path.join(output_dir, "final_merged_checkpoint")
print(f"Final checkpoint dir: {final_checkpoint_dir}")
# Load the entire model on the GPU 0
device_map = {"": 0}
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
    final_checkpoint_dir,
    device_map=device_map,
    return_dict=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    cache_dir="model_cache",
)
model.config.use_cache = True
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir, use_fast=True)


# # Set the padding direction to the right
# tokenizer.padding_side = "right"

reloaded_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1, return_full_text=False, clean_up_tokenization_spaces=True, repetition_penalty=1.0)

data_path = "/data"
with open(os.path.join(data_path, "test_tokenized_filter.json"), "r") as data_file:
    test_dataset = json.load(data_file)

generated_data_path = "./final_data_to_handle"
file = "zephyr_sft_rep_1_2.json"
with open(os.path.join(generated_data_path, file), "r") as generated_file:
    generated_dataset = json.load(generated_file)

questions = []
answers_1 = []
answers_2 = []
predicted_answers = []
predictions = []
both_values = []

for idx, el in enumerate(test_dataset):
    questions.append(f"{el['text']} {el['context']}")
    answers_1.append(el["answer_1"])
    answers_2.append(el["answer_2"])
    predicted_answers.append(generated_dataset[idx]["prediction"])

PROMPT_TEMPLATE = """
[INST]
Question:
{question}

Comment A: {comment_a}

Comment B: {comment_b}

Which comment is a better response to the question, Comment A or Comment B?
Only answer with 'A' or 'B'.
[/INST]
"""

for idx, el in tqdm(enumerate(questions), total=len(questions)):
    prompt_1 = PROMPT_TEMPLATE.format(
        question=el,
        comment_a=answers_1[idx],
        comment_b=predicted_answers[idx]
    )
    response_1 = reloaded_generator(prompt_1)[0]["generated_text"]
    prompt_2 = PROMPT_TEMPLATE.format(
        question=el,
        comment_a=predicted_answers[idx],
        comment_b=answers_1[idx]
    )
    response_2 = reloaded_generator(prompt_2)[0]["generated_text"]
    if response_1 == "A" and response_2 == "A":
        predictions.append("Tie")
    elif response_1 == "B" and response_2 == "B":
        predictions.append("Tie")
    elif response_1 == "A" and response_2 == "B":
        predictions.append("Initial Data")
    elif response_1 == "B" and response_2 == "A":
        predictions.append("Predicted Answer")
    elif response_1 == "A" or response_2 == "B":
        predictions.append("Initial Data")
    elif response_1 == "B" or response_2 == "A":
        predictions.append("Predicted Answer")
    else:
        predictions.append("Unknown")
    both_values.append([response_1, response_2])
print("Predictions: ")
print(Counter(predictions))

final_result = []
for idx, el in enumerate(questions):
    cur_result = {
        "questions": el,
        "prediction": predictions[idx],
        "both_predictions": both_values[idx],
        "reference_1": answers_1[idx],
        "reference_2": answers_2[idx]
    }
    final_result.append(cur_result)

output_file_predictions = f"./pairwise_evaluation/{file}"
with open(output_file_predictions, "w") as f:
    json.dump(final_result, f)
