import time
import argparse
import logging
from collections import defaultdict

import numpy as np
from nltk.corpus import wordnet as wn
from nltk import word_tokenize
from tqdm import tqdm

from wsd_models.util import *

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S')

class GLossEncoder(torch.nn.Module):
    def __init__(self, encoder_name):
        super(GLossEncoder, self).__init__()

        #load pretrained model as base for context encoder and gloss encoder
        self.gloss_encoder, self.context_hdim = load_pretrained_model(encoder_name)

    def gloss_forward(self, input_ids, attn_mask):
        #encode gloss text
        with torch.no_grad():
            gloss_output = self.gloss_encoder(input_ids, attention_mask=attn_mask)[-1][-4:]

        gloss_output = torch.cat([i.unsqueeze(0) for i in gloss_output], dim=0).mean(0)

        #training model to put all sense information on CLS token
        gloss_output = gloss_output[:,1:,:] #now bsz*gloss_hdim
        return gloss_output


def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]


def wn_synset2keys(synset):
    if isinstance(synset, str):
        synset = wn.synset(synset)
    return list(set([lemma.key() for lemma in synset.lemmas()]))


def fix_lemma(lemma):
    return lemma.replace('_', ' ')


def get_sense_data(emb_strategy, tokenizer):
    data = []
    import pickle
    name = locals()
    for pos in ['n', 'r', 'v', 'a']:
        try:
            name['%s_example' % pos] = pickle.load(open('./sentence_dict_%s' % pos, 'rb'))
            name['%s_example' % pos] = {i: [k for k in j] for i, j in name['%s_example' % pos].items() if j}
            print('%s sentences loaded!' % pos)
        except:
            name['%s_example' % pos] = {}
    type2pos = {1: 'n', 2: 'v', 3: 'a', 4: 'r', 5: 'a'}
    cls = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])]
    sep = [torch.tensor([tokenizer.encode(tokenizer.sep_token)])]
    for index, synset in enumerate(tqdm(wn.all_synsets())):
        all_lemmas = [fix_lemma(lemma.name()) for lemma in synset.lemmas()]
        gloss = ' '.join(word_tokenize(synset.definition()))
        ty = int([i.key() for i in synset.lemmas()][0].split('%')[1][0])
        if synset.name() in name['%s_example' % type2pos[ty]]:
            examples = ' '.join(word_tokenize(' '.join(name['%s_example' % type2pos[ty]][synset.name()])))
        else:
            examples = ''
        if 'examples' in emb_strategy:
            examples += ' '.join(word_tokenize(' '.join(synset.examples())))
        for lemma in synset.lemmas():
            lemma_name = fix_lemma(lemma.name())
            # d_str = lemma_name + ' - ' + ' , '.join(all_lemmas) + ' - ' + gloss + ' - ' + examples
            # d_str = lemma_name + ' ' + ' '.join(all_lemmas) + ' ' + gloss + ' ' + examples
            d_str = ' '.join(all_lemmas) + ' ' + gloss + ' ' + examples
            g_ids = cls + [torch.tensor([[x]]) for x in tokenizer.encode(d_str)] + sep
            g_ids = g_ids[:512]
            c_attn_mask = [1] * len(g_ids)
            g_fake_mask = [-1] * len(g_ids)
            data.append([synset, lemma.key(), [g_ids, c_attn_mask, g_fake_mask]])
            break
        # if index > 10:
        #     break

    data = sorted(data, key=lambda x: x[0])
    return data


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Creates sense embeddings based on glosses and lemmas.')
    parser.add_argument('-batch_size', type=int, default=128, help='Batch size (BERT)', required=False)
    parser.add_argument('-encoder_name', type=str, default='bert-large', help='model name', required=False)
    parser.add_argument('-emb_strategy', type=str, default='aug_gloss+examples',
                        help='different components to learn the basic sense embeddings', required=False)
    parser.add_argument('-out_path', help='Path to resulting vector set', required=False,
                        default='data/vectors/%s-%s.txt')
    args = parser.parse_args()
    pooling_strategy = 'REDUCE_MEAN' # important parameter to replicate results using bert-as-service

    model = GLossEncoder(args.encoder_name).cuda()
    tokenizer = load_tokenizer(args.encoder_name)
    logging.info('Preparing Gloss Data ...')
    glosses = get_sense_data(args.emb_strategy, tokenizer)
    glosses_vecs = dict()

    logging.info('Embedding Senses ...')
    t0 = time.time()
    for batch_idx, glosses_batch in enumerate(tqdm(chunks(glosses, args.batch_size))):
        dfns = [e[-1] for e in glosses_batch]
        len_list = [len(i[0]) - 1 for i in dfns]
        max_len = max([len(i[0]) for i in dfns])
        for b_index, id_mask in enumerate(dfns):
            dfns[b_index][0], dfns[b_index][1], _ = normalize_length(id_mask[0], id_mask[1], id_mask[2], max_len,
                                                                     tokenizer.encode(tokenizer.pad_token)[0])
        gloss_ids = torch.cat([torch.cat(x, dim=-1) for x, _, _ in dfns], dim=0).cuda()
        gloss_attn_mask = torch.cat([torch.tensor(x).unsqueeze(dim=0) for _, x, _ in dfns], dim=0).cuda()
        gloss_out = model.gloss_forward(gloss_ids, gloss_attn_mask).cpu()

        assert len(gloss_out) == len(glosses_batch)
        for index, (synset, sensekey, dfn) in enumerate(glosses_batch):
            for lemma in synset.lemmas():
                sensekey = lemma.key()
                glosses_vecs[sensekey] = np.array(gloss_out[index][:len_list[index], :].mean(0))

        t_span = time.time() - t0
        n = (batch_idx + 1) * args.batch_size
        logging.info('%d/%d at %.3f per sec' % (n, len(glosses), n/t_span))

    logging.info('Writing Vectors ...')
    import pickle
    pickle.dump(glosses_vecs, open(args.out_path % (str(args.emb_strategy), args.encoder_name), 'wb'), -1)
