import numpy as np
import tensorflow as tf
import pickle
import _pickle as cPickle
from os import path
from google_similarity import gram_linear, cka
import random
from scipy.stats import norm
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import cosine, squareform
from transformers import TFBertModel, BertTokenizer#, TFGPT2Model, GPT2Tokenizer, TFXLNetModel, XLNetTokenizer, TFT5EncoderModel, T5Tokenizer

#Note: google_similarity is copied from the notebook at: https://colab.research.google.com/github/google-research/google-research/blob/master/representation_similarity/Demo.ipynb
"""
@inproceedings{pmlr-v97-kornblith19a,
  title = {Similarity of Neural Network Representations Revisited},
  author = {Kornblith, Simon and Norouzi, Mohammad and Lee, Honglak and Hinton, Geoffrey},
  booktitle = {Proceedings of the 36th International Conference on Machine Learning},
  pages = {3519--3529},
  year = {2019},
  volume = {97},
  month = {09--15 Jun},
  publisher = {PMLR}}
"""

MODEL_ID_BERT = 'bert-base-cased'
MODEL_BERT = TFBertModel.from_pretrained(MODEL_ID_BERT, output_hidden_states = True, output_attentions = False)
MODEL_TOKENIZER_BERT = BertTokenizer.from_pretrained(MODEL_ID_BERT)

pleasant = ['caress','freedom','health','love','peace','cheer','friend','heaven','loyal','pleasure','diamond','gentle','honest','lucky','rainbow','diploma','gift','honor','miracle','sunrise','family','happy','laughter','paradise','vacation']
unpleasant = ['abuse','crash','filth','murder','sickness','accident','death','grief','poison','stink','assault','disaster','hatred','pollute','tragedy','divorce','jail','poverty','ugly','cancer','kill','rotten','vomit','agony','prison']
pleasant_2 = ['joy','love','peace','wonderful','pleasure','friend','laughter','happy']
unpleasant_2 = ['agony','terrible','horrible','nasty','evil','war','awful','failure']
career = ['executive','management','professional','corporation','salary','office','business','career']
domestic = ['home','parents','children','family','cousins','marriage','wedding','relatives']
math = ['math','algebra','geometry','calculus','equations','computation','numbers','addition']
art = ['poetry','art','dance','literature','novel','symphony','drama','sculpture']
science = ['science','technology','physics','chemistry','Einstein','NASA','experiment','astronomy']
art_2 = ['poetry','art','Shakespeare','dance','literature','novel','symphony','drama']

def std_deviation(J):
    mean_J = np.mean(J)
    var_J = sum([(j - mean_J)**2 for j in J])
    return (np.sqrt(var_J / (len(J)-1)))

def create_permutation(a, b):
    permutation = random.sample(a+b, len(a+b))
    return permutation[:int(len(permutation)*.5)], permutation[int(len(permutation)*.5):]

def load_term_object(target, directory):

    with open(path.join(directory, target + '-object.pkl'), 'rb') as object_reader:
        return pickle.load(object_reader)

def cosine_similarity(a, b):
    return ((np.dot(a, b)) / (np.sqrt(np.dot(a, a)) * np.sqrt(np.dot(b, b))))

def calculate_self_similarity(term, subtoken_type, layers):

    self_similarity_by_layer = []
    num_contexts = len(term.contexts)

    cwe_map = term.subtoken_cwe_map[subtoken_type]

    for layer in layers:

        cwe_current = np.array(cwe_map[layer])
        cos_dists = pairwise_distances(cwe_current, metric="cosine")
        cos_sims = 1 - cos_dists

        cos_sim_sum = np.sum(np.triu(cos_sims)) - np.trace(cos_sims)
        comparisons = np.sum(range(len(num_contexts-1)))
        self_similarity = cos_sim_sum / comparisons

        self_similarity_by_layer.append(self_similarity)

    return self_similarity_by_layer

def measure_linear_cka_from_initial(term, subtoken_type, layers):

    cwe_map = term.subtoken_cwe_map[subtoken_type]

    cka_by_layer = []

    first_layer_vectors = cwe_map[0]

    for layer in layers:
        second_layer_vectors = cwe_map[layer]
        first_matrix = np.array(first_layer_vectors).T
        second_matrix = np.array(second_layer_vectors).T
        matrix_cka = cka(gram_linear(first_matrix), gram_linear(second_matrix))

        cka_by_layer.append(matrix_cka)

    return cka_by_layer

def SV_WEAT_permutation_test(w, A, B, test_stat, permutations):

    distribution = []

    for _ in range(permutations):
        j, k = create_permutation(A, B)
        m = SV_WEAT_association(w, j, k)
        distribution.append(m)
    
    dist_mean = np.mean(distribution)
    dist_dev = std_deviation(distribution)

    p_value = (1 - norm.cdf(test_stat, dist_mean, dist_dev))

    return p_value

def SV_WEAT(target_w, Attr_A, Attr_B, permutations):

    effect_size, variance = SV_WEAT_effect_size(target_w, Attr_A, Attr_B)
    test_statistic = SV_WEAT_association(target_w, Attr_A, Attr_B)
    p_value = SV_WEAT_permutation_test(target_w, Attr_A, Attr_B, test_statistic, permutations)

    return effect_size#, variance, p_value

def SV_WEAT_effect_size(w, A, B):

    w_norm = np.sqrt(np.dot(w, w))

    distribution_a = [(np.dot(a, w) / (w_norm * np.sqrt(np.dot(a, a)))) for a in A]
    distribution_b = [(np.dot(b, w) / (w_norm * np.sqrt(np.dot(b, b)))) for b in B]    
    joint_distribution = distribution_a + distribution_b
    
    return ((np.mean(distribution_a) - np.mean(distribution_b)) / std_deviation(joint_distribution)), std_deviation(joint_distribution)

def SV_WEAT_association(w, A, B):
    w_norm = np.sqrt(np.dot(w, w))
    mean_A_w = np.mean([np.dot(w, a) / (w_norm * np.sqrt(np.dot(a, a))) for a in A])
    mean_B_w = np.mean([np.dot(w, b) / (w_norm * np.sqrt(np.dot(b, b))) for b in B])
    return mean_A_w - mean_B_w

def get_embeddings(term, context, model, tokenizer, mean_pool = True):

    layers = [i for i in range(0, 13)]
    encoding = tokenizer.encode(term, add_special_tokens = False, add_prefix_space=True)
    encoded_context = tokenizer.encode(context, add_special_tokens=True)
    positions = []

    #Try-except prevents model from failing if previous tokenizers failed to screen a context without the term in it
    if len(encoding) == 1:
        positions = [encoded_context.index(encoding[0])]

    #Search for sequence of encodings and get positions
    else:
        for i in range(len(encoded_context)):
            if encoding[0] == encoded_context[i] and encoding[1:] == encoded_context[i+1:i+len(encoding)]:
                positions = [j for j in range(i, i + len(encoding))]

    inputs = tokenizer(context, return_tensors = 'tf')
    output_ = model(inputs)
    np.squeeze(output_)

    embeddings = []

    for layer in layers:
        
        target_embedding = []
        for position in positions:
            sub_embedding = output_[-1][layer][0][position]
            target_embedding.append(sub_embedding)
        
        #Get mean of Subtokens
        if mean_pool:
            cwe_arr = np.array([i for i in target_embedding])
            mean_cwe = np.mean(cwe_arr, axis = 0)
            tensor_cwe = tf.convert_to_tensor(mean_cwe)
            target_embedding = tensor_cwe

        embeddings.append(target_embedding)

    return embeddings

#Example Run

LOAD_PATH = ''
WRITE_PATH = ''

with open('names.pkl', 'rb') as pkl_reader:
    names = cPickle.load(pkl_reader)

write_model = 'bert'
model_id = 'tf_bert_model'
layers = [i for i in range(13)]
bleached_name_context = 'This person\'s name is X'

model_similarity = {}
model_cka = {}
model_embeddings = {}

for name in names:
    with open(path.join(LOAD_PATH, f'{name}_{model_id}-object.pkl'), 'rb') as pkl_reader:
        name_term = cPickle.load(pkl_reader)
    
    name_term.form_subtoken_cwes('Concat')

    self_similarity = calculate_self_similarity(name_term, 'Concat', layers)
    model_similarity[name] = self_similarity

    linear_cka = measure_linear_cka_from_initial(name_term, 'Concat', layers)
    model_cka[name] = linear_cka

    bleached_context = bleached_name_context.replace('X', name)
    bleached_embeddings = get_embeddings(name, bleached_context, MODEL_BERT, MODEL_TOKENIZER_BERT)
    model_embeddings[name] = bleached_embeddings

with open(path.join(WRITE_PATH, f'{write_model}_similarity.pkl'), 'wb') as pkl_writer:
    cPickle.dump(model_similarity, pkl_writer)

with open(path.join(WRITE_PATH, f'{write_model}_cka.pkl'), 'wb') as pkl_writer:
    cPickle.dump(model_cka, pkl_writer)

with open(path.join(WRITE_PATH, f'{write_model}_bleached_embeddings.pkl'), 'wb') as pkl_writer:
    cPickle.dump(model_embeddings, pkl_writer)