import os
import numpy as np
import tqdm
import argparse
from itertools import chain
from random import shuffle
from util.encoding import *
from constants import lang2vocab
from collections import Counter
from constants import lang2model
from transformers import BertModel
from transformers import BertTokenizer
from util.argparse import str2bool
import constants as c

parser = argparse.ArgumentParser()
parser.add_argument("-lang", "--language", type=str, required=True)
parser.add_argument("-gpu", type=str, required=True)
parser.add_argument("-cs", "--context_size", type=int, required=True)
parser.add_argument("-multiling", "--use_multiling_enc", type=str2bool, required=True)
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
lang = args.language
context_size = args.context_size
basedir = c.AOC_DIR + "cs=%s/" % str(context_size)
basedir += "multilingual/" if args.use_multiling_enc else "monolingual/"
basedir += "%s/" % lang
print(basedir)
os.makedirs(basedir, exist_ok=True)

with open(lang2vocab[lang], "r") as f:
  tgt_vocab_set = set([line.strip() for line in f.readlines()])

with open(c.AOC_DOC_DIR + "%s-low-1m.txt" % lang, "r") as f:
  raw_corpus = [line.strip() for line in f]

model = 'mbert'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_str = lang2model[model] if args.use_multiling_enc else lang2model[lang]

encode = encode_BERT
tokenizer = BertTokenizer.from_pretrained(model_str)
get_model = lambda model_string: BertModel.from_pretrained(model_str, output_hidden_states=True)

print("shuffling corpus")
raw_corpus = list(raw_corpus)
shuffle(raw_corpus)
sentences = raw_corpus

print("creating index")
index = {}
doc2termlist = {}
buffer = 10  # for out of position cases
max_elems = context_size + buffer if context_size > 0 else int(1e10)  # int(1e10) = Inf
for did, sent in tqdm.tqdm(enumerate(raw_corpus), total=len(raw_corpus)):
  sent_tokens = sent.split()
  sent_set = set(sent_tokens)
  words_positions = []
  for word in sent_set:
    # skip if we have a non-zero target vocabulary and word is not contained
    if word not in tgt_vocab_set and tgt_vocab_set:
      continue
    position = sent_tokens.index(word)
    word_wordposition = (word, position)
    doc_wordposition = (did, position)

    if word not in index:
      words_positions.append(word_wordposition)
      index[word] = [doc_wordposition]
    elif len(index[word]) < max_elems:
      words_positions.append(word_wordposition)
      index[word].append(doc_wordposition)
  doc2termlist[did] = words_positions

print("context-count distribution:", end=" ")
context_distr = {k: len(v) for k, v in index.items()}
print(Counter(context_distr.values()))

layer2term2emblist = {}  # store incomplete embeddings
layer2term2emb = {}  # store complete embeddings
# keep only docs that are associated with a term
selected_docs = set([sid for sid, _ in chain(*list(index.values()))])
print("effective corpus size: %s (full size: %s)" % (str(len(selected_docs)), str(len(raw_corpus))))
print("effective vocabulary size: %s" % str(len(index)))

model = get_model(model_str)
model.to(device, non_blocking=True)
model.eval()

skipped_sequences = 0
contextcounts = {}
print("embedd corpus")
for did in tqdm.tqdm(selected_docs, total=len(selected_docs)):
  # Document encoding
  doc = raw_corpus[did]
  sent, length, mask = encode(doc, tokenizer, model_str)
  sent = torch.tensor([sent])
  with torch.no_grad():
    all_layers = model(sent.cuda())[-1]

  # turn word piece embeddings into word embeddings
  tmp = []
  for i, layer in enumerate(all_layers):
    grouped_wp_embs = torch.split(layer[0], mask)
    word_embeddings = [torch.mean(wp_embs, dim=0).detach().cpu().numpy() for wp_embs in grouped_wp_embs]
    tmp.append(word_embeddings)
  all_layers = tmp

  # embedding updates
  for i, emb_seq in enumerate(all_layers):
    term2emblist = layer2term2emblist[i] if i in layer2term2emblist else {}
    term2emb = layer2term2emb[i] if i in layer2term2emb else {}
    for term, position in doc2termlist[did]:
      # skip if term is done already
      if term not in term2emb:
        emblist = term2emblist[term] if term in term2emblist else []
        if position < len(emb_seq):
          emblist.append(emb_seq[position])
          # all contexts collected -> emb done
          if len(emblist) == context_size:
            term2emb[term] = np.mean(emblist, axis=0)
            contextcounts[term] = context_size
            del term2emblist[term]
          else:
            term2emblist[term] = emblist
        else:
          skipped_sequences += 1
    layer2term2emblist[i] = term2emblist
    layer2term2emb[i] = term2emb

for i in range(len(layer2term2emb)):
  term2emblist = layer2term2emblist[i] if i in layer2term2emblist else {}
  term2emb = layer2term2emb[i] if i in layer2term2emb else {}
  for term, emblist in term2emblist.items():
    assert term not in term2emb
    context_size = len(emblist)
    term2emb[term] = np.mean(emblist, axis=0) if context_size > 1 else emblist[0]
    contextcounts[term] = context_size
  layer2term2emb[i] = term2emb

vocab = list(layer2term2emb[0].keys())
for layer in range(len(layer2term2emb)):
  term2emb = layer2term2emb[layer]
  embedding_table = []
  for term in vocab:
    embedding_table.append(term2emb[term])

  with open(basedir + "%s_%s.npy" % (lang, str(layer)), "wb") as f:
    np.save(f, np.array(embedding_table, dtype=np.float))

vocab = [entry + "\n" for entry in vocab]
with open(basedir + "%s.vocab" % lang, "w") as f:
  f.writelines(vocab)
