from sklearn.decomposition import PCA

from utils import *
from transformers import RobertaModel,RobertaTokenizer,AlbertModel,AlbertTokenizer
from tqdm import tqdm
import torch
model_des = AlbertModel.from_pretrained('model/ALBERT')
tokenizer = AlbertTokenizer.from_pretrained('model/ALBERT')

# entity2text = read_text_from_txt('data/FB15k-237/entity2textlong_filter.txt')
# relation2text = read_text_from_txt('data/FB15k-237/reverse_relation2text.txt')
entity2text = read_text_from_txt('data/WN18RR/entity2text_filter.txt')
relation2text = read_text_from_txt('data/WN18RR/relation2text_reverse.txt')
# entity2text = read_text_from_txt('data/umls/entity2textlong.txt')
# relation2text = read_text_from_txt('data/umls/reverse_relation2text.txt')
print(len(entity2text))
print(len(relation2text))
total2text = {}
total2text.update(entity2text)
total2text.update(relation2text)

model_des = model_des.to("cuda:0")
word_embeddings = {}
for item in tqdm(total2text.items()):
    key = item[0]
    des = item[1]
    input_des = tokenizer(des,return_tensors='pt',max_length=128,truncation=True,padding='max_length').to("cuda:0")
    # output = model_des(**input_des)['pooler_output']
    #mean pool
    output = model_des(**input_des)
    hidden_states = output.last_hidden_state
    mean_pooler = torch.mean(hidden_states,dim=1)

    #pca降维
    mean_pooler = mean_pooler.view(mean_pooler.size(0),6,128)
    transform_mean_pooler = mean_pooler.mean(dim=1)

    # _,max_pooled = torch.max(hidden_states,dim=1)
    # print(max_pooled.shape)
    #cls
    # cls_hidden_state = output.last_hidden_state[:,0,:]
    # print(cls_hidden_state.shape)
    word_embeddings[key] = transform_mean_pooler.detach().cpu().numpy().tolist()


import json
# with open('data/word_embeddings/FB15k237_ALBERT_embedding.json','w') as f:
#     json.dump(word_embeddings,f)
with open('data/word_embeddings/WN18RR_ALBERT_embedding.json','w') as f:
    json.dump(word_embeddings,f)
# with open('data/word_embeddings/umls_ALBERT_embedding.json','w') as f:
#     json.dump(word_embeddings,f)
