import json
import argparse


PUNCTUATION_SET_TO_EXCLUDE = set(''.join(['‘', '’', '´', '`', '.', ',', '-', '"', '\'', '[', ']', '{', '}', '(', ')',
                                          '!', '?', '\\', "\"", "or", "and", "&", "|", "a", "an", "the", "in", "on",
                                          "at", "his", "her", "its", "it"]))


def read_data(file_path):
    """
    Reads the raw file and returns a list.
    Args:
        file_path (str): file path.

    Returns:

    """
    dataset = []
    count = 0
    with open(file_path, 'r', encoding='utf8') as file:
        for row in file:
            if not row.strip():
                continue
            data = json.loads(row.strip())
            count += 1
            data_output = data.get("answer", data.get("output"))
            model_predict = data["model_predict"]
            reference = data["prompt"]

            data["model_predict"] = model_predict
            data["answer"] = data_output
            data["prompt"] = reference

            if isinstance(data["model_predict"], dict):
                data["model_predict"] = " "
            try:
                data["predict"] = data["model_predict"].replace("ā", "a")
                data["refer"] = data["prompt"]
                data["link"] = data["prompt"]
                dataset.append(data)
            except KeyError:
                continue
    return dataset


def remove_punctuation(text):
    """
    Removes punctuation from the text.
    Args:
        text ():

    Returns:

    """
    text = text.replace('_', ' ').lower()
    text = ''.join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else ' ' for c in text)
    text = ' '.join(text.split()).strip()
    return text


def evaluate(dataset, eval_column):
    """
    Evaluate the dataset based on the specified column.
    Args:
        dataset ():
        eval_column ():

    Returns:

    """
    correct, total = 0, 0
    for sample in dataset:
        predict = sample[eval_column]
        for label in sample['answer']:
            label = remove_punctuation(label).lower()
            predict = remove_punctuation(predict).lower()
            if label in predict:
                correct += 1
                break
        total += 1
    accuracy = correct / total if total > 0 else 0
    return {"eval_type": "accuracy", "column_name": eval_column, "accuracy": accuracy, "correct": correct,
            "total": total}


def main(evaluate_task_data_path):
    """main"""
    dataset = read_data(evaluate_task_data_path)
    columns = ["predict", "prompt"]

    for column in columns:
        result = evaluate(dataset, column)
        print(result)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate the model predictions.')
    parser.add_argument('file_path', type=str, help='Path to the evaluation data file.')
    args = parser.parse_args()

    main(args.file_path)
