import evaluate
import os
import json
from tqdm import tqdm
import inspect
import textstat
import pandas as pd


def evaluate_text_stat(text_values):
    results = {}
    metrics = [inspect.getmembers(textstat, predicate=inspect.ismethod)[i][0] for i in range(1,47)]
    skip_list = ["count_complex_arabic_words", "count_faseeh", "difficult_words_list", "remove_punctuation", "set_lang", "set_rm_apostrophe", "set_rounding", "count_arabic_long_words", "count_arabic_syllables", "fernandez_huerta", "szigriszt_pazos", "gutierrez_polini", "crawford", "gulpease_index", "is_difficult_word", "is_easy_word", "letter_count", "lexicon_count", "lix", "long_word_count", "miniword_count", "monosyllabcount", "osman", "polysyllabcount", "rix", "smog_index", "spache_readability", "syllable_count", "avg_sentence_per_word", "avg_letter_per_word"]
    for metric in metrics:
        # print(metric)
        if metric in skip_list:
            continue
        results[metric] = 0

    for text in tqdm(text_values):
        for elem in metrics:
            if elem in skip_list:
                continue
            method = eval("textstat." + elem)
            textstat.set_lang("en")
            if text == '\n' or text == "????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????" or text == '🙏':
                text = ""
            if elem == "text_standard":
                results[elem] += method(text, float_output=True)
            else:
                temp = method(text)
                results[elem] += temp

    for elem in metrics:
        if elem in skip_list:
                continue
        results[elem] = round(results[elem] / len(text_values), 2)
        print(f"Metric {elem} average: {results[elem]}")
    return results


prediction_files = os.listdir("eval/final_data_13")
# prediction_files = [
#     "llama_base",
#     "llama_sft",
#     "llama_sft_non_quantized",
#     "llama_sft_new_non_quantized",
#     "llama_sft_no_wsb_non_quantized",
#     "llama_chat_base",
#     "llama_chat_sft",
#     "llama_chat_base_dpo_non_quantized",
#     "mistral_base_non_quantized",
#     "mistral_sft_non_quantized",
#     "zephyr_base_non_quantized",
#     "zephyr_sft_non_quantized",
# ]
metrics = ["bleu", "rouge", "meteor", "bleurt", "bertscore"]
overall_results = {}
for file in tqdm(prediction_files):
    print(f"Handling file {file}")
    overall_results[file] = {}
    with open(os.path.join("eval/final_data", file), "r") as input_file:
        cur_dataset = json.load(input_file)
    predictions = [el["prediction"] for el in cur_dataset]
    if "reference" in cur_dataset[0].keys():
        references = [el["reference"] for el in cur_dataset]
    else:
        references = [el["reference_1"] for el in cur_dataset]
    for metric in metrics:
        if metric == "bleurt":
            cur_metric = evaluate.load(metric, "bleurt-large-512")
            result = cur_metric.compute(predictions=predictions, references=references)
            score = sum(result["scores"]) / len(result["scores"])
            overall_results[file][metric] = score
        elif metric == "bertscore":
            cur_metric = evaluate.load(metric)
            result = cur_metric.compute(predictions=predictions, references=references, lang="en")
            precision = sum(result["precision"]) / len(result["precision"])
            recall = sum(result["recall"]) / len(result["recall"])
            f1 = sum(result["f1"]) / len(result["f1"])
            overall_results[file]["bertscore_precision"] = precision
            overall_results[file]["bertscore_recall"] = recall
            overall_results[file]["bertscore_f1"] = f1
        elif metric == "bleu":
            cur_metric = evaluate.load(metric)
            result = cur_metric.compute(predictions=predictions, references=references)
            overall_results[file][metric] = result[metric]
        elif metric == "rouge":
            cur_metric = evaluate.load(metric)
            result = cur_metric.compute(predictions=predictions, references=references)
            overall_results[file]["rouge1"] = result["rouge1"]
            overall_results[file]["rouge2"] = result["rouge2"]
            overall_results[file]["rougeL"] = result["rougeL"]
            overall_results[file]["rougeLsum"] = result["rougeLsum"]
        elif metric == "meteor":
            cur_metric = evaluate.load(metric)
            result = cur_metric.compute(predictions=predictions, references=references)
            overall_results[file][metric] = result[metric]
        else:
            print("select proper metric!")
            break
    textstat_results = evaluate_text_stat(predictions)
    overall_results[file].update(textstat_results)
overall_results_df = pd.DataFrame(overall_results)
overall_results_df.to_csv("eval_results_final_data_13.csv", decimal=",", sep=";")
# with open("eval_results_metrics_textstat", "w") as outputfile:
#     json.dump(overall_results, outputfile)