import pickle
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


def generate_pd(**kwargs):
    all_list = []
    for k, v in kwargs.items():
        rc_type, attn = k.split('_')
        with open(v, 'rb') as f:
            _ = pd.DataFrame(pickle.load(f))
            _['model'] = rc_type  # LSTM/RNN/GRU
            _['type'] = attn  # no/dot
            all_list.append(_)
    all_pd = pd.concat(all_list, axis=0, ignore_index=True)
    if "e_h-befor_area" in all_pd.columns.values:
        all_pd = all_pd.rename(columns={"e_h-befor_area": "e_h-before_area"})

    all_pd['embedding_density'] = all_pd['e_h-before_area'] / all_pd['length']
    all_pd['hidden_density'] = all_pd['e_h-after_area'] / all_pd['length']
    all_pd['attn_density'] = all_pd['h_a-after_area'] / all_pd['length']
    return all_pd


def sst_bi():
    BASE_PATH = './sst'

    LSTM_no = f'{BASE_PATH}/BiLSTM_18_noAttn/calculate.pkl'
    LSTM_dot = f'{BASE_PATH}/BiLSTM_16_dotAttn/calculate.pkl'

    GRU_no = f'{BASE_PATH}/BiGRU_17_noAttn/calculate.pkl'
    GRU_dot = f'{BASE_PATH}/BiGRU_13_dotAttn/calculate.pkl'


    RNN_no = f'{BASE_PATH}/BiRNN_19_noAttn/calculate.pkl'
    RNN_dot = f'{BASE_PATH}/BiRNN_15_dotAttn/calculate.pkl'

    all_dict = {}
    for rc in ['LSTM', 'GRU', 'RNN']:
        for attn in ['no', 'dot']:
            k = f'{rc}_{attn}'
            all_dict[k] = eval(k)
    return generate_pd(**all_dict)


def agnews_bi():
    BASE_PATH = './ag_news'

    LSTM_no = f'{BASE_PATH}/BiLSTM_17_noAttn/calculate.pkl'
    LSTM_dot = f'{BASE_PATH}/BiLSTM_17_dotAttn/calculate.pkl'


    GRU_no = f'{BASE_PATH}/BiGRU_19_noAttn/calculate.pkl'
    GRU_dot = f'{BASE_PATH}/BiGRU_19_dotAttn/calculate.pkl'


    RNN_no = f'{BASE_PATH}/BiRNN_18_noAttn/calculate.pkl'
    RNN_dot = f'{BASE_PATH}/BiRNN_18_dotAttn/calculate.pkl'


    all_dict = {}
    for rc in ['LSTM', 'GRU', 'RNN']:
        for attn in ['no', 'dot']:
            k = f'{rc}_{attn}'
            all_dict[k] = eval(k)
    return generate_pd(**all_dict)


def sst_single():
    BASE_PATH = './sst'

    LSTM_no = f'{BASE_PATH}/LSTM_19_noAttn/calculate.pkl'
    LSTM_dot = f'{BASE_PATH}/LSTM_18_dotAttn/calculate.pkl'


    GRU_no = f'{BASE_PATH}/GRU_18_noAttn/calculate.pkl'
    GRU_dot = f'{BASE_PATH}/GRU_16_dotAttn/calculate.pkl'


    RNN_no = f'{BASE_PATH}/RNN_19_noAttn/calculate.pkl'
    RNN_dot = f'{BASE_PATH}/RNN_18_dotAttn/calculate.pkl'


    all_dict = {}
    for rc in ['LSTM', 'GRU', 'RNN']:
        for attn in ['no', 'dot']:
            k = f'{rc}_{attn}'
            all_dict[k] = eval(k)
    return generate_pd(**all_dict)


def agnews_single():
    BASE_PATH = './ag_news'

    LSTM_no = f'{BASE_PATH}/LSTM_17_noAttn/calculate.pkl'
    LSTM_dot = f'{BASE_PATH}/LSTM_17_dotAttn/calculate.pkl'


    GRU_no = f'{BASE_PATH}/GRU_19_noAttn/calculate.pkl'
    GRU_dot = f'{BASE_PATH}/GRU_15_dotAttn/calculate.pkl'


    RNN_no = f'{BASE_PATH}/RNN_18_noAttn/calculate.pkl'
    RNN_dot = f'{BASE_PATH}/RNN_14_dotAttn/calculate.pkl'


    all_dict = {}
    for rc in ['LSTM', 'GRU', 'RNN']:
        for attn in ['no', 'dot']:
            k = f'{rc}_{attn}'
            all_dict[k] = eval(k)
    return generate_pd(**all_dict)


def draw_pic_e_h(all_rnn, save_path='1.png'):

    e_h = ['embedding_density', 'hidden_density', 'e_h-centroid_distance',
           'e_h-area_precision', 'e_h-area_recall', 'e_h-area_f1']

    sns.set_theme(style="ticks")
    plt.figure(figsize=(18, 7))
    names = [r"$\mathtt{DSU(\mathit{\varepsilon}^{type})}$", r"$\mathtt{DSU(\mathit{H}^{type})}$",
             r"$\mathtt{CIO(\mathit{\varepsilon}^{type}, \mathit{H}^{type})}$",
             r"$\mathtt{SCP(\mathit{\varepsilon}^{type}, \mathit{H}^{type})}$",
             r"$\mathtt{SCR(\mathit{\varepsilon}^{type}, \mathit{H}^{type})}$",
             r"$\mathtt{SCF(\mathit{\varepsilon}^{type}, \mathit{H}^{type})}$"]

    for i, each in enumerate(e_h):
        ax = plt.subplot(2, 3, i + 1)
        sns.boxplot(x=each, y='model', hue='type', data=all_rnn, hue_order=['no', 'dot'],
                    ax=ax, palette=['#44c2fd', '#fb7756'], linewidth=1, fliersize=2, width=0.55, whis=2)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.legend_.remove()
        plt.xlabel(names[i])
        plt.ylabel('')
        plt.subplots_adjust(hspace=0.22)

    patches = [mpatches.Patch(facecolor='#44c2fd', label=' --', edgecolor='#000000', linewidth=0.5),
               mpatches.Patch(facecolor='#fb7756', label='attn', edgecolor='#000000', linewidth=0.5)]

    plt.legend(handles=patches,  fontsize=14, title='type',
               loc='upper center', ncol=1,  bbox_to_anchor=(-0.5, 2.2))

    # plt.tight_layout()
    # set(plt.gca,'LooseInset',get(plt.gca,'TightInset'))
    plt.savefig(f'./{save_path}')
    # plt.show()


def draw_pic_cross_delta_all(data, save_path=''):

    data['all_type'] = data['model'] + '_' + data['type']
    data.loc[data['type'] == 'no', ('attn_density',)] = data.loc[data['type'] == 'no', ('hidden_density',)].to_numpy()
    data.loc[data['type'] == 'no', ('e_a-centroid_distance',)] = data.loc[
        data['type'] == 'no', ('e_h-centroid_distance',)].to_numpy()

    names = [r"CIO", r"DSU"]

    plt.figure(figsize=(5, 8))
    sns.set_theme(style="ticks")
    for i, each in enumerate(['e_a-centroid_distance', 'attn_density']):
        ax = plt.subplot(2, 1, i + 1)

        sns.boxenplot(x=each, y="model", orient='h',
                      hue="type", data=data, palette=['#44c2fd', '#fb7756'], width=0.45, ax=ax, showfliers=False)

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        ax.legend_.remove()
        plt.xlabel(names[i])
        plt.ylabel('')

        plt.subplots_adjust(wspace=0.15)

    patches = [mpatches.Patch(facecolor='#44c2fd', label=' --', edgecolor='#000000', linewidth=0.5),
               mpatches.Patch(facecolor='#fb7756', label='attn', edgecolor='#000000', linewidth=0.5)]

    plt.legend(handles=patches, fontsize=14, title='type',
               loc='upper center', ncol=1, bbox_to_anchor=(0.75, 1))

    plt.savefig(f'./{save_path}')


if __name__ == '__main__':

    for each in ['bi', 'single']:
        draw_pic_cross_delta_all(eval(f'sst_{each}')(), f'delta_sst_{each}.pdf')
        draw_pic_cross_delta_all(eval(f'agnews_{each}')(), f'delta_agnews_{each}.pdf')
        draw_pic_e_h(eval(f'sst_{each}')(), f'sst_{each}.pdf')
        draw_pic_e_h(eval(f'agnews_{each}')(), f'agnews_{each}.pdf')




