import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from utils import ROLE2TITLE, PERTURB2TITLE
# import scienceplots
# plt.style.use('science')

def get_color(roles):
    colormap = plt.cm.get_cmap('tab10', len(roles))
    role_colors = {role: colormap(i) for i, role in enumerate(roles)}
    return role_colors

def plot_shift(data, perturb, title, save_name = None, levels = None):
    fig, axs = plt.subplots(1, 2, figsize=(20, 5))
    colors = get_color(data.keys())
    if levels is None:
        levels = ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating']

    for role in data:
        
        means = [np.mean(data[role][perturb][key]) for key in levels]
        lower_bounds = [data[role][perturb][key][0] for key in levels]
        upper_bounds = [data[role][perturb][key][1] for key in levels]

        if "cot" in role:
            axs[0].plot(levels, means, '-o', label = role, color=colors[role], linewidth=2.0)
            axs[0].fill_between(levels, lower_bounds, upper_bounds, alpha=0.5, color=colors[role])
        else:
            axs[0].plot(levels, means, '-o', label = role, color=colors[role], linewidth=2.0)
            axs[0].fill_between(levels, lower_bounds, upper_bounds, alpha=0.5, color=colors[role])
            axs[1].plot(levels, means, '-o', label = role, color=colors[role], linewidth=2.0)
            axs[1].fill_between(levels, lower_bounds, upper_bounds, alpha=0.5, color=colors[role])

    for i in range(2):
        axs[i].set_title(title, fontsize=16)
        axs[i].set_xlabel('Levels', fontsize=14)
        axs[i].set_ylabel('Preference', fontsize=14)
        axs[i].legend(fontsize=12)
        axs[i].grid(True)
        axs[i].legend(loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf")
    plt.show()

def plot_shift_non_cot(data, perturb, title, save_name = None, levels = None):
    fig, axs = plt.subplots(1, 1, figsize=(10, 5))
    colors = get_color(data.keys())
    if levels is None:
        levels = ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating']

    for role in data:
        
        means = [np.mean(data[role][perturb][key]) for key in levels]
        lower_bounds = [data[role][perturb][key][0] for key in levels]
        upper_bounds = [data[role][perturb][key][1] for key in levels]

        if "cot" in role:
            continue
        else:
            axs.plot(levels, means, '-o', label = role, color=colors[role], linewidth=2.0)
            axs.fill_between(levels, lower_bounds, upper_bounds, alpha=0.5, color=colors[role])


    axs.set_title(title, fontsize=16)
    axs.set_xlabel('Levels', fontsize=14)
    axs.set_ylabel('Preference Shift', fontsize=14)
    axs.legend(fontsize=12)
    axs.grid(True)
    axs.legend(loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf", bbox_inches='tight')
    plt.show()

def plot_shift_cot(data, perturb, title, save_name = None):
    fig, axs = plt.subplots(1, 1, figsize=(10, 5))
    colors = get_color(data.keys())
    levels = ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating']

    for role in data:
        if "human" not in role and "claude" not in role:
            means = [np.mean(data[role][perturb][key]) for key in levels]
            lower_bounds = [data[role][perturb][key][0] for key in levels]
            upper_bounds = [data[role][perturb][key][1] for key in levels]

            axs.plot(levels, means, '-o', label = role, color=colors[role], linewidth=2.0)
            axs.fill_between(levels, lower_bounds, upper_bounds, alpha=0.5, color=colors[role])

    for i in range(2):
        axs.set_title(title, fontsize=16)
        axs.set_xlabel('Levels', fontsize=14)
        axs.set_ylabel('Preference', fontsize=14)
        axs.legend(fontsize=12)
        axs.grid(True)
        axs.legend(loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf", bbox_inches='tight')
    plt.show()

def plot_preference_count(role_preference_count, perturb, save_name = None):
    preferences = ["O", 'T', 'P']
    legend = ["Original", "Tie", "Perturbation"]

    # 使用更好看的颜色方案
    colors = ['skyblue', 'gold', 'lightcoral']

    # 获取角色和阶段的列表
    roles = list(role_preference_count.keys())
    stages = list(role_preference_count[roles[0]].keys())

    # 创建一个二维的子图网格，其中每一列代表一个角色，每一行代表一个阶段
    fig, axs = plt.subplots(len(stages), len(roles), figsize=(3*len(roles), 3*len(stages)), sharey=True)

    for i, role in enumerate(roles):
        for j, stage in enumerate(stages):
            values = []
            for preference in preferences:
                values.append(role_preference_count[role][stage].get(preference, 0))
            # 使用更好看的颜色方案，并添加图例
            bars = axs[j][i].bar(legend, values, color=colors)
            axs[j][i].bar_label(bars, label_type='edge', fontsize=10)
            
            # 添加网格线
            axs[j][i].grid(True)
            
            # 设置标题和坐标轴标签的字体大小和样式
            axs[j][i].set_title(f'{role} - {stage}', fontsize=10, fontweight='bold')
            axs[j][i].set_xlabel('Preference', fontsize=10)
            axs[j][i].set_ylabel('Count', fontsize=10)
            
            # 调整坐标轴的刻度和刻度标签
            axs[j][i].set_xticks(range(len(legend)))
            axs[j][i].set_xticklabels(legend, fontsize=8)
            
    # 创建全局图例
    patches = [mpatches.Patch(color=color, label=preference) for color, preference in zip(colors, legend)]
    fig.legend(handles=patches, labels=legend, loc='lower center', fontsize=14, ncol=len(legend), bbox_to_anchor=(0.5, 0-0.05))

    # 设置大标题
    fig.suptitle(f'Preference Plot for {perturb}', fontsize=15, fontweight='bold')

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    if save_name is not None:
        plt.savefig(f"../images/{save_name}.pdf", bbox_inches='tight')
    plt.show()


import matplotlib.lines as lines

def plot_preference_count_all(role_preference_count_all, perturbs, save_name = None):
    preferences = ["O", 'T', 'P']
    legend = ["U-Unperturbed", "T-Tie", "P-Perturbed"]
    xlabel = ["U", "T", "P"]
    colors = ['skyblue', 'gold', 'lightcoral']

    perturb2title = PERTURB2TITLE
    role2title = ROLE2TITLE
    
    roles = list(role_preference_count_all[0].keys())
    stages = list(role_preference_count_all[0][roles[0]].keys())
    
    # 创建一个三维的子图网格，其中每一列代表一个角色，每一行代表一个阶段，深度代表不同的perturb
    l = 3
    fig, axs = plt.subplots(len(stages), len(roles)*len(perturbs), figsize=(l*len(roles)*len(perturbs), l*len(stages)), sharey = True)

    for k, perturb in enumerate(perturbs):
        role_preference_count = role_preference_count_all[k]
        
        ## 绘制分割线
        if k > 0:
            ax_position_left = axs[0][k * len(roles) - 1].get_position()
            ax_position_right = axs[0][k* len(roles)].get_position()
            ax_position_bottom = axs[-1][k * len(roles)].get_position()  # 获取底部子图的位置
            
            # 计算分割线的 x 位置
            separator_x = (ax_position_left.x1 + ax_position_right.x0) / 2
            
            # 创建分割线，设置 y 坐标范围以避免线穿过 y 轴标签和标题
            line = lines.Line2D([separator_x, separator_x], [ax_position_bottom.y0, ax_position_left.y1], transform=fig.transFigure, color='grey', linestyle='--', lw=2)
            
            # 将分割线添加到图中
            fig.add_artist(line)

        for i, role in enumerate(roles):
            for j, stage in enumerate(stages):
                values = [role_preference_count[role][stage].get(preference, 0) for preference in preferences]
                
                # 使用更好看的颜色方案，并添加图例
                idx = k * len(roles) + i
                bars = axs[j][idx].bar(xlabel, values, color=colors, width=0.5)
                axs[j][idx].bar_label(bars, label_type='edge', fontsize=10)
                
                # 添加网格线
                axs[j][idx].grid(True)
                
                # 设置标题和坐标轴标签的字体大小和样式
                axs[j][idx].set_title(f'{role2title[role]}\n{stage}', fontsize=14, fontweight='bold')

                if k == 0 and i == 0:
                    # axs[j][idx].set_xlabel('Vote', fontsize=12)
                    axs[j][idx].set_ylabel('Count', fontsize=14)
                
                # 调整坐标轴的刻度和刻度标签
                axs[j][idx].set_xticks(range(len(xlabel)))
                axs[j][idx].set_xticklabels(xlabel, fontsize=15)
                # axs[j][idx].set_yticklabels([int(i) for i in axs[j][idx].get_yticks()], fontsize=14)


                # 仅在第一行第一列为该 perturb 添加一个子标题
                if j == 0 and i == 0:
                    first_ax_position = axs[j][k * len(roles)].get_position()
                    # 获取最后一个子图的位置
                    last_ax_position = axs[j][(k + 1) * len(roles) - 1].get_position()
                    # 计算子标题的 x 位置为两个子图的中心的平均值
                    subtitle_x = (first_ax_position.x0 + last_ax_position.x1) / 2
                    # 使用 ax_position.y1 + 0.02 来确定子标题的 y 位置，使其位于子图的上方一点
                    plt.figtext(subtitle_x, 1-0.02, perturb2title[perturb], ha="center", va="center", fontsize=17, fontweight='bold')

    # 创建全局图例
    # patches = [mpatches.Patch(color=color, label=preference) for color, preference in zip(colors, legend)]
    # fig.legend(handles=patches, loc='upper right', bbox_to_anchor=(1-0.07, 1+0.2), fontsize=12)

    patches = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in colors]
    fig.legend(handles=patches, labels=legend, loc='lower center', fontsize=14, ncol=len(legend), bbox_to_anchor=(0.5, 0-0.05))
    
    # 设置大标题
    fig.suptitle('Preference Plot for All Perturbations', fontsize=20, fontweight='bold', y = 1.08)

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.2, hspace=0.5)
    if save_name is not None:
        plt.savefig(f"../images/{save_name}.png", bbox_inches='tight')
    plt.show()

def plot_dist_all(role_preference_count_all, perturbs, save_name = None, ratio = 1):
    preferences = ["O", 'T', 'P']
    legend = ["U-Unperturbed", "T-Tie", "P-Perturbed"]
    xlabel = ["U", "T", "P"]
    colors = ['skyblue', 'gold', 'lightcoral']
    perturb2title = {
        "factual_error": "Factual Error",
        "reference": "Reference",
        "rich_content": "Rich Content",
    }
    role2title = {
        'human': 'Human',
        'gpt4': 'GPT-4',
        'gpt-turbo': 'GPT-Turbo',
        'claude-2': 'Claude-2',
    }
    
    roles = list(role_preference_count_all[0].keys())
    stages = list(role_preference_count_all[0][roles[0]].keys())
    
    # 创建一个三维的子图网格，其中每一列代表一个角色，每一行代表一个阶段，深度代表不同的perturb
    l = 3
    fig, axs = plt.subplots(len(stages), len(roles)*len(perturbs), figsize=(l*len(roles)*len(perturbs), l*len(stages)), sharey = True)

    for k, perturb in enumerate(perturbs):
        role_preference_count = role_preference_count_all[k]
        
        ## 绘制分割线
        if k > 0:
            ax_position_left = axs[0][k * len(roles) - 1].get_position()
            ax_position_right = axs[0][k* len(roles)].get_position()
            ax_position_bottom = axs[-1][k * len(roles)].get_position()  # 获取底部子图的位置
            
            # 计算分割线的 x 位置
            separator_x = (ax_position_left.x1 + ax_position_right.x0) / 2
            
            # 创建分割线，设置 y 坐标范围以避免线穿过 y 轴标签和标题
            line = lines.Line2D([separator_x, separator_x], [ax_position_bottom.y0, ax_position_left.y1], transform=fig.transFigure, color='grey', linestyle='--', lw=2)
            
            # 将分割线添加到图中
            fig.add_artist(line)

        for i, role in enumerate(roles):
            for j, stage in enumerate(stages):
                values = [role_preference_count[role][stage].get(preference, 0) for preference in preferences]
                
                # 使用更好看的颜色方案，并添加图例
                idx = k * len(roles) + i
                bars = axs[j][idx].bar(xlabel, values, color=colors, width=0.5)
                axs[j][idx].bar_label(bars, label_type='edge', fontsize=10)
                
                # 添加网格线
                axs[j][idx].grid(True)
                
                # 设置标题和坐标轴标签的字体大小和样式
                axs[j][idx].set_title(f'{role2title[role]}\n{stage}', fontsize=14, fontweight='bold')

                if k == 0 and i == 0:
                    # axs[j][idx].set_xlabel('Vote', fontsize=12)
                    axs[j][idx].set_ylabel('Count', fontsize=14)
                
                # 调整坐标轴的刻度和刻度标签
                axs[j][idx].set_xticks(range(len(xlabel)))
                axs[j][idx].set_xticklabels(xlabel, fontsize=15)
                # axs[j][idx].set_yticklabels([i/10 for i in range(0, 11)], fontsize=14)
                axs[j][idx].set_ylim([0, 1])


                # 仅在第一行第一列为该 perturb 添加一个子标题
                if j == 0 and i == 0:
                    first_ax_position = axs[j][k * len(roles)].get_position()
                    # 获取最后一个子图的位置
                    last_ax_position = axs[j][(k + 1) * len(roles) - 1].get_position()
                    # 计算子标题的 x 位置为两个子图的中心的平均值
                    subtitle_x = (first_ax_position.x0 + last_ax_position.x1) / 2
                    # 使用 ax_position.y1 + 0.02 来确定子标题的 y 位置，使其位于子图的上方一点
                    plt.figtext(subtitle_x, 1-0.02, perturb2title[perturb], ha="center", va="center", fontsize=17, fontweight='bold')

    # 创建全局图例
    # patches = [mpatches.Patch(color=color, label=preference) for color, preference in zip(colors, legend)]
    # fig.legend(handles=patches, loc='upper right', bbox_to_anchor=(1-0.07, 1+0.2), fontsize=12)

    patches = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in colors]
    fig.legend(handles=patches, labels=legend, loc='lower center', fontsize=14, ncol=len(legend), bbox_to_anchor=(0.5, 0-0.05))
    
    # 设置大标题
    fig.suptitle(f'Preference Distribution for {ratio*100}% Weak Data', fontsize=20, fontweight='bold', y = 1.08)

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.2, hspace=0.5)
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.png", bbox_inches='tight')
    plt.show()

def plot_preference_count_dif(role_preference_count, perturb, save_name = None):
    preferences = ["O", 'T', 'P']
    legend = ["Original", "Tie", "Perturbation"]

    # 使用更好看的颜色方案
    colors = ['skyblue', 'gold', 'lightcoral']

    # 获取角色和阶段的列表
    roles = list(role_preference_count.keys())
    stages = list(role_preference_count[roles[0]].keys())

    # 创建一个二维的子图网格，其中每一列代表一个角色，每一行代表一个阶段
    fig, axs = plt.subplots(len(stages) + 1, len(roles), figsize=(5*len(roles), 5*(len(stages) + 1)), sharey = True)

    for i, role in enumerate(roles):
        for j, stage in enumerate(stages):
            values = []
            for preference in preferences:
                values.append(role_preference_count[role][stage].get(preference, 0))
            # 使用更好看的颜色方案，并添加图例
            bars = axs[j][i].bar(legend, values, color=colors)
            
            # 添加网格线
            axs[j][i].grid(True)
            
            # 设置标题和坐标轴标签的字体大小和样式
            axs[j][i].set_title(f'{role} - {stage}', fontsize=20, fontweight='bold')
            axs[j][i].set_xlabel('Preference', fontsize=16)
            axs[j][i].set_ylabel('Count', fontsize=16)
            
            # 调整坐标轴的刻度和刻度标签
            axs[j][i].set_xticks(range(len(legend)))
            axs[j][i].set_xticklabels(legend, fontsize=12)
        
        # 计算s2和s1的差值
        diff_values = []
        for preference in preferences:
            s1_value = role_preference_count[role]['Control'].get(preference, 0)
            s2_value = role_preference_count[role]['Experimental'].get(preference, 0)
            diff_values.append(s2_value - s1_value)
        
        # 绘制差值子图
        bars = axs[len(stages)][i].bar(legend, diff_values, color=colors)
        axs[len(stages)][i].grid(True)
        axs[len(stages)][i].set_title(f'{role} - Difference', fontsize=20, fontweight='bold')
        axs[len(stages)][i].set_xlabel('Preference', fontsize=16)
        axs[len(stages)][i].set_ylabel('Count', fontsize=16)
        axs[len(stages)][i].set_xticks(range(len(legend)))
        axs[len(stages)][i].set_xticklabels(legend, fontsize=12)

    # 创建全局图例
    patches = [mpatches.Patch(color=color, label=preference) for color, preference in zip(colors, legend)]
    fig.legend(handles=patches, loc='upper right', bbox_to_anchor=(1, 0.95), fontsize='large')

    # 设置大标题
    fig.suptitle(f'Preference Plot for {perturb}', fontsize=30, fontweight='bold')

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf")
    plt.show()

def plot_preference_count_compact(role_preference_count, perturb, save_name = None):
    preferences = ["O", 'T', 'P']
    legend = ["Original", "Tie", "Perturbation"]

    # 使用更好看的颜色方案
    colors = ['skyblue', 'gold', 'lightcoral', 'green', 'purple']

    # 获取角色和阶段的列表
    roles = list(role_preference_count.keys())
    stages = list(role_preference_count[roles[0]].keys())

    # 创建一个二维的子图网格，其中每一列代表一个角色，每一行代表一个阶段
    fig, axs = plt.subplots(len(stages), 1, figsize=(3*len(roles), 3*(len(stages)+1)))

    width = 0.05  # the width of the bars

    for j, stage in enumerate(stages):
        for i, role in enumerate(roles):
            values = []
            for preference in preferences:
                values.append(role_preference_count[role][stage].get(preference, 0))
            # 使用更好看的颜色方案，并添加图例
            bars = axs[j].bar([x + i * width for x in range(len(legend))], values, width, color=colors[i])
            
            # 添加网格线
            axs[j].grid(True)
            
            # 设置标题和坐标轴标签的字体大小和样式
            axs[j].set_title(f'{stage}', fontsize=20, fontweight='bold')
            axs[j].set_xlabel('Preference', fontsize=16)
            axs[j].set_ylabel('Count', fontsize=16)
            
            # 调整坐标轴的刻度和刻度标签
            axs[j].set_xticks([x + width * (len(roles) / 2 - 0.5) for x in range(len(legend))])
            axs[j].set_xticklabels(legend, fontsize=12)
            
    # 创建全局图例
    patches = [mpatches.Patch(color=color, label=role) for color, role in zip(colors, roles)]
    fig.legend(handles=patches, loc='upper right', bbox_to_anchor=(1, 0.95), fontsize='large')

    # 设置大标题
    fig.suptitle(f'Preference Plot for {perturb}', fontsize=30, fontweight='bold')

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.2, hspace=0.6)
    if save_name is not None:
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf")
    plt.show()

def plot_preference_count_stack(role_preference_count, perturb, save_name = None, prefix = ""):
    preferences = ["O", 'T', "P"]
    legend = ["Original", "Tie", "Perturbation"]

    # 使用更好看的颜色方案
    colors = ['skyblue', 'gold', 'lightcoral', 'green', 'purple']

    # 获取角色和阶段的列表
    roles = list(role_preference_count.keys())
    stages = list(role_preference_count[roles[0]].keys())

    # 创建一个二维的子图网格，其中每一列代表一个角色，每一行代表一个阶段
    fig, axs = plt.subplots(len(stages), 1, figsize=(3*len(roles), 3*(len(stages)+1)))

    width = 0.35  # the width of the bars

    for j, stage in enumerate(stages):
        for i, role in enumerate(roles):
            left = 0
            for k, preference in enumerate(preferences):
                value = role_preference_count[role][stage].get(preference, 0)
                # if value >= 0:
                #     axs[j].barh(i, value, height=width, color=colors[k], left=left)
                #     left += value
                # else:
                #     axs[j].barh(i, abs(value), height=width, color=colors[k], left=left+value)
                #     left += value
                if value >= 0:
                    axs[j].barh(i, value, height=width, color=colors[k], left=left)
                    left += value
                else:
                    left -= abs(value)
                    axs[j].barh(i, abs(value), height=width, color=colors[k], left=left)

            
            # 添加网格线
            axs[j].grid(True)
            
            # 设置标题和坐标轴标签的字体大小和样式
            axs[j].set_title(f'{stage}', fontsize=20, fontweight='bold')
            axs[j].set_xlabel('Count', fontsize=16)
            axs[j].set_ylabel('Role', fontsize=16)
            
            # 调整坐标轴的刻度和刻度标签
            axs[j].set_yticks(range(len(roles)))
            axs[j].set_yticklabels(roles, fontsize=12)
            
    # 创建全局图例
    patches = [mpatches.Patch(color=color, label=preference) for color, preference in zip(colors, preferences)]
    fig.legend(handles=patches, loc='upper right', bbox_to_anchor=(1, 0.95), fontsize='large')

    # 设置大标题
    fig.suptitle(f'{prefix} Preference Plot for {perturb}', fontsize=30, fontweight='bold')

    # 调整子图之间的间距
    plt.subplots_adjust(wspace=0.2, hspace=0.6)
    if save_name is not None:
        print("saving to ", f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf")
        plt.savefig(f"/workspace2/guiming/workspace/comparative/images/{save_name}.pdf")
    plt.show()




