import os
import json
import re
import ast

reference_data_mapping = {
    "HalluEval": "HalluEval_marked.json",
    "HalluQA": "HalluQA_marked.json",
    "Chinese_Domain": "Chinese_Domain_marked.json",
    "triviaQA": "triviaQA_marked.json",
    "WebQA": "WebQA_marked.json"
}


def load_reference_data(filename):
    """ Helper function to load reference data based on the filename mapping. """
    for key, value in reference_data_mapping.items():
        if key in filename:
            with open(os.path.join('reference_data', value), 'r') as file:
                reference_data = json.load(file)
            reference_en_dict = {el['en_question']: el['en_low'] for el in reference_data}
            reference_ch_dict = {el['ch_question']: el['ch_low'] for el in reference_data}
            return reference_en_dict, reference_ch_dict
    return {}, {}


def evaluate_data(filename, include_selection=True, detect_language=False):
    """ Generalized evaluation function to calculate accuracy based on different conditions. """
    reference_en_dict, reference_ch_dict = load_reference_data(filename)
    with open(filename, 'r') as f:
        data = json.load(f)

    en_total = ch_total = plain_en_correct = plain_ch_correct = final_en_correct = final_ch_correct = 0

    for el in data:
        if include_selection and not (el.get('en_selection') and el.get('ch_selection')):
            continue

        en_answer = el.get('en_answer')
        ch_answer = el.get('ch_answer')
        if en_answer:
            en_total += 1
        if ch_answer:
            ch_total += 1

        en_correct = 'correct' in el.get('evaluation_en', '').lower()
        ch_correct = 'correct' in el.get('evaluation_ch', '').lower()

        if en_correct:
            plain_en_correct += 1
        if ch_correct:
            plain_ch_correct += 1

        en_question = el.get('en_question')
        ch_question = el.get('ch_question')

        if detect_language:
            en_lang_check = 'chinese' in el.get('en_selection', '').lower()
            ch_lang_check = any(x in el.get('ch_selection', '').lower() for x in ['英语', '英文', 'english'])
        else:
            en_lang_check = ch_lang_check = True

        if (reference_en_dict.get(en_question) == 1 and ch_correct and en_lang_check) or \
                (reference_en_dict.get(en_question) == 0 and en_correct):
            final_en_correct += 1

        if (reference_ch_dict.get(ch_question) == 1 and en_correct and ch_lang_check) or \
                (reference_ch_dict.get(ch_question) == 0 and ch_correct):
            final_ch_correct += 1

    print(plain_en_correct, plain_ch_correct, final_en_correct, final_ch_correct)
    print('original_en_acc:', plain_en_correct / en_total if en_total else 0)
    print('original_ch_acc:', plain_ch_correct / ch_total if ch_total else 0)
    print('final_en_acc:', final_en_correct / en_total if en_total else 0)
    print('final_ch_acc:', final_ch_correct / ch_total if ch_total else 0)

    return (
        plain_en_correct / en_total if en_total else 0,
        plain_ch_correct / ch_total if ch_total else 0,
        final_en_correct / en_total if en_total else 0,
        final_ch_correct / ch_total if ch_total else 0
    )





def extract_score(text, language='en'):
    """ Extracts the numeric score from a given text based on the language specified. """
    pattern = {
        'en': r"'Overall Score': (\d+(\.\d{1,2})?)",
        'ch': r"'综合得分': (\d+(\.\d{1,2})?)"
    }.get(language)

    if pattern:
        match = re.search(pattern, text)
        return float(match.group(1)) if match else -1


def extract_dictionary(text):
    """ Extracts and returns the last dictionary found in a string as a Python dictionary. """
    text = text.replace('‘', "'").replace('“', '"').replace('”', '"')
    pattern = r'{(.*?)}(?![^{]*{)'  # Matches the last dictionary
    match = re.search(pattern, text)
    if match:
        dictionary_str = match.group(0)
        try:
            return ast.literal_eval(dictionary_str)
        except SyntaxError as e:
            print(f"Syntax error in dictionary conversion: {e}")
    return None


def parse_evaluations(data_list, reference_dicts):
    """ Process evaluation data based on reference dictionaries for scoring adjustments. """
    scores = {
        'original': [],
        'integrated': [],
        'integrated_pure': [],
        'direct_improve': [],
        'avg_improve': []
    }

    for entry in data_list:
        try:
            original_score = extract_score(entry['score_original'])
            integrated_score = extract_score(entry['score_integrated'])

            if original_score is not None:
                scores['original'].append(original_score)
                scores['integrated_pure'].append(integrated_score)

                ref_score = reference_dicts[entry['language']][entry['question']]
                if ref_score:
                    scores['integrated'].append(integrated_score)
                    scores['direct_improve'].append(original_score)
                    scores['avg_improve'].append(integrated_score - original_score)
        except Exception as e:
            print(f"Error processing evaluation data: {e}")

    return scores


def alignbench_eval():
    model_list = ['llama3', 'chatgpt']
    file_path = 'test_data/AlignBench_new.json'
    evaluation_dir = 'evaluation'

    with open(file_path, 'r') as file:
        reference_data = json.load(file)
        ref_en_dict = {item['en_question']: item['en_low'] for item in reference_data}
        ref_ch_dict = {item['ch_question']: item['ch_low'] for item in reference_data}

    bad_cases = []

    for model in model_list:
        model_dir = os.path.join(evaluation_dir, model)
        for file_name in os.listdir(model_dir):
            if 'version-3' in file_name:
                file_path = os.path.join(model_dir, file_name)
                with open(file_path, 'r') as file:
                    data = json.load(file)
                    # Filtering data based on various conditions can be abstracted into a function if complex
                    filtered_data = [d for d in data if should_include(d)]
                    scores = parse_evaluations(filtered_data, {'en': ref_en_dict, 'ch': ref_ch_dict})

                    # Add detailed analysis or further processing of scores here

    with open('bad_case2.json', 'w') as f:
        json.dump(bad_cases, f, indent=4, ensure_ascii=False)


def should_include(entry):
    """ Define conditions under which an entry should be included. """
    excluded_questions = {3, 4, 7, 32, 76, 96, 156, 247, 285, 397, 400, 510, 520, 523, 524, 529, 543, 564, 587, 617,
                          643}
    if entry['question_id'] in excluded_questions:
        return False
    if len(entry['ch_question']) >= 50 or len(entry['en_question'].split()) >= 50:
        return False
    return True


import pandas as pd
model_list = ['ChatGLM3', 'ChatGPT', 'GPT-4', 'Yi-34b', 'Qwen-turbo', 'Llama3']
file_list = ['HalluEval', 'HalluQA', 'history_geograph_data', 'triviaQA', 'WebQA']
def gen_res(model_list, file_list, root_path, save_file='res.csv', selection=True, detector=True):
    all_stat = {}
    for model in model_list:
        all_stat[model] = {}
        all_model_file = os.listdir(root_path + model.lower)
        for file in file_list:
            for model_file in all_model_file:
                if file in model_file:
                    filename = os.path.join(root_path, model.lower(), model_file)
                    original_en_acc, original_ch_acc, improve_en_acc, improve_ch_acc = evaluate_data(filename=filename, include_selection=selection, detect_language=detector)
                    all_stat[model][file + '(en)'] = {'Original Acc': original_en_acc, 'Improved Acc': improve_en_acc}
                    all_stat[model][file + '(ch)'] = {'Original Acc': original_ch_acc, 'Improved Acc': improve_ch_acc}
    print(all_stat)
    columns = ['File', 'Language']
    for model in model_list:
        columns.extend([f"{model} Original Acc", f"{model} Improved Acc"])
    data = []
    for file in file_list:
        for language in ['(en)', '(ch)']:
            row = [file, language]
            for model in model_list:
                original_acc = all_stat[model].get(f"{file}{language}", {}).get('Original Acc', 'N/A')
                improved_acc = all_stat[model].get(f"{file}{language}", {}).get('Improved Acc', 'N/A')
                row.extend([original_acc, improved_acc])
            data.append(row)
    df_corrected = pd.DataFrame(data, columns=columns)
    df_corrected.to_csv(save_file, index=False)

