import os
import pickle as pkl
from collections import Counter
from itertools import product
from itertools import zip_longest

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import stats
from scipy.stats import wasserstein_distance
from matplotlib.colors import LinearSegmentedColormap, to_hex

directories = ["/path/to/inputs/layer_sensitivities/baseline", "/path/to/inputs/layer_sensitivities/dpr_full_train"]
# directories = ["/path/to/inputs/layer_sensitivities/dpr_full_train", "/path/to/inputs/layer_sensitivities/dpr_full_train_edited_added"]
# directories = ["/path/to/inputs/layer_sensitivities/dpr_full_train_edited_added", "/path/to/inputs/layer_sensitivities/dpr_full_train_edited_removed"]

train_data = False
# (_, cluster_labels, _, _, _ := pkl.load(open("training_data_cluster_custom_save.pkl", "rb"))) if train_data else (cluster_labels := pkl.load(open("training_data_cluster_bertopic_val_clusters.pkl", "rb")))
# cluster_labels_passages = np.repeat(cluster_labels, 13)
all_activations = []
all_top_neurons = []
all_neuron_counters = []
neuron_cluster_counter = []
layer_types = ["intermediate_knowledge", "output_knowledge"]
# layer_types = ["intermediate_knowledge"]
model_type = "question_model"

for layer_type in layer_types:
    for directory in directories:
        print(directory, layer_type)
        activations = []
        activation_file_questions = []
        files = [file for file in os.listdir(directory) if layer_type in file and model_type in file and ("train_data" in file if train_data else not "train_data" in file) and "rank" not in file and ".pkl" in file]
        files = sorted(files, key=lambda x: int(x.split("_")[-1][:-4]))
        for file in files:
                activation_file = pkl.load(open(f"{directory}/{file}", "rb"))
                activations.append(activation_file[0])
                activation_file_questions.append(activation_file[1])

        activations = [torch.vstack([activations[j][i] for j in range(len(activations))]) for i in range(len(activations[0]))]
        activation_index = np.argwhere(np.array(activation_file_questions[-1]) == 'when did mozart compose his first piece of music')[0][0]
        activations = [activations[activation_index]]
        [print(i, activation_questions[activation_index]) for i, activation_questions in enumerate(activation_file_questions)]
        # activations = [activations[149]]
        # activations = [activations[2791]]
        all_activations.append(activations)

        top_neurons = []
        for i in range(len(activations)):
            example_max = torch.max(activations[i])
            example_close_to_max_loc = torch.argwhere(torch.abs(activations[i]) >= example_max * 0.1)
            top_neurons.append(example_close_to_max_loc)
        all_top_neurons.append(top_neurons)

        neuron_counter = Counter()
        for neurons in top_neurons:
            # Convert each tensor to a list of tuples
            rows = map(tuple, neurons.tolist())
            # Update the counter with the rows from this tensor
            neuron_counter.update(rows)
        all_neuron_counters.append(neuron_counter)

    all_top_neurons_np = np.array(all_top_neurons, dtype="object")
    # for j in range(len(directories)):
    #     neuron_cluster_counter.append([])
    #     for i in range(30):
    #         current_top_neurons = all_top_neurons_np[j, cluster_labels==i if len(activations)==len(cluster_labels) else cluster_labels_passages==i].tolist()
    #         neuron_counter = Counter()
    #         for neurons in current_top_neurons:
    #             # Convert each tensor to a list of tuples
    #             rows = map(tuple, neurons.tolist())
    #             # Update the counter with the rows from this tensor
    #             neuron_counter.update(rows)
    #         neuron_cluster_counter[j].append(neuron_counter)

print(all_top_neurons[0][0].shape[0] + all_top_neurons[2][0].shape[0], all_top_neurons[0][0].shape[0], all_top_neurons[2][0].shape[0])
print(all_top_neurons[1][0].shape[0] + all_top_neurons[3][0].shape[0], all_top_neurons[1][0].shape[0], all_top_neurons[3][0].shape[0])
import ipdb; ipdb.set_trace()
def print_counter_results_2(counter1, counter2, num):
    for (row1, count1), (row2, count2) in zip_longest(counter1.most_common(num), counter2.most_common(num), fillvalue=('N/A', 0)):
        print(f"{row1}: {count1}\t\t{row2}: {count2}")

def print_counter_results(counter1, num):
    for (row1, count1) in counter1.most_common(num):
        print(f"{row1}: {count1}")
    return counter1.most_common(num)

def rows2yaml(rows):
    most_common_layer = {i: [] for i in range(12)}
    for i in rows:
        most_common_layer[i[0]].append(i[1])
    for layer, neurons in most_common_layer.items():
        print(f"layer_{layer}: {', '.join([str(i) for i in neurons])}")


def counter2yaml(counter):
    counter_by_layer = sorted([i for i in counter], key=lambda x: x[0])
    remaining_counter_sorted_by_layer = [[i[1] for i in counter_by_layer if i[0] == j] for j in range(12)]
    for layer, neurons in enumerate(remaining_counter_sorted_by_layer):
        print(f"layer_{layer}: {', '.join([str(i) for i in sorted(neurons)])}")

def exctop(counter, num):
    # Get the top num elements
    top = counter.most_common(num)
    top_counter = Counter(dict(top))
    # Create a set of all possible tuples
    all_values = set(product(range(12), range(3072)))
    # Create a counter from all possible values with a default count of 1
    all_values_counter = Counter({value: 1 for value in all_values})
    # Subtract the original counter and the top counter from the all_values_counter
    remaining_counter = all_values_counter - top_counter
    return remaining_counter


def exclusive_most_common(counter1, counter2, num1, num2, verbose=True):
    # Get the top 50 rows from each counter
    top50_counter1 = counter1.most_common(num1)
    top50_counter2 = counter2.most_common(num2)

    # Convert to dictionaries for easy lookup
    dict_counter1 = dict(top50_counter1)
    dict_counter2 = dict(top50_counter2)

    # Find the rows that are in the top 50 of counter1 but not in the top 50 of counter2
    rows_in_counter1_not_counter2 = set(dict_counter1.keys()) - set(dict_counter2.keys())

    # Sort the rows by count in descending order
    rows_in_counter1_not_counter2 = sorted(rows_in_counter1_not_counter2, key=dict_counter1.get, reverse=True)

    # Print the results
    if verbose:
        print("Rows in the top 50 of counter1 but not in the top 50 of counter2:")
        for row in rows_in_counter1_not_counter2:
            print(f"Row: {row}, Count: {dict_counter1[row]}")
    return [(i, dict_counter1[i]) for i in rows_in_counter1_not_counter2]

# Kullback-Leibler Divergence
def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

# Jensen-Shannon Divergence
def js_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

# Total Variation Distance
def total_variation_distance(p, q):
    return 0.5 * np.sum(np.abs(p - q))

# Wasserstein metric
def wasserstein_metric(p, q):
    return wasserstein_distance(p, q)

def plot_activations(activations, ax, temperature_scale=True, cluster_num=None):
    if cluster_num is None:
        activations = torch.stack(activations).mean(dim=0)
    else:
        activations = torch.stack(activations)[cluster_labels==cluster_num if len(activations)==len(cluster_labels) else cluster_labels_passages==cluster_num].mean(dim=0)
    flattened = activations.flatten()
    percentiles = [stats.percentileofscore(flattened, a, 'rank') for a in flattened]
    matrix_percentiles = np.reshape(percentiles, (12, 3072))
    if temperature_scale:
        cmap = colors.LinearSegmentedColormap.from_list(
            "custom",
            [(0, "blue"), (0.10, "white"), (0.90, "white"), (1, "red")]
        )
        img = ax.matshow(matrix_percentiles, aspect="auto", cmap=cmap)
        percentile_contents = []
        for i in range(len(matrix_percentiles)):
            percentile_contents.append((round(len(np.argwhere(matrix_percentiles[i] < 10)) / 3072, 2), round(len(np.argwhere(matrix_percentiles[i] > 90)) / 3072, 2)))
        ticks = [f"{i}, {j}" for i, j in enumerate(percentile_contents)]
        ax.set_yticks(np.arange(len(ticks)))
        ax.set_yticklabels(ticks)
    else:
        img = ax.matshow(matrix_percentiles, aspect="auto")
    return img

def cluster_percentile_overlap(activations):
    import seaborn as sns
    clusters_activation_percentiles = []
    for cluster_num in range(30):
        cluster_activations = torch.stack(activations)[cluster_labels == cluster_num if len(activations) == len(cluster_labels) else cluster_labels_passages == cluster_num].mean(dim=0)
        flattened = cluster_activations.flatten()
        percentiles = [stats.percentileofscore(flattened, a, 'rank') for a in flattened]
        cluster_activation_percentiles = np.reshape(percentiles, (12, 3072))
        clusters_activation_percentiles.append(cluster_activation_percentiles)
    fig, axs = plt.subplots(2, 1)
    top_percentile_each_cluster = [np.argwhere(i[-1] > 90) for i in clusters_activation_percentiles]
    top_percentile_heatmap = [[len(np.intersect1d(i, j))/len(i) for j in top_percentile_each_cluster] for i in top_percentile_each_cluster]
    top_percentile_ticks = [f"{i} ({j.shape[0]})" for i, j in enumerate(top_percentile_each_cluster)]
    sns.heatmap(top_percentile_heatmap, ax=axs[0], annot=True, vmin=0, vmax=1, xticklabels=top_percentile_ticks, yticklabels=top_percentile_ticks)
    labels = axs[0].get_yticklabels()
    [i.set_weight("bold") for i in labels]
    axs[0].set_title("Overlap of Top Percentile Neurons in Layer 11 (Denominator in Bold)")

    bottom_percentile_each_cluster = [np.argwhere(i[-1] < 10) for i in clusters_activation_percentiles]
    bottom_percentile_heatmap = [[len(np.intersect1d(i, j))/len(i) for j in bottom_percentile_each_cluster] for i in bottom_percentile_each_cluster]
    bottom_percentile_ticks = [f"{i} ({j.shape[0]})" for i, j in enumerate(bottom_percentile_each_cluster)]
    sns.heatmap(bottom_percentile_heatmap, ax=axs[1], annot=True, vmin=0, vmax=1, xticklabels=bottom_percentile_ticks, yticklabels=bottom_percentile_ticks)
    labels = axs[1].get_yticklabels()
    [i.set_weight("bold") for i in labels]
    axs[1].set_title("Overlap of Bottom Percentile Neurons in Layer 11 (Denominator in Bold)")
    plt.subplots_adjust(hspace=0.3, top=0.9)
    plt.show()
    return clusters_activation_percentiles


def align_layers(intermediate, output):
    # Repeat the output activations four times to match the intermediate activations width
    repeated_output = np.repeat(output, 4, axis=1)

    # Ensure the repeated output is the same width as the intermediate activations
    if repeated_output.shape[1] != intermediate.shape[1]:
        # In case there is a mismatch, adjust the size
        difference = intermediate.shape[1] - repeated_output.shape[1]
        repeated_output = np.pad(repeated_output, ((0, 0), (0, difference)), 'constant')

    # Stack the layers alternately
    combined = np.vstack(tuple(np.vstack((intermediate[i, np.newaxis], repeated_output[i, np.newaxis])) for i in range(intermediate.shape[0])))
    return combined


import seaborn as sns


pretrained_output_counter_values = np.zeros([12, 768])
dpr_trained_output_counter_values = np.zeros([12, 768])
for i in all_neuron_counters[2]:
    pretrained_output_counter_values[i] = all_neuron_counters[2][i]
for i in all_neuron_counters[3]:
    dpr_trained_output_counter_values[i] = all_neuron_counters[3][i]

pretrained_intermediate_counter_values = np.zeros([12, 3072])
dpr_trained_intermediate_counter_values = np.zeros([12, 3072])
for i in all_neuron_counters[0]:
    pretrained_intermediate_counter_values[i] = all_neuron_counters[0][i]
for i in all_neuron_counters[1]:
    dpr_trained_intermediate_counter_values[i] = all_neuron_counters[1][i]

overall_dpr_trained = align_layers(dpr_trained_intermediate_counter_values, dpr_trained_output_counter_values)
overall_pretrained_trained = align_layers(pretrained_intermediate_counter_values, pretrained_output_counter_values)

# Get the "PRGn" colormap
prgn_cmap = plt.cm.get_cmap("PRGn")

# Extract specific colors from the "PRGn" colormap
# 0.0 (start), 0.5 (middle), 1.0 (end) refer to the normalized positions in the colormap
start_color = to_hex(prgn_cmap(0.0))  # Should be a shade of purple
mid_color = "white"  # Middle color as white
end_color = to_hex(prgn_cmap(0.8))  # Should be a shade of green

# Create a custom colormap with these specific colors
custom_cmap = LinearSegmentedColormap.from_list("custom_prgn", [mid_color, start_color, end_color])

vmin = min(np.min(overall_pretrained_trained), np.min(overall_dpr_trained))
vmax = max(np.max(overall_pretrained_trained), np.max(overall_dpr_trained))

fig, axs = plt.subplots(1, 2, figsize=(14, 6), constrained_layout=True)

sns.heatmap(overall_pretrained_trained, ax=axs[0], cmap=custom_cmap, vmin=vmin, vmax=vmax, cbar=False)  # cbar=False to only show color bar on one
sns.heatmap(overall_dpr_trained, ax=axs[1], cmap=custom_cmap, vmin=vmin, vmax=vmax)

# Calculate positions for the "Intermediate Layer" and "Output Layer" labels
num_layers = 24
tick_labels = []
for i in range(num_layers):
    # For even indices, label as "Intermediate", for odd, "Output"
    tick_labels.append(f"Intermediate Layer ({np.sum(overall_pretrained_trained[i] > 0)})" if i % 2 == 0 else f"Output Layer ({np.sum(overall_pretrained_trained[i] > 0) // 4})")

# Set the custom tick labels
axs[0].set_yticks(np.arange(num_layers) + 0.5)  # Position ticks in the center of each layer
axs[0].set_yticklabels(tick_labels)

axs[1].set_yticks(np.arange(num_layers) + 0.5)  # Adjust this if your data has a different number of ticks
globals().update(locals())
axs[1].set_yticklabels([f"({np.sum(overall_dpr_trained[i] > 0) if i % 2 == 0 else np.sum(overall_pretrained_trained[i] > 0) // 4})" for i in range(num_layers)])

for i in range(12):
    # The position is halfway between the 'Intermediate Layer' and 'Output Layer'
    position = (i * 2) + 1.11  # Adjusting this value positions the "0" between the layer labels
    axs[0].text(-0.365, position, str(i), transform=axs[0].get_yaxis_transform(), ha='right', va='center', fontsize='small')

# Add an outer square border to each subplot
for ax in axs:
    ax.patch.set_edgecolor('black')  # Set border color
    ax.patch.set_linewidth(1)  # Set border width

# Remove y-axis tick marks and labels on the left figure
# axs[1].set_yticks([])
# axs[1].set_yticklabels([])

# Remove x-axis tick marks and labels on both figures
for ax in axs:
    ax.tick_params(axis='x', which='both', length=0, labelbottom=False)

# Ensure the tick labels are readable
plt.setp(axs[0].get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
plt.setp(axs[1].get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
axs[0].set_title('Pretrained BERT')
axs[1].set_title('DPR-Trained BERT')
# axs[0].set_title('DPR-Trained BERT - fact added')
# axs[1].set_title('DPR-Trained BERT - fact removed')

# Set a super title for the figure
fig.suptitle('BERT Layer-wise Activations', fontsize=16)
plt.subplots_adjust(left=0.15)
plt.show()


# import IPython; IPython.embed()