import torch, six, os, configparser
from tqdm import tqdm
import numpy as np

def get_embedding(embedding_file):
    embs = dict()
    for (i, l) in tqdm(enumerate(open(embedding_file, 'rb'))):
        if not l:
            break
        if len(l) == 0:
            continue

        l_split = l.decode('utf8').strip().split(' ')
        if len(l_split) == 2:
            continue
        embs[l_split[0]] = [float(em) for em in l_split[1:]]

    return embs

def match_embeddings(vocab, emb):
    dim = len(six.next(six.itervalues(emb)))
    filtered_embeddings = np.zeros((len(vocab), dim))
    for w in tqdm(vocab):
        w_id = vocab[w]
        if w in emb:
            filtered_embeddings[w_id] = emb[w]

    return torch.Tensor(filtered_embeddings)

def load_pretrained_embedding(path, vocab):
    file_path = path['SAVED_WORD_EMB'] + path['WORD_DIM'] + ".pt"
    if os.path.exists(file_path):
        print('Loading pretrained word embeddings from ', file_path)
        embedding = torch.load(file_path)
    else:
        print("Start to process pretrained word embeddings ")
        embedding_file = path['WORD_EMB']
        word_embedding = get_embedding(embedding_file)
        embedding = match_embeddings(vocab, word_embedding)
        torch.save(embedding, file_path)
    return embedding

def parse_data_path_cfg():
    cfg = configparser.ConfigParser()
    cfg.read("utils/dataset.cfg")
    path = cfg['PATH']
    return path

def parse_model_path_cfg():
    cfg = configparser.ConfigParser()
    cfg.read("utils/model.cfg")
    model_opt = cfg['MODEL']
    return model_opt

def make_src(data, tag, gpu):
    src_vocab_size = max([max(t[tag]) for t in data]) + 1 # why + 1?
    alignment = torch.zeros(200, len(data), src_vocab_size)
    for i, sent in enumerate(data):
        for j, t in enumerate(sent[tag]):
            alignment[j, i, t] = 1
    if gpu:
        alignment = alignment.cuda()
    return alignment

def make_tgt(data, tag, gpu):
    # tgt_size = max([len(t['citation']) for t in data])
    alignment = torch.zeros(52, len(data)).long()
    for i, sent in enumerate(data):
        sent = sent[tag]
        alignment[:len(sent), i] = torch.LongTensor(sent)
    if gpu:
        alignment = alignment.cuda()
    return alignment