import os
import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig
from trl import SFTTrainer
import transformers
import evaluate
import json
from tqdm import tqdm
from huggingface_hub import login
import pandas as pd


login()
# output_dir = "./results_7b_llama_chat_sft_no_packing"
# Where to load model results
# final_checkpoint_dir = os.path.join(output_dir, "final_merged_checkpoint")
final_checkpoint_dir = "meta-llama/Llama-2-7b-hf"
print(f"Final checkpoint dir: {final_checkpoint_dir}")
# Load the entire model on the GPU 0
device_map = {"": 0}
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
    final_checkpoint_dir,
    device_map=device_map,
    return_dict=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    cache_dir="model_cache",
)
model.config.use_cache = True
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir, use_fast=True)


# # Set the padding direction to the right
# tokenizer.padding_side = "right"

reloaded_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=50, return_full_text=False, clean_up_tokenization_spaces=True)

data_path = "/data"
with open(os.path.join(data_path, "poll_data_inference_question.json"), "r") as data_file:
    final_dataset = json.load(data_file)
questions = []
answers_1 = []
answers_2 = []
predictions = []
amount_of_options = []

for el in final_dataset:
    # questions.append(f"[INST] <<SYS>>\nYou are an artificial intelligence assistant trained on Reddit Data.\nYou are supposed to return the number of the poll option which in your opinion will get the most votes within the mentioned subreddit.\n<</SYS>>\n\n This is a poll from the subreddit Language and was created on 2020/07/08.\n Poll title: Translation Question\n Poll context: What is the portuguese translation for red? \n The following are the options to choose from in this poll:\n\nOption 1: laranja\n\nOption 2: verde\n\nOption 3: azul\n\n Option 4: vermelho\n\n Which option do you think will get the most votes within the subreddit Language? \n Answer ONLY with a number. This is the list of available options: [1, 2, 3, 4] \n Answer Form: \n The most fitting option number is [/INST]\n4</s>\n\n<s>[INST] {el['question']} [/INST]")
    questions.append(f"[INST] <<SYS>>\nYou are an artificial intelligence assistant trained on Reddit Data.\nYou are supposed to return the number of the poll option which in your opinion will get the most votes within the mentioned subreddit.\n<</SYS>>\n\n This is a poll from the subreddit Language and was created on 2020/07/08.\n Poll title: Translation Question\n Poll context: What is the portuguese translation for red? \n The following are the options to choose from in this poll:\n\nOption 1: laranja\n\nOption 2: verde\n\nOption 3: azul\n\n Option 4: vermelho\n\n Which option would you choose? \n Answer ONLY with a number. This is the list of available options: [1, 2, 3, 4] \n Answer Form: \n The most fitting option number is [/INST]\n4</s>\n\n<s>[INST] {el['question']} [/INST]")
    answers_1.append(el["answer"])
    answers_2.append(el["answer_idx"])
    amount_of_options.append(el["amount_of_options"])


for el in tqdm(questions, total=len(questions)):
    response = reloaded_generator(el)
    predictions.append(response[0]["generated_text"])

final_result = []
for idx, el in enumerate(questions):
    cur_result = {
        "questions": el,
        "prediction": predictions[idx],
        "reference": answers_1[idx],
        "reference_idx": answers_2[idx],
        "amount_of_options": amount_of_options[idx]
    }
    final_result.append(cur_result)
output_file_predictions = "poll_predictions/llama_sft_poll_data_50.json"
with open(output_file_predictions, "w") as f:
    json.dump(final_result, f)

