import json
import random
import os
import numpy as np
import re
import gedi

subjects = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
            'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine',
            'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering',
            'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry',
            'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics',
            'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics',
            'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law',
            'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing',
            'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition',
            'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
            'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy',
            'virology', 'world_religions']

scale = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# set the number of voters
voting_pop = 10

# set the number of unreliable agents
unreliable_agent_num = 0

hitrate_at = 1

benchmark = 'ARC-V1-Feb2018-2'
#     'MMLU', 'ARC-V1-Feb2018-2', 'MMLU-Pro'


engines = ['mistral-7b-v0.3-32k', 'glm-4-9b-128k',
           'llama3-8b-8k', 'llama3-70b-8k',
           'qwen1.5-72b-32k', 'qwen1.5-110b-32k',
           'gpt-3.5-turbo-1106', 'gpt-4-0125-preview']


# temperature = 1.0

method2data = {
    "single_select": 'single_select',

    "range_voting": 'range',

    "random": 'base',
    "blind_dictatorial": 'base',
    "plurality": 'base',
    "bucklin": 'base',
    "borda_count": 'base',
    "irv": 'base',
    "minimax": 'base',
    "ranked_pairs": 'base',

    "informed": 'informed',
    "disrupted": 'disrupted',
}

voting_rules = list(method2data.keys())

data2methods = {
    "single_select": ['single_select'],
    "range": ['range_voting'],
    "base": ['random', 'blind_dictatorial', 'plurality', 'bucklin', 'borda_count', 'irv', 'minimax', 'ranked_pairs'],
    "informed": ['informed'],
    "disrupted": ['disrupted'],
}

# data = ['single_select', 'range', 'base', 'informed', 'disrupted']
data = ['single_select', 'range', 'base', 'informed', 'disrupted']


if __name__ == "__main__":
    results = {}

    # alternative index, MMLU are multichoice questions with fixed answer index
    if benchmark == 'MMLU-Pro':
        choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
    else:
        choices = ["A", "B", "C", "D"]

    for engine in engines:
        if engine in ['gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview']:
            temperature = 1.0
        else:
            temperature = 0.7

        results[engine] = {}
        # Global parameters

        for dataset in data:
            cors = {}
            accs = {}
            sub_cors = {}
            sub_accs = {}

            for subject in subjects:
                sub_cors[subject] = {}
                sub_accs[subject] = {}
                for voting_system in voting_rules:
                    sub_cors[subject][voting_system] = []

            for voting_system in voting_rules:
                cors[voting_system] = []

            total_profile_num = 0
            valid_profile_num = 0
            total_ballot_num = 0
            valid_ballot_num = 0

            results[engine][dataset] = {}

            if os.path.exists(f'./results/{benchmark}/{engine}/{engine}_{temperature}_{dataset}_profile.jsonl'):
                with open(f'./results/{benchmark}/{engine}/{engine}_{temperature}_{dataset}_profile.jsonl', 'r', encoding="utf-8") as fin:

                    profile_index = -1
                    subject_index = -1
                    for line in fin:
                        profile_index += 1
                        subject = subjects[profile_index//100]

                        total_profile_num += 1
                        profile = json.loads(line)


                        # extract the full profile and label
                        if engine in ['gpt-3.5-turbo-1106', 'gpt-4-0125-preview'] and dataset != 'single_select' and benchmark == 'MMLU':
                            label = profile[-1]["label"]
                            raw_profile = profile[:-1]
                        else:
                            label = profile["standardized_gold_answer"]
                            raw_profile = profile["ranking"]

                        # print(valid_profile)
                        # print(len(valid_profile))

                        # create a subset of valid ballots
                        valid_profile = gedi.check_profile(choices, raw_profile, dataset, scale)

                        # exclude incomplete profiles
                        if len(valid_profile) == 0 or (len(valid_profile) - unreliable_agent_num) <= 0:
                            continue

                        # quote to include incomplete profiles
                        # 'range'
                        if dataset not in ['informed', 'disrupted', ]:
                            if len(valid_profile) != voting_pop:
                                continue

                        valid_profile_num += 1

                        total_ballot_num += len(valid_profile)
                        valid_ballot_num += len(raw_profile)

                        # replacing actual ballots with unreliable profile
                        try:
                            if unreliable_agent_num > 0:
                                valid_profile = random.sample(valid_profile, len(valid_profile) - unreliable_agent_num)
                                for i in range(unreliable_agent_num):
                                    if re.search('range', dataset):
                                        unreliable_ballot = scale.copy()
                                    else:
                                        unreliable_ballot = choices.copy()
                                    random.shuffle(unreliable_ballot)
                                    valid_profile.append({"preference": unreliable_ballot})

                        except Exception as info:
                            print(f"Adding random ballot error: {info}")
                            continue

                        # downsizing profile
                        if voting_pop != 10:
                            valid_profile = random.sample(valid_profile, voting_pop)

                        for voting_system in data2methods[dataset]:

                            if voting_system in ['random', 'blind_dictatorial']:
                                sample_time = 10
                            else:
                                sample_time = 1

                            # sample_time = 1

                            for t in range(sample_time):
                                try:
                                    order, tally = gedi.cdm(choices, valid_profile, voting_system, method2data, scale)

                                    # collective ranking
                                    cdm_ranking = []
                                    for i in range(len(set(order.values()))):
                                        cdm_ranking.append(list(order.keys())[list(order.values()).index(i+1)])

                                    # print("order: ", order)
                                    # print("tally: ", tally)
                                    # print('cdm_ranking: ', cdm_ranking)

                                    hit = False
                                    for j in range(hitrate_at):
                                        # for in range(len(cdm_ranking[])):
                                        if label in cdm_ranking[j]:
                                            hit = True
                                            break

                                    if hit:
                                        cors[voting_system].append(True)
                                        sub_cors[subject][voting_system].append(True)
                                    else:
                                        cors[voting_system].append(False)
                                        sub_cors[subject][voting_system].append(False)

                                except Exception as info:
                                    print(f"CDM Error: {info}")
                                    print("order: ", order)
                                    print("tally: ", tally)
                                    print('cdm_ranking: ', cdm_ranking)
                                    print('type(cdm_ranking): ', type(cdm_ranking))
                                    print('voting_system: ', voting_system)
                                    # cors[voting_system].append(False)
                                    # sub_cors[subject][voting_system].append(False)
                                    continue

            # record stats
            results[engine][dataset]["total_profile_num"] = total_profile_num
            results[engine][dataset]["valid_profile_num"] = valid_profile_num

            for voting_system in data2methods[dataset]:
                results[engine][dataset][voting_system] = round(100*np.mean(cors[voting_system]), 1)