import os
import torch
import numpy as np
import tqdm
import argparse
from constants import lang2vocab
from constants import lang2model
from transformers import BertModel
from transformers import BertTokenizer
import constants as c

parser = argparse.ArgumentParser()
parser.add_argument("-lang", "--language", type=str)
parser.add_argument("-gpu", type=str)
parser.add_argument("-multiling", "--use_multiling_enc", type=lambda string: True if string == "True" else False)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu


lang = args.language
basedir = c.ISO_DIR + "%s/" % lang
if args.use_multiling_enc:
  model_str = lang2model['mbert']
  basedir += "multilingual/"
else:
  model_str = lang2model[lang]
  basedir += "monolingual/"

os.makedirs(basedir, exist_ok=True)
device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
model = BertModel.from_pretrained(model_str, output_hidden_states=True)
tokenizer = BertTokenizer.from_pretrained(model_str)
model.to(device, non_blocking=True)
model.eval()

with open(lang2vocab[lang], "r") as f:
  vocabulary = [line.strip() for line in f.readlines()]

num_layers = 13
layer2embs = {layer: [] for layer in range(num_layers)}
vocab = []
with torch.no_grad():
  for entry in tqdm.tqdm(vocabulary, total=len(vocabulary)):
    encoded = [tokenizer.encode(entry, add_special_tokens=False)]
    encoded = torch.tensor(encoded, dtype=torch.long).to(device)
    if encoded.nelement() > 0:
      all_layers = model(encoded)[-1]
      vocab.append(entry + "\n")
      assert len(all_layers) == num_layers

      for i, embedding_layer_i in enumerate(all_layers):
        embedding = torch.squeeze(torch.mean(embedding_layer_i, dim=1))
        embedding = embedding.detach().cpu().numpy()
        layer2embs[i].append(embedding)

for layer in range(num_layers):
  with open(basedir + "%s_%s.npy" % (lang, str(layer)), "wb") as f:
    np.save(f, np.array(layer2embs[layer], dtype=np.float))

with open(basedir + "%s.vocab" % lang, "w") as f:
  f.writelines(vocab)
