'''This makes the plots from Figure 1 based on the outputs of explore_representations.py.

The command I used for generating the final plots was:
(acl_auth_resp_results is the directory I saved the output of
explore_representations.py to).

for mode in {lower,cased}; do python3 make_plots.py
acl_auth_resp_results/$mode\/ ; mv token_representations.pdf
token_representations_$mode\.pdf; mv attn_representations.pdf
attn_representations_$mode\.pdf; done;

'''
import argparse
import json
import os

import matplotlib.pyplot as plt
import numpy as np

import scipy.stats

import seaborn as sns

import statsmodels.stats.api as sms


sns.set()

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('results_dir')

    return parser.parse_args()


def ci_width(values):
    return scipy.stats.sem(values)
    ci = sms.DescrStatsW(values).tconfint_mean(alpha=.05)
    return np.mean(ci) - ci[0]


def main():
    args = parse_args()

    plt.figure(figsize=(5,5))
    for mode in ['baseline', 'shuff']:
        all_fs = [args.results_dir + '/{}/'.format(mode) + x for x in os.listdir(args.results_dir + '/' + mode)]

        ## token id ##
        roberta_token_id = [x for x in all_fs
                            if 'token_identifiability_statistics_roberta' in x][0]
        bert_token_id = [x for x in all_fs
                         if 'token_identifiability_statistics_bert' in x][0]
        print(roberta_token_id)
        print(bert_token_id)

        with open(roberta_token_id) as f:
            roberta_token_id = json.load(f)
        with open(bert_token_id) as f:
            bert_token_id = json.load(f)

        xs = roberta_token_id['xs']
        roberta_y_dists = roberta_token_id['ydists']
        print([len(y) for y in roberta_y_dists])
        roberta_ys = [np.mean(y) for y in roberta_y_dists]
        roberta_yerr = [ci_width(y) for y in roberta_y_dists]

        bert_y_dists = bert_token_id['ydists']
        bert_ys = [np.mean(y) for y in bert_y_dists]
        bert_yerr = [ci_width(y) for y in bert_y_dists]

        if mode == 'shuff':
            markers, caps, bars = plt.errorbar(
                xs, bert_ys, yerr=bert_yerr, label='BERT', linewidth=5, color='#e66101')
            [bar.set_alpha(0.7) for bar in bars]
            [cap.set_alpha(0.7) for cap in caps]

            markers, caps, bars = plt.errorbar(
                xs, roberta_ys, yerr=roberta_yerr, label='RoBERTa', linewidth=5, color='#5e3c99')
            [bar.set_alpha(0.7) for bar in bars]
            [cap.set_alpha(0.7) for cap in caps]

        else:
            plt.errorbar(
                np.array(xs) + .1, bert_ys, yerr=roberta_yerr, linewidth=2, color='#e66101', alpha=.5, linestyle='--')
            plt.errorbar(
                np.array(xs) + .1, roberta_ys, yerr=roberta_yerr, linewidth=2, color='#5e3c99', alpha=.5, linestyle='--')

    plt.legend(fontsize=14)
    plt.xlabel('Layer Index', fontsize=18)
    plt.ylabel('Token Identifiability', fontsize=18)
    plt.tight_layout()
    plt.savefig('token_representations.pdf')
    plt.cla()
    plt.clf()


    ####################### attention distance #######################


    plt.figure(figsize=(5,5))

    for mode in ['baseline', 'shuff']:

        all_fs = [args.results_dir + '/{}/'.format(mode) + x for x in os.listdir(args.results_dir + '/' + mode)]
        
        roberta_attn = [x for x in all_fs
                        if 'deshuffled_attn_statistics_roberta' in x][0]
        bert_attn = [x for x in all_fs
                     if 'deshuffled_attn_statistics_bert-base' in x][0]

        with open(roberta_attn) as f:
            roberta_attn = json.load(f)
        with open(bert_attn) as f:
            bert_attn = json.load(f)

        xs = roberta_attn['xs']

        roberta_y_dists = roberta_attn['ydists']
        roberta_ys = [np.mean(y) for y in roberta_y_dists]
        roberta_yerr = [ci_width(y) for y in roberta_y_dists]

        bert_y_dists = bert_attn['ydists']
        bert_ys = [np.mean(y) for y in bert_y_dists]
        bert_yerr = [ci_width(y) for y in bert_y_dists]

        if mode == 'shuff':
            plt.errorbar(
                xs, bert_ys, yerr=bert_yerr, label='BERT', linewidth=5, color='#e66101')
            plt.errorbar(
                xs, roberta_ys, yerr=roberta_yerr, label='RoBERTa', linewidth=5, color='#5e3c99')

            ymin, ymax = plt.gca().get_ylim()
            xmin, xmax = plt.gca().get_xlim()

            plt.scatter(
                bert_attn['per_head_xs'],
                bert_attn['per_head_ys'],
                #label='BERT Heads',
                alpha=.3,
                color='#e66101')

            plt.scatter(
                roberta_attn['per_head_xs'],
                roberta_attn['per_head_ys'],
                #label='RoBERTa Heads',
                alpha=.3,
                color='#5e3c99')
        else:
            plt.errorbar(
                xs, bert_ys, yerr=roberta_yerr, linewidth=2, color='#e66101', alpha=.5, linestyle='--')
            plt.errorbar(
                xs, roberta_ys, yerr=roberta_yerr, linewidth=2, color='#5e3c99', alpha=.5, linestyle='--')
    
    plt.xlim((xmin, xmax))
    plt.ylim((ymin, ymax))
    plt.legend(fontsize=14)
    plt.xlabel('Layer Index', fontsize=20)
    plt.ylabel('Attention Distance', fontsize=20)
    plt.tight_layout()
    plt.savefig('attn_representations.pdf')
    plt.cla()
    plt.clf()


if __name__ == '__main__':
    main()
