import json
import string
import pandas as pd
from tqdm import tqdm
import collections

file_list = [
    "llama_chat_dpo_rep_1_2.json",
    "llama_chat_rep_1_2.json",
    "llama_chat_sft_dpo_rep_1_2.json",
    "llama_chat_sft_rep_1_2.json",
    "llama_rep_1_2.json",
    "llama_sft_dpo_rep_1_2.json",
    "llama_sft_rep_1_2.json",
    "mistral_rep_1_2.json",
    "mistral_sft_dpo_rep_1_2.json",
    "mistral_sft_rep_1_2.json",
    "zephyr_dpo_rep_1_2.json",
    "zephyr_rep_1_2.json",
    "zephyr_sft_dpo_rep_1_2.json",
    "zephyr_sft_rep_1_2.json",
]

file_data = {}

human_eval = pd.read_csv("human_eval_results.csv", sep=";", decimal=",", index_col="Question Index", encoding='utf8')
human_eval.fillna("", inplace=True)
matched_counts = []
for file in file_list:
    with open(f"eval/final_data_to_handle/{file}", "r") as inputfile:
        cur_dataset = json.load(inputfile)
    for el in cur_dataset:
        el["questions"] = el["questions"].removeprefix("[INST] ")
        el["questions"] = el["questions"].removesuffix(" [/INST]")
        el["questions"] = el["questions"].removeprefix("<|system|>\n\n<|user|>\n")
        el["questions"] = el["questions"].removeprefix("<|system|>\n</s>\n<|user|>\n")
        el["questions"] = el["questions"].removesuffix("\n<|assistant|>\n")
        el["questions"] = el["questions"].removesuffix("</s>\n")
    
    file_result = []
    count = 0
    allowed_chars = string.ascii_letters
    last_idx = None
    for idx, row in human_eval.iterrows():
        if count in matched_counts or idx == last_idx:
            count += 1
            continue
        human_eval_prediction = row["prediction"].strip()
        human_eval_prediction_filtered = ""
        for letter in human_eval_prediction:
            if letter in allowed_chars:
                human_eval_prediction_filtered = human_eval_prediction_filtered + letter
        entry = cur_dataset[idx]
        entry_answer = entry["prediction"].strip().replace("&x200B", "").replace("\n&x200B", "")
        entry_answer_filtered = ""
        for letter in entry_answer:
            if letter in allowed_chars:
                entry_answer_filtered = entry_answer_filtered + letter
        if human_eval_prediction_filtered == entry_answer_filtered:
            new_el = {
                "prediction": entry_answer,
                "Relevance": row["Relevance"],
                "Specificity": row["Specificity"],
                "Simplicity": row["Simplicity"],
                "Helpfulness": row["Helpfulness"],
                "Objectivity": row["Objectivity"],
            }
            file_result.append(new_el)
            matched_counts.append(count)
            last_idx = idx
        count += 1
    if len(file_result) < 50:
        print(f"{50 - len(file_result)} Entries missing for file: {file}")
    if len(file_result) == 50:
        print(f"{file} complete!")
    if len(file_result) > 50:
        print(f"{len(file_result) - 50} Entries too much for file: {file}")
    file_result_df = pd.DataFrame(file_result)
    new_file_name = file.split(".")[0]
    file_result_df.to_csv(f"predictions_human/{new_file_name}.csv", index=False)
# unmatched_counts = []
# for i in range(0,700):
#     if i not in matched_counts:
#         unmatched_counts.append(i + 2)
# print(matched_counts)
# print(len(matched_counts))
# print([(item, count) for item, count in collections.Counter(matched_counts).items() if count > 1])
# print(unmatched_counts)
# print(len(unmatched_counts))
