import numpy as np
from transformers import GPT2Tokenizer, BertTokenizer, XLNetTokenizer, T5Tokenizer
from scipy.stats import spearmanr
from matplotlib import pyplot as plt
import _pickle as cPickle
from bias_similarity_functions import SV_WEAT

#Model constants
MODEL_ID_BERT = 'bert-base-cased'
MODEL_TOKENIZER_BERT = BertTokenizer.from_pretrained(MODEL_ID_BERT)

MODEL_ID_GPT2 = 'gpt2'
MODEL_TOKENIZER_GPT2 = GPT2Tokenizer.from_pretrained(MODEL_ID_GPT2)

MODEL_ID_XLNET = 'xlnet-base-cased'
MODEL_TOKENIZER_XLNET = XLNetTokenizer.from_pretrained(MODEL_ID_XLNET)

MODEL_ID_T5 = 't5-base'
MODEL_TOKENIZER_T5 = T5Tokenizer.from_pretrained(MODEL_ID_T5)

#Full WEAT
pleasant = ['caress','freedom','health','love','peace','cheer','friend','heaven','loyal','pleasure','diamond','gentle','honest','lucky','rainbow','diploma','gift','honor','miracle','sunrise','family','happy','laughter','paradise','vacation']
unpleasant = ['abuse','crash','filth','murder','sickness','accident','death','grief','poison','stink','assault','disaster','hatred','pollute','tragedy','divorce','jail','poverty','ugly','cancer','kill','rotten','vomit','agony','prison']
pleasant_2 = ['joy','love','peace','wonderful','pleasure','friend','laughter','happy']
unpleasant_2 = ['agony','terrible','horrible','nasty','evil','war','awful','failure']
career = ['executive','management','professional','corporation','salary','office','business','career']
domestic = ['home','parents','children','family','cousins','marriage','wedding','relatives']
mathematics = ['math','algebra','geometry','calculus','equations','computation','numbers','addition']
art = ['poetry','art','dance','literature','novel','symphony','drama','sculpture']
science = ['science','technology','physics','chemistry','Einstein','NASA','experiment','astronomy']
art_2 = ['poetry','art','Shakespeare','dance','literature','novel','symphony','drama']

#Read in name data
name_data = []
yob_data = []
final_names = []

with open('D:\\datasets\\firstnames.csv', 'r', encoding='utf8') as name_doc:
    name_doc.readline()
    for line in name_doc:
        row = line.strip().split(',')
        row[0] = row[0].capitalize()
        row[1:] = [float(i) for i in row[1:]]
        name_data.append(row)

with open('D:\\datasets\\yob1990.txt', 'r', encoding='utf8') as yob_doc:
    for line in yob_doc:
        row = line.strip().split(',')
        row[0] = row[0].capitalize()
        row[2] = int(row[2])
        yob_data.append(row)

race_dict = {0: 'Hispanic', 1: 'White', 2: 'Black', 3: 'Asian', 4: 'Native_American', 5: 'Mixed_Race'}

first_names = [i[0] for i in name_data]
yob_names = [i[0] for i in yob_data]
final_names = [i for i in first_names if i in yob_names]

#Parse 1990 census data, maintaining overall frequency, assigning gender label to majority gender
yob_dict = {}

for name_row in yob_data:
    if name_row[0] in yob_dict:
        if yob_dict[name_row[0]][1] > name_row[2]:
            yob_dict[name_row[0]][1] += name_row[2]
        else:
            name_row[2] += yob_dict[name_row[0]][1]
            yob_dict[name_row[0]] = name_row[1:]
    else:
       yob_dict[name_row[0]] = name_row[1:] 

name_dict = {i[0]:i[1:] for i in name_data}

#Combine data from datasets
final_name_list = []

for name in final_names:

    current_data = [name]

    yob_data = yob_dict[name]
    current_data.extend(yob_data)
    
    name_data = name_dict[name]
    race_target = name_data.index(max(name_data))
    current_data.append(race_dict[race_target])

    final_name_list.append(current_data)

final_name_list = sorted(final_name_list, key = lambda x: x[2], reverse = True)
final_name_dict = {i[0]:i[1:] for i in final_name_list}
final_names = [i[0] for i in final_name_list]
name_frequencies = [i[2] for i in final_name_list]

with open('D:\\names.pkl', 'wb') as pkl_writer:
    cPickle.dump(final_names, pkl_writer)

#Set model parameters
CURRENT_TOKENIZER = MODEL_TOKENIZER_BERT
MODEL_WRITE = 'bert'
CHART_MODEL = 'BERT'
CURRENT_MODEL_ID = 'tf_bert_model'
TARGET_LAYER = 9

#Set bias parameters
BIAS_GROUP = 'Minority'
A_TARGET = pleasant
B_TARGET = unpleasant

#Load in previously harvested embeddings from bleached sentence templates
name_vector_dict = {}
mean_cka_dict = {}
A = []
B = []

for name in final_names:

    with open(f'D:\\lexicon_terms\\bleached_terms\\{MODEL_WRITE}\\mean\\{name}_{MODEL_WRITE}_bleached.pkl', 'rb') as pkl_file:
        embeddings = cPickle.load(pkl_file)
        name_vector_dict[name] = np.squeeze(np.array(embeddings[TARGET_LAYER]))

for name in A_TARGET:

    with open(f'D:\\lexicon_terms\\bleached_terms\\{MODEL_WRITE}\\mean\\{name}_{MODEL_WRITE}_bleached.pkl', 'rb') as pkl_file:
        embeddings = cPickle.load(pkl_file)
        A.append(np.squeeze(np.array(embeddings[TARGET_LAYER])))

for name in B_TARGET:

    with open(f'D:\\lexicon_terms\\bleached_terms\\{MODEL_WRITE}\\mean\\{name}_{MODEL_WRITE}_bleached.pkl', 'rb') as pkl_file:
        embeddings = cPickle.load(pkl_file)
        B.append(np.squeeze(np.array(embeddings[TARGET_LAYER])))

name_associations = []
name_associations_dict = {}

for name in final_names:
    target_vector = name_vector_dict[name]
    association = SV_WEAT(target_vector, A, B)
    name_associations.append(association)
    name_associations_dict[name] = association

with open(f'D:\\1k_files\\{MODEL_WRITE}_dispersion.pkl', 'rb') as pkl_reader:
    dispersion_dict = cPickle.load(pkl_reader)

with open(f'D:\\1k_files\\{MODEL_WRITE}_cka.pkl', 'rb') as pkl_reader:
    concat_cka_dict = cPickle.load(pkl_reader)

with open(f'D:\\1k_files\\{MODEL_WRITE}_similarity.pkl', 'rb') as pkl_reader:
    similarity_dict = cPickle.load(pkl_reader)

with open(f'D:\\{MODEL_WRITE}_name_frequencies.pkl', 'rb') as pkl_reader:
    name_frequencies_dict = cPickle.load(pkl_reader)

final_name_associations = [name_associations_dict[name] for name in final_names]
final_name_similarities = [similarity_dict[name][TARGET_LAYER] for name in final_names]
concat_ckas = [concat_cka_dict[name][TARGET_LAYER] for name in final_names]
corpus_name_frequencies = [name_frequencies_dict[name] for name in final_names]

#Obtain correlation of name frequency and similarity indices
self_sim_correlation = spearmanr(corpus_name_frequencies, final_name_similarities)
cka_correlation = spearmanr(corpus_name_frequencies, concat_ckas)
print(self_sim_correlation)
print(cka_correlation)

#Use name corpus name frequencies or log of corpus name frequencies
name_frequencies = corpus_name_frequencies
#name_frequencies = [np.log2(i) for i in corpus_name_frequencies]

#Get a list of tuples
names_bundle = zip(final_names, final_name_associations, name_frequencies, concat_ckas, final_name_similarities)
names_bundle = [i for i in names_bundle]

#Dictionaries of names by group
all_bundle_dict = {}
all_bundle = [i for i in names_bundle]
all_bundle_dict['All'] = all_bundle

#Gender
male_bundle = [i for i in all_bundle if final_name_dict[i[0]][0] == 'M']
female_bundle = [i for i in all_bundle if final_name_dict[i[0]][0] == 'F']

all_bundle_dict['Male'] = male_bundle
all_bundle_dict['Female'] = female_bundle

#Race
black_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Black']
hispanic_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Hispanic']
asian_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Asian']
minority_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] != 'White']
white_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'White']

all_bundle_dict['Black'] = black_bundle
all_bundle_dict['Hispanic'] = hispanic_bundle
all_bundle_dict['Asian'] = asian_bundle
all_bundle_dict['Minority'] = minority_bundle
all_bundle_dict['White'] = white_bundle

#Intersectional
bm_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Black' and final_name_dict[i[0]][0] == 'M']
hm_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Hispanic' and final_name_dict[i[0]][0] == 'M']
am_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Asian' and final_name_dict[i[0]][0] == 'M']
mm_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] != 'White' and final_name_dict[i[0]][0] == 'M']
wm_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'White' and final_name_dict[i[0]][0] == 'M']

all_bundle_dict['BM'] = bm_bundle
all_bundle_dict['HM'] = hm_bundle
all_bundle_dict['AM'] = am_bundle
all_bundle_dict['MM'] = mm_bundle
all_bundle_dict['WM'] = wm_bundle

bf_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Black' and final_name_dict[i[0]][0] == 'F']
hf_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Hispanic' and final_name_dict[i[0]][0] == 'F']
af_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'Asian' and final_name_dict[i[0]][0] == 'F']
mf_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] != 'White' and final_name_dict[i[0]][0] == 'F']
wf_bundle = [i for i in all_bundle if final_name_dict[i[0]][-1] == 'White' and final_name_dict[i[0]][0] == 'F']

all_bundle_dict['BF'] = bf_bundle
all_bundle_dict['HF'] = hf_bundle
all_bundle_dict['AF'] = af_bundle
all_bundle_dict['MF'] = mf_bundle
all_bundle_dict['WF'] = wf_bundle

intersectional = ['BM', 'HM', 'AM', 'MM', 'WM', 'BF', 'HF', 'AF', 'MF', 'WF']
race = ['Black', 'Hispanic', 'Asian', 'Minority', 'White']
gender = ['Male', 'Female']

#Get median training corpus frequency by intersectional group
for group in intersectional:
    median_frequency = np.median([i[2] for i in all_bundle_dict[group]])
    print(group)
    print(median_frequency)

#Association of bias score with frequency
bias_group = all_bundle_dict[BIAS_GROUP]
bias_scores = [i[1] for i in bias_group]
bias_frequencies = [i[2] for i in bias_group]
bias_frequency_correlation = spearmanr(bias_frequencies, bias_scores)
print(bias_frequency_correlation)

#Handle tokenization
single_bundle_dict = {}
multi_bundle_dict = {}

all_num, single_num, multi_num = [],[],[]

for name_group in list(all_bundle_dict.keys()):

    single = [i for i in all_bundle_dict[name_group] if len(CURRENT_TOKENIZER.encode(i[0], add_special_tokens = False, add_prefix_space = True)) == 1]
    multi = [i for i in all_bundle_dict[name_group] if len(CURRENT_TOKENIZER.encode(i[0], add_special_tokens = False, add_prefix_space = True)) != 1]

    #Print percent singly tokenized by intersectional group
    print(name_group)
    print(len(single)/len(single + multi))

    single_bundle_dict[name_group] = single
    multi_bundle_dict[name_group] = multi

    if name_group in intersectional:
        single_num.append(len(single))
        multi_num.append(len(multi))
        all_num.append(len(single) + len(multi))    

single_associations = [i[1] for i in single_bundle_dict['All']]
single_frequencies = [i[2] for i in single_bundle_dict['All']]
single_ckas = [i[3] for i in single_bundle_dict['All']]
single_similarities = [i[4] for i in single_bundle_dict['All']]

multi_names = [i[0] for i in multi_bundle_dict['All']]
multi_associations = [i[1] for i in multi_bundle_dict['All']]
multi_frequencies = [i[2] for i in multi_bundle_dict['All']]
multi_ckas = [i[3] for i in multi_bundle_dict['All']]
multi_similarities = [i[4] for i in multi_bundle_dict['All']]

#Obtain correlation of tokenization with frequency
single_tokenized = [1 for i in single_frequencies]
multiple_tokenized = [0 for i in multi_frequencies]

tokenized_list = single_tokenized + multiple_tokenized
frequency_list = single_frequencies + multi_frequencies

tokenization_correlation = spearmanr(tokenized_list, frequency_list)
print(tokenization_correlation)

#Obtain mean of self-similarity and similarity to initial representation by tokenization    
print('similarity')
print(np.mean(single_similarities))
print(np.mean(multi_similarities))

print('cka')
print(np.mean(single_ckas))
print(np.mean(multi_ckas))

#Plot frequency vs. self-similarity (use log scale)
plt.scatter(multi_frequencies, multi_similarities, c = 'lightgray', marker = 'o', alpha = 0.5, label = 'Multiply Tokenized')
plt.gray()
plt.scatter(single_frequencies, single_similarities, c = 'gray', marker = '+', alpha = 0.5, label = 'Singly Tokenized')
plt.gray()
plt.xlabel('Training Corpus Name Frequency')
plt.ylabel('Contextualized Word Embedding Self-Similarity')
plt.xlim([0, 25])
plt.xticks([0,5,10,15,20,25],['0', '2⁵', '2¹⁰', '2¹⁵', '2²⁰', '2²⁵'])
#plt.ylim([0.5, 1.0])
plt.legend(loc='best')
plt.title(f'{CHART_MODEL} Frequency vs. Self-Similarity')
plt.gray()
plt.show()

#Plot frequency vs. CKA similarity to initial representation (use log scale)
plt.scatter(multi_frequencies, multi_ckas, c = 'lightgray', marker = 'o', alpha = 0.5, label = 'Multiply Tokenized')
plt.gray()
plt.scatter(single_frequencies, single_ckas, c = 'gray', marker = '+', alpha = 0.5, label = 'Singly Tokenized')
plt.gray()
plt.xlabel('Training Corpus Name Frequency')
plt.ylabel('Linear CKA Similarity to Initial Representation')
plt.xlim([0, 25])
plt.xticks([0,5,10,15,20,25],['0', '2⁵', '2¹⁰', '2¹⁵', '2²⁰', '2²⁵'])
plt.ylim([0.0, 0.5])
plt.legend(loc='best')
plt.title(f'{CHART_MODEL} Frequency vs. Similarity to Initial')
plt.gray()
plt.show()