import argparse
import json
import random
import re
from tqdm import tqdm
import ast
import os
import pandas as pd
import time
from openai import OpenAI
from random import randrange


# OpenAI API environment
client = OpenAI()

alt_client = OpenAI(
    base_url='',
    api_key='',
)

def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def search_ranking(res):
    # print("original response length:", len(response))
    res = res.replace('\n', '')
    res = res.replace(' ', '')
    res = res.replace('.', '')
    res = res.replace('\'', '')
    res = res.replace('\"', '')
    res = res.replace('*', '')
    res = res.replace('(', '')
    res = res.replace(')', '')

    if method == 'single_select':
        return str(res)[0]

    elif method == 'range':
        start_idx = res.find("{")
        end_idx = res.find("}")
        if start_idx == -1 and end_idx == -1:
            start_idx = res.find("[")
            end_idx = res.find("]")
            if end_idx - start_idx < 4:
                res = res[end_idx + 1:]
                start_idx = res.find("[")
                end_idx = res.find("]")

    else:
        start_idx = res.find("[")
        end_idx = res.find("]")
        if end_idx - start_idx < 4:
            res = res[end_idx+1:]
            start_idx = res.find("[")
            end_idx = res.find("]")

    colon_idx = res.find(":")
    if colon_idx < 3:
        colon_idx = -1

    if start_idx != -1 and end_idx != -1:
        res = res[start_idx+1:end_idx]
    elif colon_idx != -1:
        res = res[colon_idx+1:]

    if method == 'range':
        # print('1: ', res)
        res = '{' + res + '}'
        res = re.sub(r"(?<=[A-Z])(?=[0-9])", ":", res)
        res = re.sub(r"(?<=[0-9])(?=[A-Z])", ",", res)
        res = re.sub(r'([a-zA-Z])', r'"\1"', res)
        # print('2: ', res)
        preference_dict = json.loads(res)
        # print('2: ', preference_dict)
        score = []
        for alt in custom_labels:
            score.append(preference_dict[alt])

        return str(score)

    elif "A" in custom_labels:
        res = re.sub(r":", "", res)
        res = re.sub(r"\d+", "", res)
        res = re.sub(r"(?<=[A-Z])(?=[A-Z])", ",", res)


    res = list(res.split(','))
    cleaned_res = str(res)

    return cleaned_res


def formulate_question(df, idx, original_labels, custom_labels, method, include_answer=True):

    # k = number of answers, excluding question and correct answer column
    k = df.shape[1] - 2

    # question string
    gold_answer_str = ''
    gold_answer_idx = -1
    answers_list = []
    prompt = df.iloc[idx, 0]
    original_gold_answer_label = df.iloc[idx, k+1]
    custom_gold_answer_label = original_gold_answer_label

    # create a mapping of original answers (keys) and labels (values)
    original_answer_label_mapping = {}
    for i in range(len(original_labels)):

        original_answer_label_mapping[df.iloc[idx, i + 1]] = original_labels[i]

        if original_gold_answer_label == original_labels[i]:
            gold_answer_str = df.iloc[idx, i + 1]
            gold_answer_idx = i

        answers_list.append(df.iloc[idx, i + 1])

    # create custom answer list
    # if method == "iia_rm_g":
    #     rmd_answer = answers_list.pop(gold_answer_idx)
    if method == "iia_rm_rng":
        rand = gold_answer_idx
        while rand == gold_answer_idx:
            rand = randrange(len(original_labels))
        rmd_answer = answers_list.pop(rand)
    elif method == "iia_rm_g":
        if include_answer:
            rand = gold_answer_idx
            while rand == gold_answer_idx:
                rand = randrange(len(original_labels))
            rmd_answer = answers_list.pop(rand)
        else:
            rmd_answer = answers_list.pop(gold_answer_idx)
            custom_gold_answer_label = 'gold removed'
    if method == "iia_rm_g+1":
        if gold_answer_idx + 1 >= len(original_labels):
            rmd_answer = answers_list.pop(gold_answer_idx + 1 - len(original_labels))
        else:
            rmd_answer = answers_list.pop(gold_answer_idx+1)
    elif method == "iia_rm_g+2":
        if gold_answer_idx + 2 >= len(original_labels):
            rmd_answer = answers_list.pop(gold_answer_idx + 2 - len(original_labels))
        else:
            rmd_answer = answers_list.pop(gold_answer_idx+2)
    elif method == "iia_rm_g+3":
        if gold_answer_idx + 3 >= len(original_labels):
            rmd_answer = answers_list.pop(gold_answer_idx + 3 - len(original_labels))
        else:
            rmd_answer = answers_list.pop(gold_answer_idx + 3)

    # answers_list.append(gold_answer_str)
    # selected_answers = random.sample(answers_list, len(custom_labels))

    custom_label_answer_mapping = {}
    for j in range(len(custom_labels)):
        custom_label_answer_mapping[custom_labels[j]] = answers_list[j]
        if answers_list[j] == gold_answer_str:
            custom_gold_answer_label = custom_labels[j]

    # apply custom answer list
    # for j in range(len(custom_labels)):
        prompt += "\n{}. {}".format(custom_labels[j], answers_list[j])

    # create a subset of selected answers and labels

    if method == 'single_select':
        prompt += "\nAnswer: "
    elif method == 'reverse':
        prompt += "\nPreferential Answer Ranking (from least to most preferred): "
    else:
        prompt += "\nPreferential Answer Ranking: "

    # include answer string for in-context learning examples
    if include_answer:
        # assign_score = random.sample(scale, 4)
        # assign_score.sort(reverse=True)
        # ranking_with_score = {}

        if method == 'single_select':
            prompt += f"{custom_gold_answer_label}\n\n"
        else:
            # gold_label
            preferential_ranking = [custom_gold_answer_label]

            # make a random list of unselected answers
            rest_of_answers = []
            for answer in custom_labels:
                if answer != custom_gold_answer_label:
                    rest_of_answers.append(answer)
            random.shuffle(rest_of_answers)
            preferential_ranking += rest_of_answers

            if method == 'reverse':
                preferential_ranking = preferential_ranking[::-1]

            if method == 'range':
                assign_score = random.sample(scale, len(custom_labels))
                assign_score.sort(reverse=True)
                ranking_with_score = {}

                for i in range(len(custom_labels)):
                    ranking_with_score[preferential_ranking[i]] = assign_score[i]

                prompt += f"{ranking_with_score}\n\n"
            else:
                prompt += f"{preferential_ranking}\n\n"
    return prompt, custom_gold_answer_label, custom_label_answer_mapping, original_answer_label_mapping

def formulate_question_binary(df, idx, original_labels, custom_labels, former=-1, later=-1, include_answer=True):
    # k = number of answers, excluding question and correct answer column
    k = df.shape[1] - 2

    # question string
    prompt = df.iloc[idx, 0]
    original_gold_answer_label = df.iloc[idx, k+1]
    gold_answer_str = ''
    answers_list = []
    custom_gold_answer_label = ''

    # create a mapping of original answers (keys) and labels (values)
    original_answer_label_mapping = {}
    for i in range(len(original_labels)):

        original_answer_label_mapping[df.iloc[idx, i + 1]] = original_labels[i]

        if original_gold_answer_label == original_labels[i]:
            gold_answer_str = df.iloc[idx, i + 1]
            gold_answer_idx = i

        answers_list.append(df.iloc[idx, i + 1])

    # create custom answer list
    custom_label_answer_mapping = {}
    if include_answer:
        # randomly choose a label for gold first
        custom_gold_answer_label = random.choice(custom_labels)
        custom_label_answer_mapping[custom_gold_answer_label] = gold_answer_str
        for j in range(len(custom_labels)):
            if custom_labels[j] != custom_gold_answer_label:
                custom_label_answer_mapping[custom_labels[j]] = answers_list[j]

    else:
        custom_label_answer_mapping[custom_labels[0]] = answers_list[former]
        custom_label_answer_mapping[custom_labels[1]] = answers_list[later]


    # apply custom answer list
    for j in range(len(custom_labels)):
        prompt += "\n{}. {}".format(custom_labels[j], custom_label_answer_mapping[custom_labels[j]])

    # create a subset of selected answers and labels
    prompt += "\nAnswer: "

    # include answer string for in-context learning examples
    if include_answer:
        prompt += f"{custom_gold_answer_label}\n\n"

    return prompt, custom_gold_answer_label, custom_label_answer_mapping, original_answer_label_mapping

def gen_prompt(train_df, subject, method, k=-1, former=-1,later=-1):
    if subject == 'ARC-Challenge' or subject == 'ARC-Easy':
        prompt = "The following are multiple choice questions.\n"
    else:
        prompt = "The following are multiple choice questions about{}.\n".format(format_subject(subject))

    # generate k-shot examples
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        if method == 'binary':
            prompt += formulate_question_binary(train_df, i, original_labels, custom_labels, former, later, include_answer=True,)[0]
        else:
            prompt += formulate_question(train_df, i, original_labels, custom_labels, method, include_answer=True)[0]

    prompt += "Refer to above examples, "

    if method == 'single_select' or method == 'binary':
        prompt += ("choose the better answer for the following question.\n"
                  "Return only an answer tag without explanation.\n\n")
    elif method == 'reverse':
        prompt += ("provide a preferential ranking of all answers for the following question.\n"
                   "The ranking should start from the LEAST preferred and end with the MOST preferred.\n"
                   "Return only a ranking list without explanation.\n\n")
    elif method == 'range':
        prompt += ("provide a preferential ranking of all answers with preferential scores on a 0-9 scale for the following question.\n"
                   "The ranking should start from the MOST preferred (highest score) and end with the LEAST preferred (lowest score).\n"
                   "Return only a ranking list of labels and scores, no explanation.\n\n")
    else:
        prompt += ("provide a preferential ranking of all answers for the following question.\n"
                  "The ranking should start from the MOST preferred and end with the LEAST preferred.\n"
                  "Return only a ranking list of labels without explanation.\n\n")
    return prompt

def llm_rank(args, subject, engine, temperature, dev_df, test_df, test_sample_size=0, subject_index=0):

    # Test full dataset if test sample size not assigned
    if test_sample_size == 0:
        test_sample_size = test_df.shape[0]

    for i in tqdm(range(test_sample_size), desc=f"Evaluating {subject}:"):

        if method == 'informed':
            with open(f'./results/{benchmark}/{engine}/{engine}_{temperature}_base_profile.jsonl', 'r', encoding="utf-8") as fin:
                question_index = 0
                for line in fin:
                    if subject_index * 100 + i == question_index:
                        profile = json.loads(line)

                        informed_profile = []
                        for j in range(len(profile['ranking'])):
                            informed_profile.append(profile['ranking'][j]['preference'])
                        break
                    question_index += 1

        elif method == 'disrupted':
            informed_profile = []
            rnd_vote = custom_labels.copy()
            random.shuffle(rnd_vote)
            for vote in range(counsel_pop):
                informed_profile.append(rnd_vote)
                random.shuffle(rnd_vote)

        # get prompt and make sure it fits
        k = args.ntrain

        # train_prompt is the in-context learning examples
        train_prompt = gen_prompt(dev_df, subject, method, k)
        # prompt_end is the actual question
        prompt_end, custom_gold_label, custom_label_answer_mapping, original_answer_label_mapping = formulate_question(test_df, i, original_labels, custom_labels, method, include_answer=False)

        prompt = train_prompt + prompt_end

        # default response if fail to retrieve any answer from models
        response = 'Not retrieved'

        while crop(prompt) != prompt:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject, method, k)
            prompt = train_prompt + prompt_end

        if method == 'informed' or method == 'disrupted':
            end = -len('Preferential Answer Ranking: ')
            prompt = prompt[:end]

            prompt += (f"Other 10 raters have provided their preferences as {informed_profile}\n"
                       "Your Preferential Answer Ranking: ")

        actual_gold_label = test_df.iloc[i, test_df.shape[1]-1]


        # gather profile P of preferential rankings (ballots) from agents
        profile = {
            "choice_labels": custom_labels,
            "gold_answer": custom_gold_label,
            "standardized_gold_answer": actual_gold_label,
            "question": prompt_end,
            # "prompt": prompt,
            "ranking": [],
        }

        # collect individual ballot
        for agent in range(response_pop):
            rnd = random.randint(1, 2024)
            request = f"You are the No.{rnd} rater.\n" + prompt

            force_break = 0
            while True:
                try:
                    if re.search(r"gpt", engine):
                        response = client.chat.completions.create(
                            model=engine,
                            messages=[
                                {"role": "user",
                                 "content": request}
                            ],
                            temperature=temperature,
                        )
                    else:
                        response = alt_client.chat.completions.create(
                            model=engine,
                            messages=[
                                {"role": "user",
                                 "content": request}
                            ],
                            temperature=temperature,
                            max_tokens=100,
                        )

                    break
                except Exception as info:
                    print(f"Response Error: {info}")
                    time.sleep(1)

                    force_break += 1

                    if force_break < 3:
                        continue
                    else:
                        break

            # take out answers from string
            ballot = {}
            try:
                ans = search_ranking(response.choices[0].message.content)

                if method == 'single_select':
                    ballot['preference'] = ans
                else:
                    ballot['preference'] = ast.literal_eval(ans)

                if method != 'range':
                    standardized_preference = []
                    for label in ballot['preference']:
                        standardized_preference += original_answer_label_mapping[custom_label_answer_mapping[label]]
                    ballot['standardized_preference'] = standardized_preference

                profile["ranking"].append(ballot)
            except Exception as info:
                print(f"Ballot Error: {info}")
                continue

        # record profile
        with open(f'./{args.save_dir}/{args.engine}/{engine}_{temperature}_{method}_profile.jsonl', 'a', encoding="utf-8") as fout:
            try:
                out_str = json.dumps(profile, ensure_ascii=False)
                fout.write(f"{out_str}\n")
            except Exception as info:
                print(info)
                continue

def llm_rank_binary(args, subject, engine, temperature, dev_df, test_df, test_sample_size=0):

    # Test full dataset if test sample size not assigned
    if test_sample_size == 0:
        test_sample_size = test_df.shape[0]

    # for i in tqdm(range(test_sample_size), desc="Evaluating"):
    for i in range(test_sample_size):

        # prompt_end is the actual question
        prompt_end, custom_gold_label, custom_label_answer_mapping, original_answer_label_mapping = formulate_question(
            test_df, i, original_labels, original_labels, 'base', include_answer=False)

        actual_gold_label = test_df.iloc[i, test_df.shape[1]-1]

        binary_matrix = [[column for column in range(len(original_labels))] for row in range(len(original_labels))]

        profile = {
            "choice_labels": original_labels,
            "gold_answer": custom_gold_label,
            "standardized_gold_answer": actual_gold_label,
            "question": prompt_end,
            "pairs": []
        }

        for former_alt in range(len(original_labels)):
            for latter_alt in range(len(original_labels)):
                if former_alt == latter_alt:
                    binary_matrix[former_alt][latter_alt] = 0
                    continue

                # get prompt and make sure it fits
                k = args.ntrain

                # train_prompt is the in-context learning examples
                train_prompt = gen_prompt(dev_df, subject, method, k, former=former_alt, later=latter_alt)

                prompt_end, custom_gold_label, custom_label_answer_mapping, original_answer_label_mapping = formulate_question_binary(
                    test_df, i, original_labels, custom_labels, former_alt, latter_alt, include_answer=False)

                prompt = train_prompt + prompt_end

                # default response if fail to retrieve any answer from models
                response = 'Not retrieved'

                # collect individual ballot
                for agent in range(response_pop):
                    # rnd = random.randint(1, 2024)
                    # prompt = f"You are the No.{rnd} rater.\n" + prompt

                    force_break = 0
                    while True:
                        try:
                            if re.search(r"gpt", engine):
                                response = client.chat.completions.create(
                                    model=engine,
                                    messages=[
                                        {"role": "user",
                                         "content": prompt}
                                    ],
                                    temperature=temperature,
                                )
                            else:
                                response = alt_client.chat.completions.create(
                                    model=engine,
                                    messages=[
                                        {"role": "user",
                                         "content": prompt}
                                    ],
                                    temperature=temperature,
                                    max_tokens=80,
                                )
                            break
                        except Exception as info:
                            print(f"Response Error: {info}")
                            time.sleep(1)
                            force_break += 1
                            if force_break < 3:
                                continue
                            else:
                                break

                    # take out answers from string
                    ballot = {}
                    try:
                        # if re.search(r"gpt|llama3-70b", engine):
                        ballot['preference'] = response.choices[0].message.content[0]

                        # standardized ranking to [A, B, C, D]
                        ballot['standardized_preference'] = original_answer_label_mapping[custom_label_answer_mapping[ballot['preference']]]

                        if ballot['standardized_preference'] == original_labels[former_alt]:
                            binary_matrix[former_alt][latter_alt] = 1
                            profile['pairs'].append(f'{original_labels[former_alt]}>{original_labels[latter_alt]}')
                        elif ballot['standardized_preference'] == original_labels[latter_alt]:
                            binary_matrix[former_alt][latter_alt] = -1
                            profile['pairs'].append(f'{original_labels[former_alt]}<{original_labels[latter_alt]}')
                        else:
                            binary_matrix[former_alt][latter_alt] = 0

                    except Exception as info:
                        print(f"Ballot Error: {info}")
                        print("original response: ", response.choices[0].message.content)
                        profile['pairs'].append(f'Failed response: {response.choices[0].message.content}')
                        binary_matrix[former_alt][latter_alt] = 0
                        # time.sleep(1)
                        continue

        profile["matrix"] = binary_matrix

        # record profile
        with open(f'./{args.save_dir}/{args.engine}/{engine}_{temperature}_{method}_profile.jsonl', 'a', encoding="utf-8") as fout:
            try:
                out_str = json.dumps(profile, ensure_ascii=False)
                fout.write(f"{out_str}\n")
            except Exception as info:
                print(info)
                continue

def main(args):
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    if not os.path.exists(os.path.join(args.save_dir, "{}".format(args.engine))):
        os.mkdir(os.path.join(args.save_dir, "{}".format(args.engine)))

    # print("subjects:", subjects)
    print("args: ", args)
    print("args.engine: ", args.engine)

    if benchmark == 'MMLU':
        subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])

        subject_index = 0
        for subject in tqdm(subjects):

            dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
            test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)

            if method != "binary":
                llm_rank(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, subject_index)
            else:
                llm_rank_binary(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, )

            subject_index += 1

    elif benchmark == 'ARC-V1-Feb2018-2':
        subject = 'ARC-Challenge'
        # 'ARC-Easy'
        dev_df = pd.read_csv(os.path.join(args.data_dir, subject, subject+"-Dev-Re.csv"), header=None)[:args.ntrain]
        test_df = pd.read_csv(os.path.join(args.data_dir, subject, subject+"-Test-Re.csv"), header=None)

        if method != "binary":
            llm_rank(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, )
        else:
            llm_rank_binary(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, )

    elif benchmark == 'MMLU-Pro':
        subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])


        subject_index = 0
        for subject in tqdm(subjects):

            dev_df = pd.read_csv(os.path.join(args.data_dir, "dev.csv"), header=None)[:args.ntrain]
            test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)

            if method != "binary":
                llm_rank(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, subject_index)
            else:
                llm_rank_binary(args, subject, args.engine, temperature, dev_df, test_df, test_sample_size, )

            subject_index += 1


if __name__ == "__main__":

    # voting_rules = ["dictatorial", "random", "plurality", "borda_count", "irv", "minimax", "ranked_pairs", "bucklin"]
    scale = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

    # if output a result statistics after the ranking process
    output_stat = False

    # set method
    methods = ['range']

    benchmark = 'MMLU'
    # 'MMLU', 'ARC-V1-Feb2018-2', 'MMLU-Pro'

    # test sample size, set 0 to test full dataset
    if benchmark == 'MMLU':
        test_sample_size = 100
    elif benchmark == 'ARC-V1-Feb2018-2':
        test_sample_size = 0
    elif benchmark == 'MMLU-Pro':
        test_sample_size = 100

    for method in methods:

        # set the number of responses collected
        if method in ['informed', 'disrupted']:
            response_pop = 1
            counsel_pop = 10
        else:
            response_pop = 10

        # alternative index, MMLU are multiple choice questions with 4 choices
        if benchmark in ['MMLU', 'ARC-V1-Feb2018-2']:
            original_labels = ["A", "B", "C", "D"]
            if re.search("iia", method):
                custom_labels = ["A", "B", "C"]
            elif re.search("binary", method):
                custom_labels = ["A", "B"]
            elif re.search("alt_arabic", method):
                custom_labels = ["(1)", "(2)", "(3)", "(4)"]
            elif re.search("alt_roman", method):
                custom_labels = ["I", "II", "III", "IV"]
            else:
                custom_labels = ["A", "B", "C", "D"]
        elif benchmark == 'MMLU-Pro':
            original_labels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
            custom_labels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]


        parser = argparse.ArgumentParser()
        parser.add_argument("--ntrain", "-k", type=int, default=5)
        parser.add_argument("--data_dir", "-d", type=str, default=f"data/{benchmark}")       # --data_dir
        parser.add_argument("--save_dir", "-s", type=str, default=f"results/{benchmark}")    # --save_dir
        parser.add_argument(
            "--engine", "-e",
            type=str,
            default='gpt-3.5-turbo-0125',
            # nargs="+"
        )
        args = parser.parse_args()

        # gpt-3.5-turbo-1106, gpt-4-0125-preview,
        # llama3-8b-8k, llama3-70b-8k,
        # qwen1.5-72b-32k, qwen1.5-110b-32k,
        # mistral-7b-v0.3-32k,

        # set temperature
        if args.engine in ['gpt-3.5-turbo-1106', 'gpt-4-0125-preview', 'gpt-3.5-turbo-0125']:
            temperature = 1.0
        else:
            temperature = 0.7

        print(f"""Initializing, engine={args.engine}, temperature={temperature}, few shot example={args.ntrain},
        response_pop={response_pop}, test_sample_size={test_sample_size}""")

        main(args)



