import numpy as np
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, TFGPT2Model, GPT2Tokenizer, TFXLNetModel, XLNetTokenizer, TFT5EncoderModel, T5Tokenizer

from os import path
import _pickle as cPickle
from cwe_term import cwe_term

MODEL_ID_BERT = 'bert-base-cased'
MODEL_BERT = TFBertModel.from_pretrained(MODEL_ID_BERT, output_hidden_states = True, output_attentions = False)
MODEL_TOKENIZER_BERT = BertTokenizer.from_pretrained(MODEL_ID_BERT)

MODEL_ID_GPT2 = 'gpt2'
MODEL_GPT2 = TFGPT2Model.from_pretrained(MODEL_ID_GPT2, output_hidden_states = True, output_attentions = False)
MODEL_TOKENIZER_GPT2 = GPT2Tokenizer.from_pretrained(MODEL_ID_GPT2)

MODEL_ID_XLNET = 'xlnet-base-cased'
MODEL_XLNET = TFXLNetModel.from_pretrained(MODEL_ID_XLNET, output_hidden_states = True, output_attentions = False)
MODEL_TOKENIZER_XLNET = XLNetTokenizer.from_pretrained(MODEL_ID_XLNET)

MODEL_ID_T5 = 't5-base'
MODEL_T5 = TFT5EncoderModel.from_pretrained(MODEL_ID_T5, output_hidden_states = True, output_attentions = False)
MODEL_TOKENIZER_T5 = T5Tokenizer.from_pretrained(MODEL_ID_T5)

pleasant = ['caress','freedom','health','love','peace','cheer','friend','heaven','loyal','pleasure','diamond','gentle','honest','lucky','rainbow','diploma','gift','honor','miracle','sunrise','family','happy','laughter','paradise','vacation']
unpleasant = ['abuse','crash','filth','murder','sickness','accident','death','grief','poison','stink','assault','disaster','hatred','pollute','tragedy','divorce','jail','poverty','ugly','cancer','kill','rotten','vomit','agony','prison']
pleasant_2 = ['joy','love','peace','wonderful','pleasure','friend','laughter','happy']
unpleasant_2 = ['agony','terrible','horrible','nasty','evil','war','awful','failure']
career = ['executive','management','professional','corporation','salary','office','business','career']
domestic = ['home','parents','children','family','cousins','marriage','wedding','relatives']
math = ['math','algebra','geometry','calculus','equations','computation','numbers','addition']
art = ['poetry','art','dance','literature','novel','symphony','drama','sculpture']
science = ['science','technology','physics','chemistry','Einstein','NASA','experiment','astronomy']
art_2 = ['poetry','art','Shakespeare','dance','literature','novel','symphony','drama']

#Get name embeddings
LOAD_PATH = ''

with open(path.join(LOAD_PATH, 'names.pkl'), 'rb') as pkl_reader:
    names = cPickle.load(pkl_reader)

with open(path.join(LOAD_PATH, f'Taylor_contexts.pkl'), 'rb') as pkl_reader:
    contexts = cPickle.load(pkl_reader)

MODEL = MODEL_BERT
TOKENIZER = MODEL_TOKENIZER_BERT
WRITE_MODEL = 'bert'
WRITE_PATH = ''

for name in names:
    new_contexts = [context.replace('Taylor', name) for context in contexts]
    term = cwe_term(name, new_contexts, MODEL, TOKENIZER, WRITE_PATH)