import sys
import os.path as o
import os
os.chdir('../')


sys.path.append(o.abspath(o.join(o.dirname(sys.modules[__name__].__file__), "..")))
from directories import *

# Hyper Params
BACKBONE = "bert"
SELECTED_GPU = 1
TASK = "movie_rationales"
SPLIT = "test"
MODEL_PATH = 'bert-base-uncased'
MAX_LENGTH = 512
BEST_EPOCH = 5
SEED = 42  ### DO NOT CHANGE
SALIENCIES_PATH = WORKER_SALIENCIES_PATH_to_READ(TASK, BACKBONE, SPLIT)
ATTENTION_PATH = WORKER_ATTENTIONS_PATH_to_READ(TASK, BACKBONE, SPLIT, "attn_wCLSSEP_")
ATTENTION_ROLLOUT_PATH = WORKER_ATTENTIONS_PATH_to_READ(TASK, BACKBONE, SPLIT, "attn_rollout_wCLSSEP_")

LOAD_MODEL_PATH = WORKER_MODEL_LOAD_PATH(TASK, BACKBONE, "adaptive_lr/run_1_movies_GAMMA0.01_PHI0.01_lr") #0.01 0.01

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, average_precision_score
from scipy.stats import spearmanr

from transformers import (
    BertConfig,
    BertTokenizer,
)

from modeling.modeling_tf_bert import TFBertForSequenceClassification

from utils.rationale_utils import eraser_to_token_level_rationale_data, movies_to_features, eraser_tasks, num_label_tasks

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_visible_devices(gpus[int(SELECTED_GPU)], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[int(SELECTED_GPU)], True)
        print(gpus[int(SELECTED_GPU)])
    except RuntimeError as e:
        print(e)


# Load Tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

# Load Data
data_root = os.path.join('../ERASER/eraserbenchmark/data', TASK)
data = eraser_to_token_level_rationale_data(data_root)
test_data = data['test']

# Convert to Features
test_dataset = movies_to_features(test_data, tokenizer, max_length=MAX_LENGTH, return_all=True).cache()
num_test_examples = len(list(test_dataset.as_numpy_iterator()))
test_dataset = test_dataset.batch(1)
num_labels = num_label_tasks[TASK]

# Load Saliencies
raw_sals = np.load(SALIENCIES_PATH+f"{BEST_EPOCH}.npy")

# Load Attentions
raw_attns = np.load(ATTENTION_PATH+f"{BEST_EPOCH}.npy")
raw_roll_attns = np.load(ATTENTION_ROLLOUT_PATH+f"{BEST_EPOCH}.npy")


raw_attns = np.sum(raw_attns, axis=2) # how much each token attended
raw_roll_attns = np.sum(raw_roll_attns, axis=2)


# Load Model
config = BertConfig.from_pretrained(MODEL_PATH, num_labels=num_labels)
model = TFBertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)
model.to_ALR()
model.load_weights(LOAD_MODEL_PATH)
model.bert.encoder._lambda.assign(1e+10)


# Extract Head explanations & Saliencies
soft_attentions = []
hard_attentions = []
soft_roll_attentions = []
hard_roll_attentions = []
soft_head_explanations = []
hard_head_explanations = []
soft_saliencies = []
hard_saliencies = []
rationales = []
for ex, inputs in enumerate(test_dataset):
    outputs = model(inputs[0], training=False, output_explanations=True)

    # rationales
    rationale = tf.squeeze(inputs[1]['rationale']).numpy()
    # rationales.append(np.concatenate([tf.squeeze(inputs[1]['rationale']).numpy()[1:sep_position-1], [0]])) # if you want to keep sep
    
    # soft head's explanation & soft saliency (intact) and soft attentions
    soft_exp = tf.squeeze(tf.stack(outputs['explanations'])).numpy()
    soft_sal = raw_sals[ex]
    soft_attn = raw_attns[ex]
    soft_roll_attn = raw_roll_attns[ex]

    # truncating: ignore special tokens
    sep_position = np.sum(inputs[0]['attention_mask']) - 1 # remove -1 if you want to keep sep

    soft_exp = soft_exp[:, 1:sep_position]
    soft_sal = soft_sal[:, 1:sep_position]
    soft_attn = soft_attn[:, 1:sep_position]
    soft_roll_attn = soft_roll_attn[:, 1:sep_position]
    rationale = rationale[1:sep_position]
    
    # token-level to word-level 
    word_indices = tf.squeeze(inputs[1]['word_index']).numpy()[1:sep_position]
    t = 0
    exist_status = False
    while t < len(word_indices):
        deleted_indices = []
        i = t + 1
        while i < len(word_indices) and word_indices[t] == word_indices[i]:
            deleted_indices.append(i)
            i += 1
        if len(deleted_indices) > 0:
            # aggregate
            soft_sal[:, t] = np.sum(soft_sal[:, [t] + deleted_indices], axis=1)
            soft_exp[:, t] = np.sum(soft_exp[:, [t] + deleted_indices], axis=1)
            soft_attn[:, t] = np.sum(soft_attn[:, [t] + deleted_indices], axis=1)
            soft_roll_attn[:, t] = np.sum(soft_roll_attn[:, [t] + deleted_indices], axis=1)
            # delete
            soft_sal = np.delete(soft_sal, deleted_indices, axis=1)
            soft_exp = np.delete(soft_exp, deleted_indices, axis=1)
            soft_attn = np.delete(soft_attn, deleted_indices, axis=1)
            soft_roll_attn = np.delete(soft_roll_attn, deleted_indices, axis=1)
            rationale = np.delete(rationale, deleted_indices)
            word_indices = np.delete(word_indices, deleted_indices)
        t += 1


    # normalize
    soft_sal = soft_sal / np.sum(soft_sal[:, :len(word_indices)], axis=1, keepdims=True)
    soft_attn = soft_attn / np.sum(soft_attn[:, :len(word_indices)], axis=1, keepdims=True)
    soft_roll_attn = soft_roll_attn / np.sum(soft_roll_attn[:, :len(word_indices)], axis=1, keepdims=True)
    for l in range(12):
        if np.sum(soft_exp[l]) != 0:
            soft_exp[l] = soft_exp[l] / np.sum(soft_exp[l, :len(word_indices)])
    
    # hard head's explanations and updated saliencies
    hard_exp = np.zeros(soft_exp.shape, dtype=np.int8)
    hard_sal = np.zeros(soft_sal.shape, dtype=np.int8)
    hard_attn = np.zeros(soft_attn.shape, dtype=np.int8)
    hard_roll_attn = np.zeros(soft_roll_attn.shape, dtype=np.int8)
    for l in range(12):
        hard_sal[l] = (soft_sal[l] >= model.bert.encoder.ETA[l].numpy() / len(word_indices))
        hard_attn[l] = (soft_attn[l] >= model.bert.encoder.ETA[l].numpy() / len(word_indices))
        hard_roll_attn[l] = (soft_roll_attn[l] >= model.bert.encoder.ETA[l].numpy() / len(word_indices))
        if np.sum(soft_exp[l]) != 0:
            hard_exp[l] = (soft_exp[l] >= model.bert.encoder.ETA[l].numpy() / np.sum(soft_exp[l] > 1e-6))
    
    
    # appending to lists
    soft_head_explanations.append(soft_exp)
    hard_head_explanations.append(hard_exp)
    soft_roll_attentions.append(soft_roll_attn)
    hard_roll_attentions.append(hard_roll_attn)
    soft_saliencies.append(soft_sal)
    hard_saliencies.append(hard_sal)
    soft_attentions.append(soft_attn)
    hard_attentions.append(hard_attn)
    rationales.append(rationale)


# plt.rcParams.update({'font.size': 12})
def plot_show(x, Y, C, labels, linestyles, markers, linewidth, xlabel, ylabel, title, legend):
    for y, c, l, s, m in zip(Y, C, labels, linestyles, markers):
        plt.plot(x, y, c=c, label=l, linestyle=s, marker=m, linewidth=linewidth)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if legend:
        plt.legend()
    plt.title(title)
    plt.show()


# calculate word-level mAP
ap_sal = np.zeros((num_test_examples, 12))
ap_attn = np.zeros((num_test_examples, 12))
ap_roll_attn = np.zeros((num_test_examples, 12))
ap_headexp = np.zeros((num_test_examples, 12))
for ex in range(num_test_examples):
    bool_rationale = (rationales[ex] > 0).astype(int)
    for l in range(12):
        ap_sal[ex, l] = average_precision_score(bool_rationale, soft_saliencies[ex][l])
        ap_attn[ex, l] = average_precision_score(bool_rationale, soft_attentions[ex][l])
        ap_roll_attn[ex, l] = average_precision_score(bool_rationale, soft_roll_attentions[ex][l])
        ap_headexp[ex, l] = average_precision_score(bool_rationale, soft_head_explanations[ex][l])

mAP_sal = np.mean(ap_sal, axis=0)
mAP_attn = np.mean(ap_attn, axis=0)
mAP_roll_attn = np.mean(ap_roll_attn, axis=0)
mAP_headexp = np.mean(ap_headexp, axis=0)

plot_show(x=range(1, 13), Y=[mAP_sal, mAP_attn, mAP_roll_attn, mAP_headexp], 
          C=["green", "darkorange", "crimson", "navy"],
          labels=["Saliency", "Attention", "Attention Rollout", "CP"], 
          linestyles=['dashdot', 'dashdot', 'dashdot', 'solid'],
          markers=[5,5,5,'o'],
          linewidth=3,
          xlabel="Layer", ylabel="mAP", title="",
          legend=False)

# FPr
def compute_FPr_layerwise(score, R):
    fpr = np.zeros(12)
    for l in range(12):
        FP = 0.0
        N = 0
        for ex in range(num_test_examples):
            predicted_indicies = np.argwhere(score[ex][l] == 1)
            FP += np.sum(R[ex][predicted_indicies] == 0)
            N += np.sum(R[ex] == 0)
        fpr[l] = FP / N
    return fpr

fpr_sal = compute_FPr_layerwise(hard_saliencies, rationales)
fpr_attn = compute_FPr_layerwise(hard_attentions, rationales)
fpr_roll_attn = compute_FPr_layerwise(hard_roll_attentions, rationales)
fpr_headexp = compute_FPr_layerwise(hard_head_explanations, rationales)

plot_show(x=range(1, 13), Y=[fpr_sal, fpr_attn, fpr_roll_attn, fpr_headexp], 
          C=["green", "darkorange", "crimson", "navy"],
          labels=["Saliency", "Attention", "Attention Rollout", "CP"], 
          linestyles=['dashdot', 'dashdot', 'dashdot', 'solid'],
          markers=[5,5,5,'o'],
          linewidth=3,
          xlabel="Layer", ylabel="FPr", title="",
          legend=True)



import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 120

x=range(1, 13)
Y=[mAP_sal, mAP_attn, mAP_roll_attn, mAP_headexp]
C=["green", "darkorange", "crimson", "navy"]
labels=["Saliency", "Attention", "Attention Rollout", "CP"]
linestyles=['dashed', 'dashdot', 'dashdot', 'solid']
markers=['1','3','4','.']
linewidth=2
xlabel="Layer"
ylabel="mAP"
legend=True

fig, axs = plt.subplots(2, sharex=True, figsize=(4,5))
plt.subplots_adjust(hspace=0.05)

for y, c, l, s, m in zip(Y, C, labels, linestyles, markers):
    axs[0].plot(x, y, c=c, label=l, linestyle=s, marker=m, linewidth=linewidth)
# axs[0].set_xlabel(xlabel)
axs[0].set_ylabel(ylabel)
# axs[0].legend()

ylabel="FPr"
Y=[fpr_sal, fpr_attn, fpr_roll_attn, fpr_headexp]
for y, c, l, s, m in zip(Y, C, labels, linestyles, markers):
    axs[1].plot(x, y, c=c, label=l, linestyle=s, marker=m, linewidth=linewidth)
axs[1].set_xlabel(xlabel)
axs[1].set_ylabel(ylabel)
axs[1].legend()
plt.legend(fontsize=7)
plt.show()

# # TPr
# def compute_TPr_layerwise(score, R):
#     tpr = np.zeros(12)
#     for l in range(12):
#         TP = 0.0
#         P = 0
#         for ex in range(num_test_examples):
#             predicted_indicies = np.argwhere(score[ex][l] == 1)
#             TP += np.sum(R[ex][predicted_indicies] > 0)
#             P += np.sum(R[ex] > 0)
#         tpr[l] = TP / P
#     return tpr

# tpr_sal = compute_TPr_layerwise(hard_saliencies, rationales)
# tpr_attn = compute_TPr_layerwise(hard_attentions, rationales)
# tpr_headexp = compute_TPr_layerwise(hard_head_explanations, rationales)

# plot_show(x=range(1, 13), Y=[tpr_sal, tpr_attn, tpr_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="TPr", title="")


# # PLr
# def compute_PLr_layerwise(score, R):
#     plr = np.zeros(12)
#     for l in range(12):
#         TP = 0.0
#         FP = 0.0
#         P = 0
#         N = 0
#         for ex in range(num_test_examples):
#             predicted_indicies = np.argwhere(score[ex][l] == 1)
#             TP += np.sum(R[ex][predicted_indicies] > 0)
#             FP += np.sum(R[ex][predicted_indicies] == 0)
#             P += np.sum(R[ex] > 0)
#             N += np.sum(R[ex] == 0)
#         plr[l] = (TP / P) / (FP / N)

#     return plr

# plr_sal = compute_PLr_layerwise(hard_saliencies, rationales)
# plr_attn = compute_PLr_layerwise(hard_attentions, rationales)
# plr_headexp = compute_PLr_layerwise(hard_head_explanations, rationales)

# plot_show(x=range(1, 13), Y=[plr_sal, plr_attn, plr_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="PLr (Positive likelihood ratio)", title="")


# # PPV
# def compute_PPV_layerwise(score, R):
#     ppv = np.zeros(12)
#     for l in range(12):
#         TP = 0.0
#         PP = 0
#         for ex in range(num_test_examples):
#             predicted_indicies = np.argwhere(score[ex][l] == 1)
#             TP += np.sum(R[ex][predicted_indicies] > 0)
#             PP += len(predicted_indicies)
#         ppv[l] = TP / PP
#     return ppv

# ppv_sal = compute_PPV_layerwise(hard_saliencies, rationales)
# ppv_attn = compute_PPV_layerwise(hard_attentions, rationales)
# ppv_headexp = compute_PPV_layerwise(hard_head_explanations, rationales)

# plot_show(x=range(1, 13), Y=[ppv_sal, ppv_attn, ppv_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="PPV", title="")

# # FDR
# def compute_FDR_layerwise(score, R):
#     fdr = np.zeros(12)
#     for l in range(12):
#         FP = 0.0
#         PP = 0
#         for ex in range(num_test_examples):
#             predicted_indicies = np.argwhere(score[ex][l] == 1)
#             FP += np.sum(R[ex][predicted_indicies] == 0)
#             PP += len(predicted_indicies)
#         fdr[l] = FP / PP
#     return fdr

# fdr_sal = compute_FDR_layerwise(hard_saliencies, rationales)
# fdr_attn = compute_FDR_layerwise(hard_attentions, rationales)
# fdr_headexp = compute_FDR_layerwise(hard_head_explanations, rationales)

# plot_show(x=range(1, 13), Y=[fdr_sal, fdr_attn, fdr_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="FDR", title="")

# # PPV / FDR
# plot_show(x=range(1, 13), Y=[ppv_sal/fdr_sal, ppv_attn/fdr_attn, ppv_headexp/fdr_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="PPV/FDR", title="")

# # calculate coverage of evidences
# def compute_coverage_1_layerwise(score, R):
#     coverage = np.zeros(12)
#     for l in range(12):
#         num_total_evidences = 0
#         num_true_predicted_evidences = 0
#         for ex in range(num_test_examples):
#             true_predicted_evidences = np.unique(R[ex][np.argwhere(score[ex][l])])
#             num_true_predicted_evidences += len(true_predicted_evidences[true_predicted_evidences != 0])
#             total_evidences = np.unique(R[ex])
#             num_total_evidences += len(total_evidences[total_evidences != 0])

        
#         coverage[l] = num_true_predicted_evidences / num_total_evidences

#     return coverage

# def compute_coverage_2_layerwise(score, R):
#     coverage = np.zeros(12)
#     for l in range(12):
#         num_total_rationale_tokens = 0
#         num_true_predicted_tokens = 0
#         num_false_predicted_tokens = 0
#         for ex in range(num_test_examples):
#             true_predicted_evidences = np.squeeze(R[ex][np.argwhere(score[ex][l])])
#             num_true_predicted_tokens += np.sum(true_predicted_evidences != 0)
#             num_false_predicted_tokens += np.sum(true_predicted_evidences == 0)
#             total_evidences = np.unique(R[ex])
#             num_total_rationale_tokens += len(total_evidences[total_evidences != 0])

        
#         coverage[l] = num_true_predicted_tokens / (num_total_rationale_tokens + num_false_predicted_tokens)

#     return coverage

# coverage_1_sal = compute_coverage_1_layerwise(hard_saliencies, rationales)
# coverage_1_attn = compute_coverage_1_layerwise(hard_attentions, rationales)
# coverage_1_headexp = compute_coverage_1_layerwise(hard_head_explanations, rationales)

# coverage_2_sal = compute_coverage_2_layerwise(hard_saliencies, rationales)
# coverage_2_attn = compute_coverage_2_layerwise(hard_attentions, rationales)
# coverage_2_headexp = compute_coverage_2_layerwise(hard_head_explanations, rationales)

# plot_show(x=range(1, 13), Y=[coverage_1_sal, coverage_1_attn, coverage_1_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="Coverage 1", title="")
# plot_show(x=range(1, 13), Y=[coverage_2_sal, coverage_2_attn, coverage_2_headexp], labels=["Saliency", "Attention", "Head's Explanations"], xlabel="Layer", ylabel="Coverage 2", title="")

# # np.sum([3, 5, 7, 2]) / (np.sum([3, 6, 5, 7, 2]) + 50)
# #num dilute rationale tokens in predicted / all rationale tokens in labels + other predicted tokens