
import sys
import json
import random
import numpy as np
from scipy.stats import t
from collections import Counter
from scipy.spatial import distance
import tiktoken
import matplotlib.pyplot as plt
# import seaborn as sns

POSITIONAL_BIAS_MODELS = ['gpt-turbo', 'spark', 'qwen']
PERTURB2TITLE = {
    "factual_error": "Factual Error",
    "reference": "Reference",
    "rich_content": "Rich Content",
}
ROLE2TITLE = {
    'human': 'Human',
    'gpt4': 'GPT-4',
    "gpt4-cot": "GPT-4-COT",
    'gpt-turbo': 'GPT-3.5-Turbo',
    'gpt-turbo-cot': 'GPT-3.5-Turbo-COT',
    'claude-2': 'Claude-2',
    'claude-2-cot': 'Claude-2-COT',
    'llama2-13b-ans': 'LLaMA2-13B-ANS',
    'llama2-70b-turbo-ans': 'LLaMA2-70B-Turbo-ANS',
    "turbo_turbo_fe": "Turbo-Turbo-FE",
    "gpt4-a-cot": "GPT-4-A-COT",
    'ernie': "Ernie",
    'spark': 'Spark',
    'llama2-70B': 'LLaMA2-70B',
    # 'palm': 'PaLM-2-Bison',
    'palm': 'PaLM-2',
    'gpt4-Turbo': "GPT-4-Turbo",
    'qwen': "Qwen",
    'human_not_familar': 'Human-NF',
}

sys.path.insert(0, "../scripts")
from blacklist import user_black_list, question_black_list
# random.seed(0)

def count_evals(data):
    total = 0
    for question_id, val in data.items():
        question = question_id[:-2]
        if question in question_black_list:
            continue
        for user_id, val in val['evals'].items():
            if user_id in user_black_list:
                continue
            total += 1
    return total

def count_evals_by_model(data):
    data_transformed = {}
    for question_id, question in data.items():
        # question_id = question_id[:-2]
        for _, vote in question['evals'].items():
            if vote['vote'] == 'left':
                pref = question['answers']['answer1']['model_id']

            elif vote['vote'] == 'right':
                pref = question['answers']['answer2']['model_id']
            elif vote['vote'] == 'tie':
                pref = 'tie'
            else:
                # skipped
                continue
            # print(question_id)
            # print(pref)
            data_transformed[pref] = data_transformed.get(pref, 0) + 1
    return data_transformed

def combine_res_by_question(data):
    questions = {}
    for question_id_with_side, val in data.items():
        side = question_id_with_side[-1]

        for user_id, vote in val['evals'].items():
            question = questions.get(question_id_with_side, {})
            if not question:
                question['question_id'] = question_id_with_side
                question['level'] = question_id_with_side.split('_')[0]
                try:
                    # question['perturb'] = question_id_with_side.split('-')[1]
                    # if question['perturb'].isdigit():
                    #     question['perturb'] = ""
                    question['perturb'] = val['answers']['answer1']['perturb'] if val['answers']['answer1']['perturb'] != "" else val['answers']['answer2']['perturb']

                except:
                    question['perturb'] = ""
                question['question'] = val['question']
                question['answers'] = val['answers']
                question['evals'] = {}
                
                
            # vote 为 left则是answer1，否则是answer2
            ## double check
            if vote['vote'] == "left":
                answer_id = val['answers']['answer1']['answer_id']
            elif vote['vote'] == "right":
                answer_id = val['answers']['answer2']['answer_id']
            elif vote['vote'] == "tie":
                # tie
                answer_id = ""
            else:
                # skipped
                continue

            try:
                perturb = val['answers']['answer1']['perturb'] if vote['vote'] == "left" else val['answers']['answer2']['perturb']
            except:
                perturb = ""
            vote['answer_id'] = answer_id
            vote['perturb'] = perturb
            vote['side'] = side
            question['evals'][user_id] = vote
            questions[question_id_with_side] = question
    return questions

def ramdom_choose(vote, preference, num):
    indices = list(range(len(vote)))
    random_indices = random.sample(indices, num)

    vote = [vote[i] for i in random_indices]
    preference = [preference[i] for i in random_indices]
    return vote, preference

def collect_val(question, question_id, original_id, perturb_id, level, perturb, num_of_sample = 3):
    val = {}
    val['question_id'] = question_id
    val['original_id'] = original_id
    val['perturb_id'] = perturb_id
    val['perturb'] = perturb
    val['level'] = level
    val['votes'] = []
    val['preference'] = []
    val['side'] = []
    val['position'] = []

    for user_id, vote in question['evals'].items():
        val['side'].append(vote['side'])
        val['position'].append(vote['vote'])
        if vote['answer_id'] == original_id:
            val['votes'].append(f"O:{vote['answer_id']}")
            val['preference'].append(0)
        elif vote['answer_id'] == perturb_id:
            val['votes'].append(f"P:{vote['answer_id']}")
            val['preference'].append(1)
        elif vote['answer_id'] == perturb_id[:-4]:
            val['votes'].append(f"P:{vote['answer_id']}")
            val['preference'].append(1)
        elif vote['answer_id'] == perturb_id[:-8]:
            val['votes'].append(f"P:{vote['answer_id']}")
            val['preference'].append(1)
        elif vote['answer_id'] == "":
            val['votes'].append("T")
            val['preference'].append(0.5) 
        elif "temperature" in question['perturb']:
            val['votes'].append(f"P:{vote['answer_id']}")
            val['preference'].append(1)
        else:
            print(vote['answer_id'])
            print(original_id)
            print(perturb_id)
            print(question_id)
            print(question['perturb'])
            print(question)
            raise Exception("answer_id error")
    ## 最多只能有三个vote
    if len(val['votes']) > 3:
        # val['votes'], val['preference'] = ramdom_choose(val['votes'], val['preference'], 3)
        val['votes'] = val['votes'][:num_of_sample]
        val['preference'] = val['preference'][:num_of_sample]
        val['side'] = val['side'][:num_of_sample]
        val['position'] = val['position'][:num_of_sample]
    
    return val

## 将投票结果转换为适合bootstrap的格式
## 输入的s1，s2应为by_question的格式，同时s2只应包含单一perturbation
def transform_between_stage(s1, s2, num_of_sample = 3):
    s1_transformed = {}
    s2_transformed = {}
    for question_id, question in s2.items():
        question_id = question_id.split('-')
        del question_id[-2]
        question_id = '-'.join(question_id)
        ## 确认original id和perturb id
        original_id = question['answers']['answer1']['answer_id'] if question['answers']['answer1']['perturb'] == "" else question['answers']['answer2']['answer_id']
        perturb_id = question['answers']['answer1']['answer_id'] if question['answers']['answer1']['perturb'] != "" else question['answers']['answer2']['answer_id']
        s2_transformed[question_id] = collect_val(  question = question, 
                                                    question_id = question_id, 
                                                    original_id = original_id, 
                                                    perturb_id = perturb_id, 
                                                    level = question_id.split('_')[0],
                                                    perturb = question['perturb'],
                                                    num_of_sample = num_of_sample)

    empty_questions = []
    for question_id, question in s2_transformed.items():
        try:
            s1[question_id]['perturb'] = question['perturb']
            s1_transformed[question_id] = collect_val(question=s1[question_id],
                                                    question_id=question_id,
                                                    original_id=question['original_id'],
                                                    perturb_id=question['perturb_id'],
                                                    level = question['level'],
                                                    perturb = question['perturb'],
                                                    num_of_sample = num_of_sample)
        except:
            print(question_id)
            empty_questions.append(question_id)
            continue
    for question_id in empty_questions:
        del s2_transformed[question_id]

    return s1_transformed, s2_transformed


def transform_s2(role_data):
    role_transformed = {}

    for question_id, question in role_data.items():
        question_id = question_id.split('-')[0] + "-" + question_id.split('-')[2]
        if question['answers']['answer1']['perturb'] == "":
            original_id = question['answers']['answer1']['answer_id']
            perturb_id = question['answers']['answer2']['answer_id']
        else:
            original_id = question['answers']['answer2']['answer_id']
            perturb_id = question['answers']['answer1']['answer_id']
            
        role_transformed[question_id] = collect_val(  question = question, 
                                                    question_id = question_id, 
                                                    original_id = original_id, 
                                                    perturb_id = perturb_id, 
                                                    level = question_id.split('_')[0],
                                                    perturb = question['perturb'])
    return role_transformed

## role_transformed[question_id] = {'question_id': 'Evaluating_15-2',
                                    # 'original_id': 'ZtruPBow',
                                    # 'perturb_id': '7TQNKsta',
                                    # 'perturb': '',
                                    # 'level': 'Evaluating',
                                    # 'votes': ['P:7TQNKsta', 'P:7TQNKsta', 'T'],
                                    # 'preference': [1, 1, 0.5],
                                    # 'side': ['2', '2', '2'],
                                    # 'position': ['left', 'left', 'tie']}
def transform_s1(role_data, original_ids = None, perturb_ids = None):
    role_transformed = {}
    if original_ids is None or perturb_ids is None:
        original_ids = {}
        perturb_ids = {}

    for question_id, question in role_data.items():
        q_id = question_id.split("-")[0]
        if q_id not in original_ids:
            original_ids[q_id] = question['answers']['answer1']['answer_id']
            perturb_ids[q_id] = question['answers']['answer2']['answer_id']
        original_id = original_ids[q_id]
        perturb_id = perturb_ids[q_id]

        role_transformed[question_id] = collect_val(  question = question, 
                                                    question_id = question_id, 
                                                    original_id = original_id, 
                                                    perturb_id = perturb_id, 
                                                    level = question_id.split('_')[0],
                                                    perturb = question['perturb'])
    return role_transformed



def bootstrap_confidience_interval(s1, s2, B1 = 100, B2=100, alpha = 0.01, n = None):
    # 计算每个question_id的投票结果的差异
    diffs = [np.mean(s2[question_id]['preference']) - np.mean(s1[question_id]['preference']) for question_id in s1]

    # 第一层bootstrap：在每个question_id内部进行bootstrap
    # bootstrap_diffs = [[question1_diff1, question2_diff1, ...], [question1_diff2, question2_diff2, ...], ...
    bootstrap_diffs = []
    for _ in range(B1):  # 进行B1次bootstrap
        bootstrap_diff = []
        for question_id in s1:
            s1_sample = random.choices(s1[question_id]['preference'], k=len(s1[question_id]['preference']))
            s2_sample = random.choices(s2[question_id]['preference'], k=len(s2[question_id]['preference']))
            bootstrap_diff.append(np.mean(s2_sample) - np.mean(s1_sample))
        bootstrap_diffs.append(bootstrap_diff)

    # 第二层bootstrap：在第一层的bootstrap样本上进行bootstrap
    nested_bootstrap_diffs = []
    for _ in range(B2):  # 进行B2次bootstrap
        nested_bootstrap_diff = np.mean(random.choices(bootstrap_diffs, k=len(bootstrap_diffs)))
        nested_bootstrap_diffs.append(nested_bootstrap_diff)

    sample_mean = np.mean(diffs)
    sample_var = np.var(nested_bootstrap_diffs, ddof=0)
    t_value = t.ppf((2 - alpha) / 2., len(nested_bootstrap_diffs))

    # 计算差异的置信区间
    conf_interval = [sample_mean - t_value * np.sqrt(sample_var), sample_mean + t_value * np.sqrt(sample_var)]
    return conf_interval

def bootstrap_by_level(s1, s2, B1 = 100, B2=100, alpha = 0.01, n = None):
    levels = list(set([s2[i]['level'] for i in s2]))
    s1_level_data = {}
    s2_level_data = {}
    level_ci = {}

    for i in s2:
        level = s2[i]['level']
        s2_level_data[level] = s2_level_data.get(level, {})
        s2_level_data[level][i] = s2[i]

        s1_level_data[level] = s1_level_data.get(level, {})
        try:
            s1_level_data[level][i] = s1[i]
        except:
            print(i)
            continue
    
    for level in levels:
        ci = bootstrap_confidience_interval(s1_level_data[level], s2_level_data[level], B1, B2, alpha, n)
        level_ci[level] = ci

    return level_ci


def group_by_perturb(data, data_by_perturb):
    perturbs = [
        "temperature_0.2", "temperature_0.4", "temperature_0.6", "temperature_0.8", 
        "factual_error", "reference", "rich_content", "", " ",
        "reference_rich_content",
        "factual_error+reference", "factual_error+rich_content", 'factual_error+reference_rich_content'
    ]
    ## 将role2按照perturb分组
    for question_id, question in data.items():
        for perturb in perturbs:
            if question['perturb'] == perturb:
                data_by_perturb[perturb] = data_by_perturb.get(perturb, {})
                data_by_perturb[perturb][question_id] = question
                break
        else:
            print(f"|{question['perturb']}|")
            raise Exception("perturb error")

def preference_shift_by_stage(s1_data, s2_data, B1 = 100, B2=100, alpha = 0.01):
    ## 将s1和s2按照question分组
    s1_data_by_question = combine_res_by_question(s1_data)
    s2_data_by_question = combine_res_by_question(s2_data)
    s2_by_perturb = {}

    group_by_perturb(s2_data_by_question, s2_by_perturb)
    
    results = {}
    for perturb in s2_by_perturb:
        s1_transformed, s2_transformed = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])
        results[perturb] = bootstrap_by_level(s1_transformed, s2_transformed, B1, B2, alpha)

    return results

def preference_shift_by_role(role1_data, role2_data, B1=100, B2=100, alpha = 0.01, stage = 's1'):
    if stage == 's1':
        transform_same_stage = transform_s1
    elif stage == 's2':
        transform_same_stage = transform_s2
    else:
        raise Exception("stage error")
    ## 将role1和role2按照question分组
    role1_data_by_question = combine_res_by_question(role1_data)
    role2_data_by_question = combine_res_by_question(role2_data)
    
    ## 将问题按perturbation分组
    role1_by_perturb = {}
    role2_by_perturb = {}
    group_by_perturb(role1_data_by_question, role1_by_perturb)
    group_by_perturb(role2_data_by_question, role2_by_perturb)


    results = {}
    for perturb in role1_by_perturb:
        role1_transformed = transform_same_stage(role1_by_perturb[perturb])
        role2_transformed = transform_same_stage(role2_by_perturb[perturb])
        results[perturb] = bootstrap_by_level(role1_transformed, role2_transformed, B1, B2, alpha)

    return results

## {question_id: {'question_id': 'Evaluating_15-2', 'position': ['left', 'left', 'tie']}, ...}
def collect_preference_vote(data):
    preference = []
    for _, val in data.items():
        for v in val['votes']:
            preference.append(v.split(":")[0])
    return preference

def get_tranformed_data_between_stage(data, perturb, num_of_sample = 3):
    s1_data, s2_data = data['s1'], data['s2']
    s1_data_by_question = combine_res_by_question(s1_data)
    s2_data_by_question = combine_res_by_question(s2_data)

    s2_by_perturb = {}
    group_by_perturb(s2_data_by_question, s2_by_perturb)

    s1_transformed, s2_transformed = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb], num_of_sample)

    return s1_transformed, s2_transformed

def get_preference_count(data, perturb, no_claude=True, no_turbo=True, no_gpt4_ref=True, no_gpt4_no_exp=True, no_cot=True, no_spark=True, no_qwen=True):
    # roles = ['human', 'gpt-turbo', 'gpt-turbo-cot', 'gpt4', 'gpt4-cot', 'claude-2', 'claude-2-cot', 'gpt4_ref', 'gpt4_no_exp', 'gpt-3.5-turbo-ans']
    role_preference_votes = {}
    role_preference_count = {}
    pref = ["O", "T", "P"]

    for role in data.keys():
        if no_claude and 'claude' in role:
            continue
        if no_turbo and 'turbo' in role:
            continue
        if no_gpt4_ref and 'gpt4_ref' in role:
            continue
        if no_gpt4_no_exp and 'gpt4_no_exp' in role:
            continue
        if no_cot and 'cot' in role:
            continue
        if no_spark and 'spark' in role:
            continue
        if no_qwen and 'qwen' in role:
            continue

        # role_data_by_question_s1 = combine_res_by_question(data[role]['s1'])
        # role_data_by_question_s2 = combine_res_by_question(data[role]['s2'])

        # role_by_perturb_s2 = {}
        # group_by_perturb(role_data_by_question_s2, role_by_perturb_s2)
        # s1_transformed, s2_transformed = transform_between_stage(role_data_by_question_s1, role_by_perturb_s2[perturb])

        s1_transformed, s2_transformed = get_tranformed_data_between_stage(data[role], perturb)

        role_preference_votes[role] = {}
        role_preference_votes[role]['Control'] = collect_preference_vote(s1_transformed)
        role_preference_votes[role]['Experimental'] = collect_preference_vote(s2_transformed)
        # for perturb in perturbs:
        #     role_position_votes[role]['s2'].extend(collect_position_vote(transform_s2(role_by_perturb_s2[perturb])))


        role_preference_count[role] = {}
        stages = ['Control', 'Experimental']

        for stage in stages:
            positions = role_preference_votes[role][stage]
            role_preference_count[role][stage] = {}
            counter = Counter(positions)
            for key in counter:
                if key is not None:
                    role_preference_count[role][stage][key] = counter[key]
            for p in pref:
                if p not in role_preference_count[role][stage]:
                    role_preference_count[role][stage][p] = 0

    return role_preference_count


# 计算string中包含的token数
def num_tokens_from_string(string: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding("cl100k_base")
    num_tokens = len(encoding.encode(string))
    return num_tokens

# 在data的每一个answer下增加一个length字段
def count_tokens(data):
    for role, role_data in data.items():
        for stage, stage_data in role_data.items():
            for question_id, question in stage_data.items():
                for answer_id, answer in question['answers'].items():
                    answer['length'] = num_tokens_from_string(answer['answer'])
                question['length_dif'] = question['answers']['answer1']['length'] - question['answers']['answer2']['length']
    return data

def load_data(count_token = False):
    ctrl_home_dir = '../data/ctrl_group/'
    exp_home_dir = '../data/exp_group/'
    paths = {
        "human": {
            's1': f'{ctrl_home_dir}/human/s1_encoded_high_quality.json',
            's2': f"{exp_home_dir}/human/s2_encoded_high_quality.json"
        },
        "gpt4":{
            "s1": f"{ctrl_home_dir}/gpt-4/v0/non_cot.json",
            "s2": f"{exp_home_dir}/gpt4/v1/non_cot.json"
        },
        "gpt4-Turbo":{
            "s1": f"{ctrl_home_dir}/gpt-4-1106-preview/v0/non_cot.json",
            "s2": f"{exp_home_dir}/gpt-4-1106-preview/v0/non_cot.json"
        },
        "gpt4-cot":{
            "s1": f"{ctrl_home_dir}/gpt-4/v0/cot.json",
            "s2": f"{exp_home_dir}/gpt4/v1/cot.json"
        },
        "gpt4-a-cot": {
            's1': f'{ctrl_home_dir}/gpt-4/v0/a_cot.json',
            's2': f'{exp_home_dir}/gpt-4/v1/a_cot.json'
        },
        "gpt-turbo-cot":{
            "s1": f"{ctrl_home_dir}/gpt-3.5-turbo-16k/v1/cot.json",
            "s2": f"{exp_home_dir}/gpt-3.5-turbo-16k/v0/cot.json"
        },
        "gpt-turbo":{
            "s1": f"{ctrl_home_dir}/gpt-3.5-turbo-16k/v1/non_cot.json",
            "s2": f"{exp_home_dir}/gpt-3.5-turbo-16k/v0/non_cot.json"
        },
        "claude-2":
        {
            "s1": f"{ctrl_home_dir}/claude-2/v1/non_cot.json",
            "s2": f"{exp_home_dir}/claude-2/v1/non_cot.json"
        },
        "claude-2-cot":
        {
            "s1": f"{ctrl_home_dir}/claude-2/v1/cot.json",
            "s2": f"{exp_home_dir}/claude-2/v1/cot.json"
        },
        "ernie":
        {
            "s1": f"{ctrl_home_dir}/ernie/v0/non_cot.json",
            "s2": f"{exp_home_dir}/ernie/v0/non_cot.json"
        },
        "spark":
        {
            "s1": f"{ctrl_home_dir}/spark-3.1/v0/non_cot.json",
            "s2": f"{exp_home_dir}/spark-3.1/v0/non_cot.json"
        },
        "llama2-70B":
        {
            's1': f"{ctrl_home_dir}/llama-2-70b-chat/v0/non_cot.json",
            's2': f"{exp_home_dir}/llama-2-70b-chat/v0/non_cot.json"
        },
        'qwen':{
            "s1": f"{ctrl_home_dir}/qwen-plus/v0/non_cot.json",
            "s2": f"{exp_home_dir}/qwen-plus/v0/non_cot.json"
        },
        "palm":{
            "s1": f"{ctrl_home_dir}/palm/v0/non_cot.json",
            "s2": f"{exp_home_dir}/palm/v0/non_cot.json"
        },

    }
    data = {}
    for role in paths:
        for stage in paths[role]:
            with open(paths[role][stage], 'r') as f:
                data[role] = data.get(role, {})
                data[role][stage] = json.load(f)
    if count_token:
        return count_tokens(data)
    return data

def load_turnover_data(judge, home):
    paths = {
        "turbo_turbo_fe":{
            "s1": f"{home}/eval_results/3_eval_answer/claude-2-web/v2.1/non_cot.json",
            "s2": f"{home}/eval_results/3_eval_answer/claude-2-web/v2.1.1/non_cot.json"
        },
        "llama2-70b-turbo-ans":{
            's1': f"{home}/eval_results/2_val_answer/gpt-4/v7.2/non_cot.json",
            's2': f"{home}/eval_results/3_eval_answer/gpt-4/v5.2.1/non_cot.json"
        
        },
    }
    data = {}
    for role in paths:
        for stage in paths[role]:
            with open(paths[role][stage], 'r') as f:
                data[role] = data.get(role, {})
                data[role][stage] = json.load(f)
    # return count_tokens(data)
    return data

def load_turnover_for_turbo_fe(judge, home):
    paths = {
        "turbo_fe_turnover":{
            "s1": f"{home}/eval_results/3_eval_answer/{judge}/vturbo_turbo-fe/non_cot.json",
            "s2": f"{home}/eval_results/3_eval_answer/{judge}/vturbo_turbo-fe_ref_rc/non_cot.json"
        },
    }
    data = {}
    for role in paths:
        for stage in paths[role]:
            with open(paths[role][stage], 'r') as f:
                data[role] = data.get(role, {})
                data[role][stage] = json.load(f)
    # return count_tokens(data)
    return data





def load_turnover_for_turbo_llama(judge, roles, home):
    def add_perturb_for_llama_in_s1(d):
        for qid in d:
            for answer in d[qid]['answers'].values():
                if 'llama' in answer['model_id']:
                    answer['perturb'] = " " # make it different from ""

    def update_perturb_id(d, turbo_ids_to_be_perturbed):
        for qid in d:
            for answer in d[qid]['answers'].values():
                if answer['perturb'] != "": # the perturbed answer
                    turbo_ids_to_be_perturbed.append(answer['answer_id'])
    def add_perturb_for_turbo_turbo_s1(d, turbo_ids_to_be_perturbed):
        for qid in d:
            for answer in d[qid]['answers'].values():
                if answer['answer_id'] in turbo_ids_to_be_perturbed: # the perturbed answer
                    answer['perturb'] = " " # make it different from ""


    paths = {}


    for role in roles:
        paths[role] = {
                "s2": f"{home}/eval_results/3_eval_answer/{judge}/v{role}/non_cot.json",
                "s1": f"{home}/eval_results/2_val_answer/{judge}/v{role.replace('_ref', '')}/non_cot.json",
        }

    data = {}
    turbo_ids_to_be_perturbed = []
    for role in paths:
        for stage in paths[role]:
            with open(paths[role][stage], 'r') as f:
                data[role] = data.get(role, {})
                d_ = json.load(f)
                if stage == 's1':
                    add_perturb_for_llama_in_s1(d_)
                if 'turbo_turbo' in role:
                    if stage == 's2':
                        # record the perturb id
                        update_perturb_id(d_, turbo_ids_to_be_perturbed)
                    elif stage == 's1':
                        add_perturb_for_turbo_turbo_s1(d_, turbo_ids_to_be_perturbed)

                data[role][stage] = d_

    # return count_tokens(data)
    return data

def combine_side(data):
    combined = {}
    for question_id, question in data.items():
        raw_id = question_id.split('-')[0]
        if raw_id not in combined:
            combined[raw_id] = {"votes": []}
        combined[raw_id]["votes"].extend(question["preference"])
    return combined

def find_weak_questions(data, perturb, weak = "perturb"):
    weak_questions = {}
    for role in data:
        if not "turbo" in role:
            weak_questions[role] = {}
            s1_data, s2_data = data[role]["s1"], data[role]["s2"]
            s1_data_by_question = combine_res_by_question(s1_data)
            s2_data_by_question = combine_res_by_question(s2_data)
            s2_by_perturb = {}

            group_by_perturb(s2_data_by_question, s2_by_perturb)
            s1_transformed, _ = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])

            s1_combined = combine_side(s1_transformed)
            for question_id, question in s1_combined.items():
                if weak == "perturb":
                    if question["votes"].count(0) > question["votes"].count(1):
                        weak_questions[role][question_id] = question['votes']
                elif weak == "original":
                    if question["votes"].count(0) < question["votes"].count(1):
                        weak_questions[role][question_id] = question['votes']
    return weak_questions



def find_inconsistent_questions(data, perturb):
    target_list = get_inconsistent_qid_list(perturb)
    target_questions = {}
    for role in data:
        if not "turbo" in role:
            target_questions[role] = {}
            s1_data, s2_data = data[role]["s1"], data[role]["s2"]
            s1_data_by_question = combine_res_by_question(s1_data)
            s2_data_by_question = combine_res_by_question(s2_data)
            s2_by_perturb = {}

            group_by_perturb(s2_data_by_question, s2_by_perturb)
            s1_transformed, s2_transformed = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])

            # s1_combined = combine_side(s1_transformed)
            s2_combined = combine_side(s2_transformed)
            for question_id, question in s2_combined.items():
                if question_id in target_list:
                    target_questions[role][question_id] = question['votes']
    return target_questions


def find_consistent_questions(data, perturb):
    target_list = get_consistent_qid_list(perturb)
    target_questions = {}
    for role in data:
        if not "turbo" in role:
            target_questions[role] = {}
            # s1_data, s2_data = data[role]["s1"], data[role]["s2"]
            # s1_data_by_question = combine_res_by_question(s1_data)
            # s2_data_by_question = combine_res_by_question(s2_data)
            # s2_by_perturb = {}

            # group_by_perturb(s2_data_by_question, s2_by_perturb)
            # s1_transformed, s2_transformed = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])

            s1_transformed, s2_transformed = get_tranformed_data_between_stage(data[role], perturb)

            # s1_combined = combine_side(s1_transformed)
            s2_combined = combine_side(s2_transformed)
            for question_id, question in s2_combined.items():
                if question_id in target_list:
                    target_questions[role][question_id] = question['votes']
    return target_questions

def find_random_questions(data, perturb, num = 10):
    target_questions = {}
    for role in data:
        if not "turbo" in role:
            target_questions[role] = {}
            # s1_data, s2_data = data[role]["s1"], data[role]["s2"]
            # s1_data_by_question = combine_res_by_question(s1_data)
            # s2_data_by_question = combine_res_by_question(s2_data)
            # s2_by_perturb = {}

            # group_by_perturb(s2_data_by_question, s2_by_perturb)
            # s1_transformed, _ = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])

            s1_transformed, _ = get_tranformed_data_between_stage(data[role], perturb)

            # s1_combined = combine_side(s1_transformed)
            s1_combined = combine_side(s1_transformed)
            samples = random.sample(list(s1_combined.keys()), num)

            for question in samples:
                target_questions[role][question] = s1_combined[question]['votes']
    return target_questions

def get_random_data(data, perturb, num = 10):
    random_questions = find_random_questions(data, perturb, num)
    random_data = {}
    for role in random_questions:
        random_data[role] = {}
        for stage in ['s1', 's2']:
            random_data[role][stage] = {}
            random_question = random_questions[role].keys()
            for question_id, question in data[role][stage].items():
                if question_id.split('-')[0] in random_question:
                    random_data[role][stage][question_id] = question
    return random_data

def find_random_questions_N(data, perturb, num = 10, N = 1):
    target_questions = {}
    for role in data:
        if not "turbo" in role:
            
            # s1_data, s2_data = data[role]["s1"], data[role]["s2"]
            # s1_data_by_question = combine_res_by_question(s1_data)
            # s2_data_by_question = combine_res_by_question(s2_data)
            # s2_by_perturb = {}

            # group_by_perturb(s2_data_by_question, s2_by_perturb)
            # s1_transformed, _ = transform_between_stage(s1_data_by_question, s2_by_perturb[perturb])

            s1_transformed, _ = get_tranformed_data_between_stage(data[role], perturb)

            # s1_combined = combine_side(s1_transformed)
            s1_combined = combine_side(s1_transformed)
            
            target_questions[role] = []
            for _ in range(N):
                samples = random.choices(list(s1_combined.keys()), k=num)
                temp = {}
                for question in samples:
                    temp[question] = s1_combined[question]['votes']
                target_questions[role].append(temp)

    return target_questions

def get_random_data_N(data, perturb, num = 10, N = 1):
    random_questions = find_random_questions_N(data, perturb, num, N)
    random_data = []
    for i in range(N):
        temp_data = {}
        for role in random_questions:
            temp_data[role] = {}
            for stage in ['s1', 's2']:
                temp_data[role][stage] = {}
                random_question = random_questions[role][i].keys()
                for question_id, question in data[role][stage].items():
                    if question_id.split('-')[0] in random_question:
                        temp_data[role][stage][question_id] = question
        random_data.append(temp_data)
    if N == 1:
        random_data = random_data[0]
    return random_data

def random_sample(data, perturb, sample_size = 50, N = 10):
    sample_data = []
    control_group = []
    experiment_group = []
    random_data = get_random_data_N(data, perturb, num=sample_size, N = N)
    for i in range(N):
        sample_data.append(get_preference_count(random_data[i], perturb)['human'])

    for item in sample_data:
        control_group.append([item['Control']['T'], item['Control']['O'], item['Control']['P']])
        experiment_group.append([item['Experimental']['T'], item['Experimental']['O'], item['Experimental']['P']])

    # 将列表转换为numpy数组
    control_group = np.array(control_group)
    experiment_group = np.array(experiment_group)
    return control_group, experiment_group

def get_weak_data(data, perturb, weak = 'perturb'):
    weak_questions = find_weak_questions(data, perturb, weak)
    weak_data = {}
    for role in weak_questions:
        weak_data[role] = {}
        for stage in ['s1', 's2']:
            weak_data[role][stage] = {}
            weak_question = weak_questions[role].keys()
            for question_id, question in data[role][stage].items():
                if question_id.split('-')[0] in weak_question:
                    weak_data[role][stage][question_id] = question
    return weak_data

def get_inconsistent_data(data, perturb):
    inc_questions = find_inconsistent_questions(data, perturb) # the inc question for this perturb
    inc_data = {}
    for role in inc_questions:
        inc_data[role] = {}
        for stage in ['s1', 's2']:
            inc_data[role][stage] = {}
            inc_question = inc_questions[role].keys() # only raw qid is stored, since this is sufficient for each perturb
            for question_id, question in data[role][stage].items():
                if question_id.split('-')[0] in inc_question:
                    # if "Eval" in question_id:
                    # if perturb in question_id: # if this is the current perturb
                    inc_data[role][stage][question_id] = question
    # pdb.set_trace()

    return inc_data

def get_consistent_data(data, perturb):
    inc_questions = find_consistent_questions(data, perturb) # the inc question for this perturb
    inc_data = {}
    for role in inc_questions:
        inc_data[role] = {}
        for stage in ['s1', 's2']:
            inc_data[role][stage] = {}
            inc_question = inc_questions[role].keys() # only raw qid is stored, since this is sufficient for each perturb
            for question_id, question in data[role][stage].items():
                if question_id.split('-')[0] in inc_question:
                    # if perturb in question_id: # if this is the current perturb
                    # if "Eval" in question_id:
                    inc_data[role][stage][question_id] = question
    # pdb.set_trace()

    return inc_data

def get_diff(role_preference_count):
    for role in role_preference_count:
        role_preference_count[role]['diff'] = {}
        for preference in role_preference_count[role]['s2']:
            role_preference_count[role]['diff'][preference] = role_preference_count[role]['s2'][preference] - role_preference_count[role]['s1'][preference]
    return role_preference_count

def search(data, perturb, question_id):
    for role in data:
        print("-"*40)
        print("role: ", role)
        print("s1 vote", data[role][perturb]['s1'][question_id]['votes'])
        print("s2 vote", data[role][perturb]['s2'][question_id]['votes'])
        print("-"*40)

def get_question_id(q_id):
    return q_id.split("-")[0]

def get_avg_pref(data):
    return sum(data['preference']) / len(data['preference'])

def get_preference_diff(data1, data2):
    return abs(get_avg_pref(data1) - get_avg_pref(data2))

def merge_preferences(vote_res):
    merged_data = {}
    for role in vote_res:
        merged_data[role] = {}
        for perturb in vote_res[role]:
            merged_data[role][perturb] = {}
            for stage in vote_res[role][perturb]:
                merged_data[role][perturb][stage] = {}
                for question_id in vote_res[role][perturb][stage]:
                    if get_question_id(question_id) not in merged_data[role][perturb][stage]:
                        merged_data[role][perturb][stage][get_question_id(question_id)] = vote_res[role][perturb][stage][question_id]
                        merged_data[role][perturb][stage][get_question_id(question_id)]['preference'] = vote_res[role][perturb][stage][question_id]['preference'][:3]
                        merged_data[role][perturb][stage][get_question_id(question_id)]['votes'] = vote_res[role][perturb][stage][question_id]['votes'][:3]
                    else:
                        merged_data[role][perturb][stage][get_question_id(question_id)]['preference'].extend(vote_res[role][perturb][stage][question_id]['preference'][:3])
                        merged_data[role][perturb][stage][get_question_id(question_id)]['votes'].extend(vote_res[role][perturb][stage][question_id]['votes'][:3])
    return merged_data

# if __name__ == '__main__':
#     import pdb
#     data = load_data()
#     li = get_inconsistent_data(data, 'rich_content')
#     # li = get_weak_data(data, 'rich_content')
#     # li = get_inconsistent_qid_list()
#     print(li)


def mix_data(weak_data, strong_data, ratio = 0.5):
    mixed_data = {}
    for role in weak_data:
        mixed_data[role] = {}
        #只根据s1的key进行采样
        weak_keys = random.sample(list(weak_data[role]['s1'].keys()), int(len(weak_data[role]['s1']) * ratio))
        strong_keys = random.sample(list(strong_data[role]['s1'].keys()), int(len(strong_data[role]['s1']) * (1 - ratio)))
        ## 去掉后缀
        weak_keys = [key.split('-')[0] for key in weak_keys]
        strong_keys = [key.split('-')[0] for key in strong_keys]

        for stage in weak_data[role]:
            mixed_data[role][stage] = {}
            for question_id, question in weak_data[role][stage].items():
                if question_id.split('-')[0] in weak_keys:
                    mixed_data[role][stage][question_id] = question
            for question_id, question in strong_data[role][stage].items():
                if question_id.split('-')[0] in strong_keys:
                    mixed_data[role][stage][question_id] = question
            
    return mixed_data

def normalize(preference_count):
    preference_count_normalized = {}
    for role in preference_count:
        preference_count_normalized[role] = {}
        for stage in preference_count[role]:
            preference_count_normalized[role][stage] = {}
            for question_id in preference_count[role][stage]:
                preference_count_normalized[role][stage][question_id] = round(preference_count[role][stage][question_id] / sum(preference_count[role][stage].values()), 4)
    return preference_count_normalized

def kl_divergence(p, q):
    # jsd
    p = np.asarray(p)
    q = np.asarray(q)
    p = p / p.sum()
    q = q / q.sum()
    m = (p + q)/2
    return (distance.jensenshannon(p,m)**2 + distance.jensenshannon(q,m)**2)/2

# import stats

# def kl_divergence(p, q):
#     return scipy.stats.entropy(p, q)

def calculate_kl_matrix(group_data):
    n = len(group_data)
    kl_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            p = [group_data[i]['T'], group_data[i]['O'], group_data[i]['P']]
            q = [group_data[j]['T'], group_data[j]['O'], group_data[j]['P']]
            kl_matrix[i, j] = kl_divergence(p, q)
            kl_matrix[j, i] = kl_divergence(q, p)  # KL divergence is not symmetric
    return kl_matrix

def plot_kl_matrix(data, perturb, role):

    data = data[perturb][role]

    # Calculate KL divergence matrices for each group
    kl_matrices = {group_name: calculate_kl_matrix(group_data) for group_name, group_data in data.items()}

    # Find the minimum and maximum KL divergence across all groups
    vmin = min(kl_matrix.min() for kl_matrix in kl_matrices.values())
    vmax = max(kl_matrix.max() for kl_matrix in kl_matrices.values()) * 1.1

    # Create subplots
    fig, axes = plt.subplots(1, len(kl_matrices), figsize=(5 * len(kl_matrices), 4))

    # Plot heatmaps
    for ax, (group_name, kl_matrix) in zip(axes, kl_matrices.items()):
        im = ax.imshow(kl_matrix, vmin=vmin, vmax=vmax, cmap='plasma')
        ax.set_title(f'{group_name} Group', fontsize = 12)

    # Create colorbar
    fig.colorbar(im, ax=axes.ravel().tolist())

    # Add a title to the Figure
    fig.suptitle(f'{role}-{perturb} Heatmaps', fontsize=16)

    plt.show()



# 按照perturb，role分组，得到不同强弱比例混合的数据的preference count
def get_preference_count_sample(num_ratio = 10):
    perturbs = ["reference", "rich_content"]
    preference_count_samples = {}
    data = load_data()

    for perturb in perturbs:
        preference_count_samples[perturb] = {}
        role_preference_count_all = []
        weak_data = get_weak_data(data, perturb)
        strong_data = get_weak_data(data, perturb, weak='original')
        
        for ratio in range(0, num_ratio + 1):
            ratio /= num_ratio
            mixed_data = mix_data(weak_data, strong_data, ratio = ratio)
            role_preference_count = get_preference_count(mixed_data, perturb)
            role_preference_count = normalize(role_preference_count)
            # del role_preference_count['gpt4-cot']
            role_preference_count_all.append(role_preference_count)
        
        for role in role_preference_count_all[0]:
            preference_count_samples[perturb][role] = {}
            preference_count_samples[perturb][role]["Control"] = []
            preference_count_samples[perturb][role]['Experimental'] = []
            for i in role_preference_count_all:
                preference_count_samples[perturb][role]['Control'].append(i[role]['Control'])
                preference_count_samples[perturb][role]['Experimental'].append(i[role]['Experimental'])
    return preference_count_samples

def get_sankey_data(data, perturb, role):
    # print(data.keys())
    s1_transformed, s2_transformed = get_tranformed_data_between_stage(data[role], perturb=perturb)
    s1_combined = combine_side(s1_transformed)
    s2_combined = combine_side(s2_transformed)

    pos_thresh = 0.5

    # labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp'] 
    # source = [0, 0, 1, 1] # 0: pos_ctrl, 1: neg_ctrl
    # target = [2, 3, 2, 3] # 2: pos_exp, 3: neg_exp
    # value = [0, 0, 0, 0]
    # for question_id, votes in s1_combined.items():
    #     preference_s1 = np.mean(votes['votes'])
    #     preference_s2 = np.mean(s2_combined[question_id]['votes'])
    #     if preference_s1 > pos_thresh:
    #         if preference_s2 > pos_thresh:
    #             value[0] += 1
    #         elif preference_s2 < pos_thresh:
    #             value[1] += 1
    #     elif preference_s1 < pos_thresh:
    #         if preference_s2 > pos_thresh:
    #             value[2] += 1
    #         elif preference_s2 < pos_thresh:
    #             value[3] += 1

    ########################## Tie included ##########################
    # labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp', 'tie_ctrl', 'tie_exp'] 
    #               0           1           2           3           4           5
    # 0     2
    # 4     5
    # 1     3

    source = [0, 0, 1, 1, 0, 1, 4, 4, 4] # 0: pos_ctrl, 1: neg_ctrl
    target = [2, 3, 2, 3, 5, 5, 2, 5, 3] # 2: pos_exp, 3: neg_exp
    value = [0 for _ in range(len(source))]

    for question_id, votes in s1_combined.items():
        preference_s1 = np.mean(votes['votes'])
        preference_s2 = np.mean(s2_combined[question_id]['votes'])
        if preference_s1 > pos_thresh:
            if preference_s2 > pos_thresh:
                # pos_ctrl to pos_exp
                value[0] += 1
            elif preference_s2 < pos_thresh:
                # pos_ctrl to neg_exp
                value[1] += 1
            else:
                # pos_ctrl to tie_exp
                value[4] += 1

        elif preference_s1 < pos_thresh:
            if preference_s2 > pos_thresh:
                # neg_ctrl to pos_exp
                value[2] += 1
            elif preference_s2 < pos_thresh:
                # neg_ctrl to neg_exp
                value[3] += 1
            else:
                # neg_ctrl to tie_exp
                value[5] += 1
        else:
            if preference_s2 > pos_thresh:
                # tie_ctrl to pos_exp
                value[6] += 1
            elif preference_s2 < pos_thresh:
                # tie_ctrl to neg_exp
                value[8] += 1
            else:
                # tie_ctrl to tie_exp
                value[7] += 1
    ########################## Tie included ##########################


    return source, target, value

# import plotly.graph_objects as go

# def plot_sankey(data, perturb, role):
#     source, target, value = get_sankey_data(data, perturb, role)
#     labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp']
#     fig = go.Figure(data=[go.Sankey(
#         node = dict(
#         pad = 15,
#         thickness = 20,
#         line = dict(color = "black", width = 0.5),
#         label = labels,
#         ),
#         link = dict(
#         source = source,
#         target = target,
#         value = value,
#     ))])

#     fig.show()

# import plotly.graph_objects as go

# def plot_sankey(data, perturbs, role):
#     fig = make_subplots(rows=1, cols=len(perturbs), subplot_titles=[f"{perturb}" for perturb in perturbs])
    
#     for idx, perturb in enumerate(perturbs):
#         source, target, value = get_sankey_data(data, perturb, role)
#         labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp']
        
#         sankey = go.Sankey(
#             node = dict(
#                 pad = 15,
#                 thickness = 20,
#                 line = dict(color = "black", width = 0.5),
#                 label = labels,
#             ),
#             link = dict(
#                 source = source,
#                 target = target,
#                 value = value,
#             )
#         )

#         # 将 Sankey 图添加到正确的子图位置
#         fig.add_trace(sankey, row=1, col=idx+1)
    
#     fig.update_layout(title_text="Your Main Title Here")
#     fig.show()

import plotly.graph_objects as go

def plot_sankey(data, perturbs, role, judge):
    fig = go.Figure()
    num_plots = len(perturbs)

    # 为每个子图之间留出一些空隙
    gap = 0.02  # 子图之间的空隙
    plot_width = (1 - gap * (num_plots - 1)) / num_plots  # 每个子图的宽度

    annotations = []
    pretty_asr = []
    pretty_acc1 = []
    pretty_acc2 = []
    for i, perturb in enumerate(perturbs):
        source, target, value = get_sankey_data(data, perturb, role)
        # value: pos2pos, pos2neg, neg2pos, neg2neg, pos2tie, neg2tie, tie2pos, tie2tie, tie2neg
        #           0       1           2       3       4       5           6       7       8
        labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp', 'tie_ctrl', 'tie_exp'] 

        sankey = go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=labels,
            ),
            link=dict(
                source=source,
                target=target,
                value=value,
            ),
            domain={'x': [i * (plot_width + gap), i * (plot_width + gap) + plot_width], 'y': [0, 1]}
        )

        fig.add_trace(sankey)

        # 为每个子图添加注释作为子标题
        annotations.append(dict(
            x=(2 * i + 1) * (plot_width + gap) / 2 - gap / 2,
            y=1.07,
            xref='paper',
            yref='paper',
            text=f'{perturb}',
            showarrow=False,
            font=dict(size=12)
        ))
        print('-'*10)
        print(f'perturb: {perturb}')
        n_pos, n_neg = value[0] + value[1], value[2] + value[3]
        r1 = f'{value[2]} / {n_neg} = {value[2]/(n_neg)}' if n_neg!=0 else 0
        r2 = f'{value[1]} / {n_pos} = {value[1]/(n_pos)}' if n_pos!=0 else 0

        # print(f'neg2pos / neg = {r1}')
        # print(f'pos2neg / pos = {r2}')

        # acc is for turbo_turbo_fe only, since turbo_turbo does not have a well-defined acc
        if role == 'turbo_fe_turnover': # two metrics: acc and asr
            acc1 = sum([value[_] for _ in [2,3,5]]) / sum(value)
            print(f'acc@stage 1: {sum([value[_] for _ in [2,3,5]])} / {sum(value)} = {acc1}')
            pretty_acc1.append([sum([value[_] for _ in [2,3,5]]), sum(value)])

            acc2 = sum([value[_] for _ in [1,3,8]]) / sum(value)
            print(f'acc@stage 2: {sum([value[_] for _ in [1,3,8]])} / {sum(value)} = {acc2}')
            pretty_acc2.append([sum([value[_] for _ in [1,3,8]]), sum(value)])

            asr = sum([value[_] for _ in [2,5]]) / sum([value[_] for _ in [2,3,5]])
            print(f'asr: {sum([value[_] for _ in [2,5]])} / {sum([value[_] for _ in [2,3,5]])} = {asr}')
            pretty_asr.append(str(round(asr, 2)))
        
        else: # turbo vs llama (confirmed)
            asr = sum([value[_] for _ in [2,6]]) / sum([value[_] for _ in [2,3,5,6,7,8]]) 
            print(f'asr: {sum([value[_] for _ in [2,6]]) } / {sum([value[_] for _ in [2,3,5,6,7,8]]) } = {asr}')
            pretty_asr.append(str(round(asr, 2)))
            turbo, llama, all = sum([value[_] for _ in [2,3,5]]), sum([value[_] for _ in [0,1,4]]), sum(value)
            # print(f"llama@stage1 %: {sum([value[_] for _ in [0,2,6]]) / sum(value)}")



        print('-'*10)
        # n_non_pos = sum([value[_] for _ in [2,3,5,6,8,7]])
        # if n_non_pos == 0:
        #     pretty_print.append(f'0 { ({value[2] + value[6]})}')
        # else:
        #     pretty_print.append(f'{((value[2] + value[6])/n_non_pos) :.2f} ({value[2] + value[6]})')

    print(f'total votes: {sum(value)}')
    print('asr:')
    print('\t'.join(pretty_asr))

    if role == 'turbo_fe_turnover':
        print('acc1:')
        acc1 = sum([p[0] for p in pretty_acc1]) / sum([p[1] for p in pretty_acc1])
        print(round(acc1, 2))
        print('acc2:')
        acc2 = sum([p[0] for p in pretty_acc2]) / sum([p[1] for p in pretty_acc2])
        print(round(acc2, 2))


    fig.update_layout(
        title_text=f"Turnover Plot for {role}\njudge: {judge}",
        annotations=annotations
    )
    fig.show()
    return turbo, llama, all


# input one user's vote data
# output a dictionary, the key is perturbation, values are voting results of the 
# voting of that perturbtion's questions
def find_perturb(user_data):
    perturb_vote = {}
    perturb_pref = {}
    for question_id, question in user_data.items():
        if question_id.split('-')[1] not in perturb_vote:
            perturb_vote[question_id.split('-')[1]] = []
        choice = question['evals']['vote']
        if choice == 'left':
            choose_type = question['answers']['answer1']['perturb']
        elif choice == 'right':
            choose_type = question['answers']['answer2']['perturb']
        elif choice == 'tie':
            choose_type = "tie"
        else:
            # skipped
            continue

        if choose_type == "tie":
            perturb_vote[question_id.split('-')[1]].append(0.5)
        elif choose_type != "":
            # perturb
            perturb_vote[question_id.split('-')[1]].append(1)
        else:
            # original
            perturb_vote[question_id.split('-')[1]].append(0)
    
    for perturb, votes in perturb_vote.items():
        perturb_pref[perturb] = {'avg_pref':sum(votes) / len(votes), "num_votes": len(votes)}
    return perturb_pref, perturb_vote

# input data is of the format of by_user
# output a dictionary, the key is user_id, values are dictionries of perturbation type and
# the average preference of that perturbation type
# this function get bias data from the user data
def get_bias_data(user_data):
    bias_data = {}
    for user_id, user in user_data.items():
        perturb_pref, perturb_vote = find_perturb(user)
        for perturb, pref in perturb_pref.items():
            if perturb == 'factual_error' or 'temperature' in perturb:
                continue
            if pref['avg_pref'] >= 0.75 or pref['avg_pref'] <= 0.25:
                if user_id not in bias_data:
                    bias_data[user_id] = {}
                bias_data[user_id][perturb] = pref
    return bias_data

# input data is of the format of by_user
# output a dictionary, the key is user_id, values are the bias perturbation types for that user
# this function get bias key from the bias data
def get_bias_key(user_data):
    bias_data = get_bias_data(user_data)    
    bias_key = {}
    for user_id, user in bias_data.items():
        bias_key[user_id] = []
        for perturb, pref in user.items():
            if pref['num_votes'] >= 5:
                bias_key[user_id].append(perturb)
    
    # delete users with no bias
    for user_id, user in bias_key.copy().items():
        if len(user) == 0:
            del bias_key[user_id]
    return bias_key

# data should be the same format as the raw data by load_data
# bias_key should be of the format {user_id: [perturb_type]}
# this function will fiter out the bias data for stage 2 human data
# according to user_id and perturbation type 
def filter_bias_data(data, bias_key):
    for question_id, question in data['human']['s2'].items():
        for user_id, perturb_types in bias_key.items():
            if user_id in question['evals'] and question_id.split('-')[1] in perturb_types:
                del question['evals'][user_id]
    return data

def calculate_turnover_rate(data, role, perturb):
    source, target, value = get_sankey_data(data, perturb, role)
    # 定义索引
    labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp', 'tie_ctrl', 'tie_exp'] 
    neg_ctrl_index = 1
    tie_ctrl_index = 4
    pos_exp_index = 2
    pos_ctrl_index = 0
    tie_exp_index = 5

    if perturb in ['reference', 'rich_content']:
        # 计算从 neg_ctrl 和 tie_ctrl 到 pos_exp 的流的总和
        turnover_flows = sum(value[i] for i in range(len(value)) if (source[i] in [neg_ctrl_index, tie_ctrl_index]) and target[i] == pos_exp_index)
        # 计算所有从neg_ctrl 和 tie_ctrl 出发的流的总和
        total_flows = sum(value[i] for i in range(len(value)) if source[i] in [neg_ctrl_index, tie_ctrl_index])
    elif perturb in ['factual_error']:
        # 计算从 pos_ctrl 和 tie_ctrl 到 pos_exp 的流的总和
        turnover_flows = sum(value[i] for i in range(len(value)) if (source[i] in [pos_ctrl_index, tie_ctrl_index] and target[i] in [pos_exp_index, tie_exp_index]))
        # 计算从 pos_ctrl 和 tie_ctrl 出发的流的总和
        total_flows = sum(value[i] for i in range(len(value)) if source[i] in [pos_ctrl_index, tie_ctrl_index])
    else:
        raise ValueError(f"perturb {perturb} not supported")

    # 计算 turnover rate
    turnover_rate = turnover_flows / total_flows if total_flows > 0 else 0
    return turnover_rate

def calculate_acc(data, role):
    source, target, value = get_sankey_data(data, 'factual_error', role)
    # 定义索引
    labels = ['pos_ctrl', 'neg_ctrl', 'pos_exp', 'neg_exp', 'tie_ctrl', 'tie_exp'] 
    neg_exp_index = 3


    # 计算流向 neg_exp 的流的总和
    to_pos_flows = sum(value[i] for i in range(len(value)) if target[i] == neg_exp_index)
    # 计算所有流的总和
    total_flows = sum(value)

    # 计算 acc
    acc = to_pos_flows / total_flows if total_flows > 0 else 0
    return acc