import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# import json
# import argparse
import matplotlib
from numpy.core.fromnumeric import mean
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from numpy import dot
from numpy.linalg import norm
from nltk.corpus import wordnet as wn
# from scipy import stats
import pandas as pd
import torch
# import tensorflow as tf
# from tqdm import tqdm
from collections import defaultdict
from transformers import BertTokenizer, BertModel
from transformers import AutoModelForMaskedLM, AutoTokenizer
# import sys

# sys.path.append('sense_bert')
# import sensebert
# from sense_bert import tokenization, sensebert


# def parse_args():
# 	parser = argparse.ArgumentParser()
# 	parser.add_argument('--data', type=str, required=True,
# 						choices=['cp', 'ss',
# 							'sssb_gender', 'sssb_race', 'sssb_nationality'],
# 						help='Path to evaluation dataset.')
# 	parser.add_argument('--model', type=str, required=True,
# 						choices=['sense-base', 'sense-large',
# 								 'bert-base', 'bert-large'])
# 	# parser.add_argument('--method', type=str, required=True,
# 	#                     choices=['aul', 'cps', 'sss'])
# 	args = parser.parse_args()

# 	return args

def get_sk_lemma(sensekey):
	return sensekey.split('%')[0]

def plot_wat_dim():
	'''Effect of dimensionality'''
	'''lmms-1024, lmms-2048, lmms-2348'''
	x_dimensions = [1024, 2048, 2348]
	y_sense_values = [0.08, 0.1886, 0.4087]
	y_word_values =[0.4587, 0.4587, 0.5276]
	plt.plot(x_dimensions, y_sense_values, label='sense embeddings')
	plt.plot(x_dimensions, y_word_values, label='word embeddings')
	plt.xlabel('Dimensionality')
	plt.ylabel('Bias Scores')
	plt.title('WAT\nEffect of dimensionality')
	plt.legend()
	path = 'figures/wat-dimensionality.png'
	plt.savefig(path, format='png')
	print('Saved figure to %s ' % path)


def plot_weat_dim():
	'''Effect of dimension'''
	'''lmms-1024, lmms-2048, lmms-2348'''
	fig = plt.figure()
	x_dimensions = [1024, 2048, 2348]
	### the first elements need to be checked
	'''sense plot'''
	y_sense_FlowersvsInsects = [1.4936, 1.9104, 2]
	y_sense_InstrumentsvsWeapons = [1.34, 1.8368, 2]
	y_sense_MathvsArt = [0.336, 0.98, 1.8288]
	y_sense_SciencevsArt = [-0.1608, 0.7464, 1.6552]
	# y_sense_PhysicalvsMentalcondition = [0.2235, -0.4232, 0.6392]
	'''word plot'''
	y_word_FlowersvsInsects =[1.4291, 1.4688, 1.6263]
	y_word_InstrumentsvsWeapons = [1.0491, 1.2376, 1.4163]
	y_word_MathvsArt = [0.8696, 1.0912, 1.5186]
	y_word_SciencevsArt = [0.7638, 1.045, 1.3749]
	y_word_PhysicalvsMentalcondition = [-0.4396, -0.5912, 0.4168]

	ax1 = plt.subplot(1, 2, 1)
	ax1.plot(x_dimensions, y_sense_FlowersvsInsects)
	ax1.plot(x_dimensions, y_sense_InstrumentsvsWeapons)
	ax1.plot(x_dimensions, y_sense_MathvsArt)
	ax1.plot(x_dimensions, y_sense_SciencevsArt)
	# ax1.plot(x_dimensions, y_sense_PhysicalvsMentalcondition)
	ax1.set_xlabel('Dimension', fontsize=10)
	ax1.set_ylabel('Bias Scores', fontsize=10, labelpad=-5)
	ax1.set_title('Sense Embembeddings', fontsize=12)
	ax1.set_ylim([-0.5,2.0])

	ax2 = plt.subplot(1, 2, 2)
	ax2.plot(x_dimensions, y_word_FlowersvsInsects)
	ax2.plot(x_dimensions, y_word_InstrumentsvsWeapons)
	ax2.plot(x_dimensions, y_word_MathvsArt)
	ax2.plot(x_dimensions, y_word_SciencevsArt)
	# ax2.plot(x_dimensions, y_word_PhysicalvsMentalcondition)
	ax2.set_xlabel('Dimension', fontsize=10)
	ax2.set_ylabel('Bias Scores', fontsize=10, labelpad=-5)
	ax2.set_title('Word Embembeddings', fontsize=12)
	ax2.set_ylim([-0.5,2.0])

	# Shrink current axis's height by 10% on the bottom
	box1 = ax1.get_position()
	ax1.set_position([box1.x0, box1.y0 + box1.height * 0.15,
				 box1.width * 0.95, box1.height * 0.85])
	box2 = ax2.get_position()
	ax2.set_position([box2.x0, box2.y0 + box2.height * 0.15,
				 box2.width * 0.95, box2.height * 0.85])

	# Put a legend below axis   
	plt.legend(['Flowers vs Insects', 'Instruments vs Weapons', 'Math vs Art', 'Science vs Art'], loc='upper center', bbox_to_anchor=(-1.1, -0.15, 1.9, 0),
		  fancybox=True, shadow=True, ncol=2, mode="expand", fontsize=8.5)
	# fig.suptitle('WEAT Effect of dimensionality', fontsize=14)
	path = 'figures/weat-dimensionality.png'
	plt.savefig(path, format='png')
	print('Saved figure to %s ' % path)


def plot_sssb_dim():
	fig = plt.figure()
	x_dimensions = [1024, 2048, 2348]

	'''Sense plot'''
	y_sense_ethnicity_black = [3.076, 7.3429, 4.6378]
	y_sense_colour_black = [2.2175, 3.6947, 1.6394]
	y_sense_nationality = [3.9172, 8.9415, 8.225]
	y_sense_language = [3.8881, 8.3998, 7.0133]
	# y_sense_noun = [0.1297, 0.2125, 0.3856]
	# y_sense_verb = [0.0409, 0.0568, 0.2622]
	# '''word plot'''
	y_word_ethnicity_black = [3.1859, 6.6336, 5.3613]
	y_word_colour_black = [3.1782, 6.573, 5.347]
	y_word_nationality = [4.003, 8.8036, 7.7839]
	y_word_language = [3.9965, 8.7854, 7.7779]
	# y_word_noun = [0.0768, 0.132, 0.3387]
	# y_word_verb = [0.0775, 0.1308, 0.3386]

	ax1 = plt.subplot(1, 2, 1)
	ax1.plot(x_dimensions, y_sense_ethnicity_black)
	ax1.plot(x_dimensions, y_sense_colour_black)
	ax1.plot(x_dimensions, y_sense_nationality)
	ax1.plot(x_dimensions, y_sense_language)
	# ax1.plot(x_dimensions, y_sense_noun)
	# ax1.plot(x_dimensions, y_sense_verb)
	ax1.set_xlabel('Dimensionality', fontsize=10)
	ax1.set_ylabel('Bias Scores', fontsize=10)
	ax1.set_title('Sense Embembeddings', fontsize=12)

	
	ax2 = plt.subplot(1, 2, 2)
	ax2.plot(x_dimensions, y_word_ethnicity_black)
	ax2.plot(x_dimensions, y_word_colour_black)
	ax2.plot(x_dimensions, y_word_nationality)
	ax2.plot(x_dimensions, y_word_language)
	# ax2.plot(x_dimensions, y_word_noun)
	# ax2.plot(x_dimensions, y_word_verb)
	ax2.set_xlabel('Dimensionality', fontsize=10)
	ax2.set_ylabel('Bias Scores', fontsize=10)
	ax2.set_title('Word Embembeddings', fontsize=12)

	# Shrink current axis's height by 10% on the bottom
	box1 = ax1.get_position()
	ax1.set_position([box1.x0, box1.y0 + box1.height * 0.15,
				 box1.width, box1.height * 0.85])
	box2 = ax2.get_position()
	ax2.set_position([box2.x0, box2.y0 + box2.height * 0.15,
				 box2.width, box2.height * 0.85])

	# Put a legend below axis   
	plt.legend(['ethnicity: black', 'coulour: black', 'nationality', 'language', 'noun', 'verb'], loc='upper center', bbox_to_anchor=(-1.1, -0.13, 2.0, 0),
		  fancybox=True, shadow=True, ncol=3, mode="expand")
	fig.suptitle('SSSB Effect of dimensionality', fontsize=14)
	path = 'figures/sssb-dimensionality.png'
	fig.savefig(path, format='png')
	print('Saved figure to %s ' % path)


def cosine_similarity(a, b):
	cos_sim = dot(a, b)/(norm(a)*norm(b))
	return cos_sim


def get_gender_vector(WE):
	male_words = []
	with open("./data/male_word_file.txt") as male_file:
		for line in male_file:
			male_words.append(line.strip())
	female_words = []
	with open("./data/female_word_file.txt") as female_file:
		for line in female_file:
			female_words.append(line.strip())
	gender_pairs = list(zip(male_words, female_words))
	gender_vects = []
	
	for (male, female) in gender_pairs:
		'''First sense'''
		# male_synset = wn.synsets(male)
		# female_synset = wn.synsets(female)
		# if len(male_synset) == 0 or len(female_synset) == 0:
		# 	continue
		# male_sid = male_synset[0].lemmas()[0].key()
		# female_sid = female_synset[0].lemmas()[0].key()
		# gender_vects.append(WE.get_vector(male_sid) - WE.get_vector(female_sid))
		'''*****'''

		'''Average sense'''
		male_senses, female_senses = [], []
		# print('male: ', male, 'female: ', female)
		male_synset = wn.synsets(male)
		female_synset = wn.synsets(female)
		if len(male_synset) == 0 or len(female_synset) == 0:
			continue
		# print('male_synset: ', male_synset, 'female_synset: ', female_synset)
		for male_syn in male_synset:
			# for male_lemma in male_lemmas:
			male_lemmas = male_syn.lemmas()
			for male_lemma in male_lemmas:
				male_sense = male_lemma.key()
				male_senses.append(male_sense)
		male_vec = np.mean(np.stack([WE.get_vector(male_sen) for male_sen in male_senses]), axis=0)

		for female_syn in female_synset:
			# for male_lemma in male_lemmas:
			female_lemmas = female_syn.lemmas()
			for female_lemma in female_lemmas:
				female_sense = female_lemma.key()
				female_senses.append(female_sense)
		female_vec = np.mean(np.stack([WE.get_vector(female_sen) for female_sen in female_senses]), axis=0)
		gender_vects.append(male_vec - female_vec)
		'''*****'''

		'''Use GloVe embeddings'''
		gender_vects.append(WE.get_vector(male) - WE.get_vector(female))
		'''*****'''
	
	gender_vects_sum = sum(gender_vects)
	print('gender_vects_sum', gender_vects_sum)
	gender_vec = gender_vects_sum / len(gender_pairs)
	return gender_vec


def eval_gender_bias(WE):
	gender_vec = get_gender_vector(WE)
	sense_bias = False
	# '''Use word embeddings'''
	# occupations = [("engineer", "engineer", "engineer"),
	# 			   ("carpenter", "carpenter", "carpenter"), 
	# 			   ("guide", "guide", "guide"),
	# 			   ("mentor", "mentor", "mentor"),
	# 			   ("judge", "judge", "judge"),
	# 			   ("nurse", "nurse", "nurse")]
	# '''*****'''

	# '''Use lmms'''
	occupations = [("engineer", "engineer%1:18:00::", "engineer%2:31:01::"),
				   ("carpenter", "carpenter%1:18:00::", "carpenter%2:41:00::"), 
				   ("guide", "guide%1:18:00::", "guide%2:38:00::"),
				   ("mentor", "mentor%1:18:00::", "mentor%2:32:00::"),
				   ("judge", "judge%1:18:00::", "judge%2:31:02::"),
				   ("nurse", "nurse%1:18:00::", "nurse%2:29:00::")]
	
	res = {}
	for (word, noun_sid, verb_sid) in occupations:
		res[word] = {}
		if noun_sid not in WE.embed:
			print("Noun Sense Embedding Not Found for =", word)
		else:
			if sense_bias:
				noun_emb = WE.get_vector(noun_sid)
			else:
				noun_senses = []
				noun_sid = get_sk_lemma(noun_sid)
				noun_synset = wn.synsets(noun_sid)
				for noun_syn in noun_synset:
					noun_lemmas = noun_syn.lemmas()
					for noun_lemma in noun_lemmas:
						noun_sense = noun_lemma.key()
						noun_senses.append(noun_sense)
				noun_emb = np.mean(np.stack([WE.get_vector(noun_sen) for noun_sen in noun_senses]), axis=0)

			sim_score = cosine_similarity(gender_vec, noun_emb)
		res[word]["noun_bias"] = sim_score

		if verb_sid not in WE.embed:
			print("Verb Sense Embedding Not Found for =", word)
		else:
			if sense_bias:
				verb_emb = WE.get_vector(verb_sid)
			else:
				verb_senses = []
				verb_sid = get_sk_lemma(verb_sid)
				verb_synset = wn.synsets(verb_sid)
				for verb_syn in verb_synset:
					verb_lemmas = verb_syn.lemmas()
					for verb_lemma in verb_lemmas:
						verb_sense = verb_lemma.key()
						verb_senses.append(verb_sense)
				verb_emb = np.mean(np.stack([WE.get_vector(verb_sen) for verb_sen in noun_senses]), axis=0)

			sim_score = cosine_similarity(gender_vec, verb_emb)
		res[word]["verb_bias"] = sim_score   
		pass
	df = pd.DataFrame(data=res)
	avg = df.copy()
	avg['mean'] = df.T.abs().mean(numeric_only=1)
	print(avg.T)
	return avg.T


class WordEmbedding(object):
	def __init__(self, fname):
		"""
		Load the word embeddings from fname.
		"""
		self.embed = self.load_lmms(fname)
		# self.embed = self.load_ares_txt(fname)
		# self.embed = self.load_glove(fname)
		print("Total number of vectors =", len(self.embed))
		pass


	def load_lmms(self, npz_vecs_path):
		lmms = {}
		loader = np.load(npz_vecs_path)
		labels = loader['labels'].tolist()
		vectors = loader['vectors']
		self.dim = vectors[0].shape[0]
		for label, vector in list(zip(labels, vectors)):
			lmms[label] = vector
		return lmms


	def load_ares_txt(self, path):
		sense_vecs = {}
		with open(path, 'r') as sfile:
			for idx, line in enumerate(sfile):
				if idx == 0:
					continue
				splitLine = line.split(' ')
				label = splitLine[0]
				vec = np.array(splitLine[1:], dtype=float)
				self.dim = vec.shape[0]
				# print('self.dim', self.dim)
				sense_vecs[label] = vec
		return sense_vecs


	def load_glove(self, path):
		embeddings = {}
		with open(path, 'r') as gfile:
			for line in gfile:
				splitLine = line.split(' ')
				word = splitLine[0]
				vec = np.array(splitLine[1:], dtype='float32')
				self.dim = vec.shape[0]
				embeddings[word] = vec
		return embeddings


	def get_vector(self, label):
		"""
		If the label is not a sense-id (i.e. in the case of sense-insensitive static word embeddings)
		return the word embedding instead of sense embedding. You will need to modify this function 
		according to the word embedding you want to evaluate. If the word is not in the sense embedding
		return a zero vector of the same dimensionality.
		"""
		return self.embed.get(label, np.zeros(self.dim))



if __name__ == "__main__":
	# args = parse_args()
	# plot_wat_dim()
	# plot_sssb_dim()
	plot_weat_dim()
	# WE = WordEmbedding("../eval_static_emb_bias/data/lmms_2348.bert-large-cased.fasttext-commoncrawl.npz")   
	# WE = WordEmbedding("./data/lmms_2048.bert-large-cased.npz") 
	# WE = WordEmbedding('../senseEmbeddings/external/lmms/lmms_1024.bert-large-cased.npz')
	# WE = WordEmbedding("../senseEmbeddings/external/ares/ares_bert_large.txt")
	# WE = WordEmbedding('../senseEmbeddings/external/glove/glove.840B.300d.txt')
	# eval_gender_bias(WE)

	# tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
	# model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True)
	# save_embeddings_path = 'data/vectors/semcor_linear_last4layers_multiword_{}_50.npz'.format(args.emb_dim)
	# eval_gender_bias_contextualised_embeddings(WE)
	# model = 'bert-large-uncased'
	# evaluate_gender_bias_mlm(args)
   
