'''
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
'''

from torch.nn import functional as F
import torch.nn as nn
import sys

from wsd_models.util import *
from synset_expand import *

parser = argparse.ArgumentParser(description='Gloss Informed Bi-encoder for WSD')

#training arguments
parser.add_argument('--context_lens', type=int, default=0)
parser.add_argument('--context_lenw', type=int, default=0)
parser.add_argument('--context_mode', type=str, default='wo-select',
                    hoices=['wo-select', 'tfidfwo-select', 'gewo-select', 'global'])
parser.add_argument('--task', type=str, default='wsd-kb',
                    choices= ['wsd-kb', 'wsd-sup', 'wsd-sup-lmms', 'wsd-sup-ares'])
parser.add_argument('--sec_wsd', action='store_true')
parser.add_argument('--context_max_length', type=int, default=128)
parser.add_argument('--gloss_max_length', type=int, default=32)
parser.add_argument('--gloss-bsz', type=int, default=256)
parser.add_argument('--encoder-name', type=str, default='bert-large',
    choices=['bert-base', 'bert-large'])
parser.add_argument('--data-path', type=str, default='./data/wsd_eval/WSD_Evaluation_Framework',
    help='Location of top-level directory for the Unified WSD Framework')


class ContextEncoder(torch.nn.Module):
    def __init__(self, encoder_name, freeze_context):
        super(ContextEncoder, self).__init__()

        self.context_encoder, self.context_hdim = load_pretrained_model(encoder_name)
        self.is_frozen = freeze_context

    def forward(self, input_ids, attn_mask, output_mask):
        with torch.no_grad():
            context_output = self.context_encoder(input_ids, attention_mask=attn_mask)[-1][-4:]

        context_output = torch.cat([i.unsqueeze(0) for i in context_output], dim=0).mean(0)

        example_arr = []
        for i in range(context_output.size(0)):
            example_arr.append(process_encoder_outputs(context_output[i], output_mask[i], as_tensor=True))

        context_output = torch.cat(example_arr, dim=0)

        return context_output

    def context_forward(self, context_input, context_input_mask, context_example_mask):
        return self.forward(context_input, context_input_mask, context_example_mask)

def tokenize_glosses(gloss_arr, tokenizer, max_len):
    glosses = []
    masks = []
    for gloss_text in gloss_arr:
        g_ids = [torch.tensor([[x]]) for x in
             tokenizer.encode(tokenizer.cls_token) + tokenizer.encode(gloss_text) + tokenizer.encode(
                 tokenizer.sep_token)]
        g_attn_mask = [1]*len(g_ids)
        g_fake_mask = [-1]*len(g_ids)
        g_ids, g_attn_mask, _ = normalize_length(g_ids, g_attn_mask, g_fake_mask, max_len,
                                             pad_id=tokenizer.encode(tokenizer.pad_token)[0])
        g_ids = torch.cat(g_ids, dim=-1)
        g_attn_mask = torch.tensor(g_attn_mask)
        glosses.append(g_ids)
        masks.append(g_attn_mask)

    return glosses, masks

#creates a sense label/ gloss dictionary for training/using the gloss encoder
def load_and_preprocess_glosses(data, tokenizer, wn_senses, max_len=-1):
    sense_glosses = {}

    for sent in data:
        for _, lemma, pos, _, label in sent:
            # if label == -1:
            #     continue  # ignore unlabeled words
            # else:
            key = generate_key(lemma, pos)
            if key not in sense_glosses:
                # get all sensekeys for the lemma/pos pair
                try:
                    sensekey_arr = wn_senses[key]
                except KeyError:
                    continue
                if max_len <= 32:
                    gloss_arr = [wn.lemma_from_key(s).synset().definition() for s in sensekey_arr]
                else:
                    gloss_arr = [wn.lemma_from_key(s).synset().definition() + ' ' + '. '.join(
                     wn.lemma_from_key(s).synset().examples()) for s in sensekey_arr]

                # preprocess glosses into tensors
                gloss_ids, gloss_masks = tokenize_glosses(gloss_arr, tokenizer, max_len)
                gloss_ids = torch.cat(gloss_ids, dim=0)
                gloss_masks = torch.stack(gloss_masks, dim=0)
                sense_glosses[key] = (gloss_ids, gloss_masks, sensekey_arr)

            # make sure that gold label is retrieved synset
            assert label in sense_glosses[key][2] or label == -1

    return sense_glosses

def preprocess_context(tokenizer, text_data, gloss_dict=None, bsz=1, max_len=-1):
    if max_len == -1: assert bsz==1 #otherwise need max_length for padding

    context_ids = []
    context_attn_masks = []

    example_keys = []

    context_output_masks = []
    instances = []
    labels = []

    #tensorize data
    # print(tokenizer.encode(tokenizer.cls_token), tokenizer.encode(tokenizer.sep_token))
    cls = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])]
    sep = [torch.tensor([tokenizer.encode(tokenizer.sep_token)])]
    count = []
    tag_sense, tag_lemma = [], []
    for sent in tqdm(text_data):
        #cls token aka sos token, returns a list with index
        # c_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])]
        # o_masks = [-1]
        c_ids = []
        o_masks = []
        sent_insts = []
        sent_keys = []
        sent_labels = []

        #For each word in sentence...
        key_len = []
        for idx, (word, lemma, pos, inst, label) in enumerate(sent):
            #tensorize word for context ids.lower()
            word_ids = [torch.tensor([[x]]) for x in tokenizer.encode(word)]
            c_ids.extend(word_ids)

            if label != -1:
                tag_sense.append(label)
                tag_lemma.append(generate_key(lemma, pos))

            if generate_key(lemma, pos) in gloss_dict:
                if len(gloss_dict[generate_key(lemma, pos)][2]) > 1:
                    count.append(0)
                else:
                    count.append(1)
            #if word is labeled with WSD sense...
            if 't' in inst:
                #add word to bert output mask to be labeled
                o_masks.extend([idx]*len(word_ids))
                #track example instance id
                sent_insts.append(inst)
                #track example instance keys to get glosses
                ex_key = generate_key(lemma, pos)
                sent_keys.append(ex_key)
                key_len.append(len(gloss_dict[ex_key][2]))
                sent_labels.append(label)
            else:
                #mask out output of context encoder for WSD task (not labeled)
                o_masks.extend([-1]*len(word_ids))
                sent_insts.append(inst)
                ex_key = generate_key(lemma, pos)
                sent_keys.append(ex_key)

            #break if we reach max len
            if max_len != -1 and len(c_ids) >= (max_len-1):
                break

        # c_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)]))  # aka eos token
        c_attn_mask = [1]*len(c_ids)
        # o_masks.append(-1)
        assert len(c_ids) == len(o_masks)

        #not including examples sentences with no annotated sense data
        # if len(sent_insts) > 0:
        context_ids.append(c_ids)
        context_attn_masks.append(c_attn_mask)
        context_output_masks.append(o_masks)
        example_keys.append(sent_keys)
        instances.append(sent_insts)
        labels.append(sent_labels)

    #package data
    context_dict, example_dict = dict(), dict()

    doc_id, doc_seg = [], []

    for index, x in enumerate(instances):
        # print(x)
        if 'eval' in x[0]:
            inst = '.'.join(x[0].split('.')[:2])
        else:
            inst = '.'.join(x[0].split('.')[:1])
        if inst not in doc_id:
            doc_id.append(inst)
            doc_seg.append(index)
    doc_seg.append(len(instances))
    new_context, new_attn_mask, new_out_mask = [], [], []
    sent_d = []

    from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

    for seg_index, seg_id in enumerate(tqdm(doc_seg[:-1])):
        ids_c = context_ids[seg_id: doc_seg[seg_index + 1]]
        attn_masks_c = context_attn_masks[seg_id: doc_seg[seg_index + 1]]
        output_masks_c = context_output_masks[seg_id: doc_seg[seg_index + 1]]
        example_keys_c = example_keys[seg_id: doc_seg[seg_index + 1]]
        instances_c = instances[seg_id: doc_seg[seg_index + 1]]
        instances_new = [[inst for inst in inst_s if 't' in inst] for inst_s in instances_c]
        valid_instance = [i for i in instances_c[0] if i != -1][0]
        sent_ids = ['.'.join(i[0].split('.')[:-1]) if 't' in i[0] else i[0] for i in instances_c]
        doc_id = '.'.join(sent_ids[0].split('.')[:-1])
        example_dict[doc_id] = (sum(example_keys_c, []), sum(instances_c, []))
        if 'gewo' in args.context_mode:
            vectorizer_score = TfidfVectorizer()
            example_keys_doc = set(sum(example_keys_c, []))
            for example_key in example_keys_doc:
                gloss_group = []
                if example_key.split('+')[1] not in 'nvar':
                    continue
                if example_key not in synset_dict:
                    syn_list = wn.synsets(example_key.split('+')[0], example_key.split('+')[1])
                    if len(syn_list) > 1:
                        continue
                    for synset in syn_list[:]:
                        # synset_group = gloss_extend(synset.name(), 'relations')
                        synset_group = [synset]
                        gloss_group.extend([' '.join(retrieve_gloss(i)) for i in synset_group])
                    synset_dict[example_key] = gloss_group

            score_doc = [' '.join(examp) + ' ' + ' '.join(sum([synset_dict[i] for i in examp if i in synset_dict], [])) for examp in example_keys_c]
            score_mat = vectorizer_score.fit_transform(score_doc).toarray()
            # dataset = valid_instance.split('.')[0]
            # context_sent = pickle.load(open('./data/context_sent_%s' % dataset, 'rb'))
        if len(valid_instance.split('.')[0]) > 2:
            new_sent_ids = ['.'.join(i.split('.')[1:]) for i in sent_ids]
            doc = [' '.join(examp) for examp in example_keys_c]
            assert len(doc) == len(new_sent_ids)
            if args.context_mode == 'wo-select':
                vectorizer = CountVectorizer()
            else:
                vectorizer = TfidfVectorizer()
            vec_count = vectorizer.fit_transform(doc)
            doc_mat = vec_count.toarray()

            for sent_id, vec in enumerate(doc_mat):
                if not instances_new[sent_id]:
                    continue
                scores = doc_mat[:, doc_mat[sent_id].nonzero()[0]].sum(1)
                id_score = [j for j in
                            sorted(zip([i for i in range(len(doc_mat))], scores), key=lambda x: x[1], reverse=True) if
                            j[0] != sent_id][:args.context_lens]
                # id_score = len_score[:args.context_lens]
                selected = [i[0] for i in id_score]
                sent_d.extend([abs(i-sent_id) for i in selected])

                window_id = [i for i in range(len(doc_mat))][
                            max(sent_id - args.context_lenw, 0):sent_id + args.context_lenw + 1]
                pure_neighbor = [i for i in window_id if i != sent_id]

                if args.context_mode == 'global':
                    ids = sorted(set([sent_id] + pure_neighbor))
                elif 'select' in args.context_mode:
                    if 'gewo' not in args.context_mode:
                        ids = sorted(set(selected + [sent_id]))
                    else:
                        scores = score_mat[:, score_mat[sent_id].nonzero()[0]].sum(1)
                        id_score = [j for j in
                                    sorted(zip([i for i in range(len(score_mat))], scores), key=lambda x: x[1],
                                           reverse=True) if j[0] != sent_id][:args.context_lens]
                        scored = [i[0] for i in id_score]
                        ids = sorted(set(scored + [sent_id]))
                else:
                    ids = [sent_id]
                total_len = len(sum([ids_c[i]for i in ids], []))
                while total_len > 510:
                    distance_index = sorted([(abs(s_id-sent_id), s_id) for s_id in ids], reverse=True)
                    ids.remove(distance_index[0][1])
                    total_len = len(sum([ids_c[i] for i in ids], []))
                if args.context_lens > 0 or args.context_lenw > 0:
                    new_context.append(cls + sum([ids_c[i]for i in ids], []) + sep)
                    new_attn_mask.append([1] + sum([attn_masks_c[i] for i in ids], []) + [1])
                    new_out_mask.append(
                        [-1] + sum([[-1] * len(output_masks_c[i]) if i != sent_id else output_masks_c[i] for i in ids],
                                   []) + [-1])
                    assert len(new_context[-1]) == len(new_attn_mask[-1]) == len(new_out_mask[-1])
                else:
                    new_context.append(cls + ids_c[sent_id] + sep)
                    new_attn_mask.append([1] + attn_masks_c[sent_id] + [1])
                    new_out_mask.append([-1] + output_masks_c[sent_id] + [-1])
                ids = sorted(set(selected + [sent_id] + pure_neighbor))
                context_dict[sent_ids[sent_id]] = [sent_ids[i] for i in ids]
        else:
            new_context.extend(ids_c)
            new_attn_mask.extend(attn_masks_c)
            new_out_mask.extend(output_masks_c)

            for sent_id in sent_ids:
                context_dict[sent_id] = [sent_id]

    instances_list = [[inst for inst in inst_s if 't' in inst] for inst_s in instances]
    example_keys = [[k for j, k in enumerate(i) if 't' in instances[h][j]] for h, i in enumerate(example_keys) if
                    instances_list[h]]
    instances = [i for i in instances_list if i]
    labels = [i for i in labels if i]
    assert [len(i) for i in example_keys] == [len(i) for i in instances] == [len(i) for i in labels]
    print(len(example_keys), len(new_context))
    assert len(example_keys) == len(new_context)
    data = [list(i) for i in
            list(zip(new_context, new_attn_mask, new_out_mask, example_keys, instances, labels))]

    print('Batching data with gloss length = {}...'.format(args.gloss_bsz))
    batched_data = []
    sent_index, current_list = [0], []
    sent_senses = [sum([len(gloss_dict[ex_key][2]) for ex_key in sent[3]]) for sent in data]
    for index, i in enumerate(sent_senses):
        current_list.append(i)
        if sum(current_list) > args.gloss_bsz:
            sent_index.append(index)
            current_list = current_list[-1:]
    sent_index.append(len(sent_senses))

    for index, data_index in enumerate(sent_index[:-1]):
        b = data[data_index: sent_index[index + 1]]
        max_len_b = max([len(x[1]) for x in b])
        if args.context_lens > 0 or args.context_lenw > 0:
            max_len = max(max_len_b, max_len)
        # print(b[0][0])
        for b_index, sent in enumerate(b):
            b[b_index][0], b[b_index][1], b[b_index][2] = normalize_length(sent[0], sent[1], sent[2], max_len,
                                                                       tokenizer.encode(tokenizer.pad_token)[0])
        context_ids = torch.cat([torch.cat(x, dim=-1) for x, _, _, _, _, _ in b], dim=0)[:, :max_len_b]
        context_attn_mask = torch.cat([torch.tensor(x).unsqueeze(dim=0) for _, x, _, _, _, _ in b], dim=0)[:,
                            :max_len_b]
        context_output_mask = torch.cat([torch.tensor(x).unsqueeze(dim=0) for _, _, x, _, _, _ in b], dim=0)[:,
                              :max_len_b]
        example_keys = []
        for _, _, _, x, _, _ in b: example_keys.extend(x)
        instances = []
        for _, _, _, _, x, _ in b: instances.extend(x)
        labels = []
        for _, _, _, _, _, x in b: labels.extend(x)
        batched_data.append(
            (context_ids, context_attn_mask, context_output_mask, example_keys, instances, labels))
    return batched_data, example_dict, context_dict


def wn_all_lexnames_groups():
    from collections import defaultdict
    groups = defaultdict(list)
    for synset in wn.all_synsets():
        groups[synset.lexname()].append(synset)
    return dict(groups)

def sec_wsd(matches, lexname_groups, curr_lemma, curr_postag, curr_vec, csi_data):
    preds = [sk for sk, sim in matches][:]
    preds_sim = [sim for sk, sim in matches][:]
    norm_predsim = np.exp(preds_sim) / np.sum(np.exp(preds_sim))
    name = locals()
    if len(preds) != 1:
        pos2type = {'ADJ': 'as', 'ADV': 'r', 'NOUN': 'n', 'VERB': 'v'}
        synset_list = retrieve_sense(curr_lemma, pos2type[curr_postag])
        keys = [k[0] for k in matches][:2]
        try:
            synsets = {wn.lemma_from_key(j).synset(): i for i, j in enumerate(keys)}
        except:
            synsets = {
                [wn.synset(k) for k in synset_list if j in [l.key() for l in wn.synset(k).lemmas()]][0]: i for
                i, j in enumerate(keys)}
        strategy = 'relations'
        all_related = Counter()
        for potential_synset in synsets.keys():
            name[potential_synset.name()] = set(gloss_extend(potential_synset.name(), strategy))
            all_related.update(list(name[potential_synset.name()]))

        for synset, count in all_related.items():
            if count == 1:
                continue
            for potential_synset in synsets.keys():
                while synset in name[potential_synset.name()]:
                    name[potential_synset.name()].remove(synset)

        for synset_index, potential_synset in enumerate((synsets.keys())):
            lexname = potential_synset.lexname()
            name['sim_%s' % potential_synset.name()] = dict()

            # if len(set([i.lexname() for i in synsets])) > 1:
            #     combine_list = list(name[potential_synset.name()]) + lexname_groups[lexname]
            # else:
            #     combine_list = list(name[potential_synset.name()])
            # for synset in combine_list:
            #     if synset in synsets.keys() and curr_postag not in ['ADJ', 'ADV']:
            #         continue
            #     sim = np.dot(curr_vec, key_vec[key_dict[synset.lemmas()[0].key()]])
            #     name['sim_%s' % potential_synset.name()][synset] = (
            #         sim, 'relation' if synset in name[potential_synset.name()] else 'lexname')

            # if curr_postag not in ['NOUN', 'VERB']:
            if potential_synset.name() in csi_data[0]:
                synset_csi = sum([csi_data[1][i] for i in csi_data[0][potential_synset.name()] if i in csi_data[1]], [])
                synset_csi = [wn.synset(i) for i in synset_csi]
                combine_list = list(name[potential_synset.name()]) + synset_csi
            else:
                combine_list = list(name[potential_synset.name()])
            #     combine_list = list(name[potential_synset.name()])
            # else:
            #     combine_list = list(name[potential_synset.name()]) + lexname_groups[lexname]
            if not combine_list:
                continue
            key_mat = torch.cat([key_vec[key_dict[synset.lemmas()[0].key()]].unsqueeze(0) for synset in combine_list])
            sim_mat = np.dot(curr_vec, key_mat.T)
            # print(sim_mat.shape, sim_mat)
            for index, synset in enumerate(combine_list):
                name['sim_%s' % potential_synset.name()][synset] = (
                    sim_mat[index], 'relation' if synset in name[potential_synset.name()] else 'lexname')

        key_score = {keys[j]: preds_sim[j] + np.sum(
            sorted([syn[0] for syn in name['sim_%s' % i.name()].values()], reverse=True)[:1]) for i, j in
                     synsets.items()}
        # print(sorted(key_score.items(), key=lambda x: x[1], reverse=True))
        final_key = [sorted(key_score.items(), key=lambda x: x[1], reverse=True)[0][0]]

    else:
        final_key = preds

    return final_key

def _eval(eval_data, model, gloss_dict, example_dict):
    lexname_groups = wn_all_lexnames_groups()
    csi_data = pickle.load(open('./data/csi_data', 'rb'))
    tag_lemma, tag_sense = pickle.load(open('./data/tag_semcor.txt', 'rb'))
    zsl, zss = [], []
    pos_tran = {'a': 'ADJ', 'n': 'NOUN', 'r': 'ADV', 'v': 'VERB'}
    eval_preds, sense_vecs = [], {}
    gold_path = os.path.join(args.data_path, 'Evaluation_Datasets/{}/{}.gold.key.txt'.format('ALL', 'ALL'))
    gold_labels = {i.split()[0]: i.split()[1:] for i in open(gold_path, 'r').readlines()}
    name = locals()
    dataset_name = sorted(set([i.split('.')[0] for i in gold_labels]))
    mfs_list, lfs_list = [], []
    pred_dict = dict()
    pred_c_list = []
    pred_c_dict = dict()
    if os.path.exists('./data/pred_dict-%s.txt' % args.task):
        pred_dict = pickle.load(open('./data/pred_dict-%s.txt' % args.task, 'rb'))
    for i in dataset_name:
        name['pred_c_%s' % i], name['pred_all_%s' % i] = 0, 0
    for pos in pos_tran.values():
        name['pred_c_%s' % pos], name['pred_all_%s' % pos] = 0, 0

    if 'wsd' not in args.task:
        context_sent, score_vec, window_vec, select_vec = dict(), dict(), dict(), dict()
    else:
        score_vec = pickle.load(open('./data/score_vec-%s.txt' % args.encoder_name, 'rb'))
        # window_vec = pickle.load(open('./data/window_vec-%s.txt' % args.encoder_name, 'rb'))
        # select_vec = pickle.load(open('./data/select_vec-%s.txt' % args.encoder_name, 'rb'))
        window_vec = pickle.load(open('./data/window_vec-%s-%d.txt' % (args.encoder_name, args.context_lenw), 'rb'))
        select_vec = pickle.load(open('./data/select_vec-%s-%d.txt' % (args.encoder_name, args.context_lens), 'rb'))
    for context_ids, context_attn_mask, context_output_mask, example_keys, insts, labels in tqdm(eval_data):
        with torch.no_grad():

            # if 'semeval2015.d003.s022.t005' not in insts:
            #     continue
            context_ids = context_ids.cuda()
            context_attn_mask = context_attn_mask.cuda()
            context_output = model.context_forward(context_ids, context_attn_mask, context_output_mask).cpu()
            for index, context in enumerate(context_output):
                context_output[index] = torch.tensor(np.array(context) / np.linalg.norm(np.array(context)))

            if 'emb' in args.task:
                assert len(labels) == len(context_output)
                for s_index, sense in enumerate(labels):
                    try:
                        sense_vecs[sense]['vecs_sum'] += context_output[s_index]
                        sense_vecs[sense]['vecs_num'] += 1
                    except KeyError:
                        sense_vecs[sense] = {'vecs_sum': context_output[s_index], 'vecs_num': 1}
                continue

            sent_id, sent_seg = [], []
            key_len_list = []
            for in_index, inst in enumerate(insts):
                s_id = '.'.join(inst.split('.')[:-1])
                if s_id not in sent_id:
                    sent_id.append(s_id)
                    sent_seg.append(in_index)
            sent_seg.append(len(insts))

            for seg_index, seg in enumerate(sent_seg[:-1]):
                key_len_list.append([len(gloss_dict[key][2]) for key in example_keys[seg:sent_seg[seg_index + 1]]])

            senses = [gloss_dict[key][2] for key in example_keys]
            if 'wsd' in args.task:
                gat_out_all = torch.cat([torch.tensor([SREF[key]], dtype=torch.float) for key in sum(senses, [])])

            for seg_index, seg in enumerate(sent_seg[:-1]):
                # curr_sent_inst = instances[seg: sent_seg[seg_index + 1]][0].split('.')[1]
                current_example_keys = example_keys[seg: sent_seg[seg_index + 1]]
                current_key_len = key_len_list[seg_index]
                current_context_output = context_output[seg: sent_seg[seg_index + 1], :]
                current_insts = insts[seg: sent_seg[seg_index + 1]]
                # current_labels = labels[seg: sent_seg[seg_index + 1]]

                if 'wsd' in args.task:
                    gat_out = gat_out_all[
                              sum(sum(key_len_list[:seg_index], [])): sum(sum(key_len_list[:seg_index + 1], [])),
                              :]

                    gloss_output_pad = torch.cat([F.pad(
                        gat_out[sum(current_key_len[:i]): sum(current_key_len[:i + 1]), :],
                        pad=[0, 0, 0, max(current_key_len) - j]).unsqueeze(0) for i, j in enumerate(current_key_len)],
                                                 dim=0)

                if 'wsd' in args.task:
                    sense_vec = []
                    assert len(list(set(['.'.join(i.split('.')[:-2]) for i in current_insts]))) == 1
                    doc_example = example_dict['.'.join(current_insts[0].split('.')[:-2])]

                    doc_example = [i for i in zip(doc_example[0], doc_example[1]) if 't' in i[1]]
                    for index, examp in enumerate(current_example_keys):
                        vec = []
                        for key, ins in doc_example:
                            if key == examp or key not in gloss_dict:
                                continue
                            if 'sup' in args.task:
                                if len(gloss_dict[key][2]) > 1:
                                        continue
                            if examp.split('+')[1] in 'nv':
                                vec.append(SREF[gloss_dict[key][2][0]][:1024])
                        if vec:
                            sense_vec.append(
                                torch.sum(torch.cat([torch.tensor(i).unsqueeze(0) for i in vec]), dim=0).unsqueeze(0))
                        else:
                            sense_vec.append(torch.zeros(1, 1024, dtype=torch.float))

                    sense_vec = torch.cat(sense_vec)
                    sense_vec = nn.functional.normalize(sense_vec, dim=1)

                    win_vec = torch.cat([torch.tensor(window_vec[i]).unsqueeze(0) for i in current_insts])
                    sco_vec = torch.cat([torch.tensor(score_vec[i]).unsqueeze(0) for i in current_insts])
                    sel_vec = torch.cat([torch.tensor(select_vec[i]).unsqueeze(0) for i in current_insts])

                    assert current_context_output.shape == win_vec.shape == sco_vec.shape == sense_vec.shape
                    context_output_c = sel_vec + win_vec + sense_vec
                    if 'sup' in args.task:
                        context_output_c = torch.cat([context_output_c, current_context_output], dim=1)
                        context_output_c = nn.functional.normalize(context_output_c, dim=1)
                    gloss_output_pad = nn.functional.normalize(gloss_output_pad, dim=2)
                    out = torch.bmm(gloss_output_pad, context_output_c.unsqueeze(2)).squeeze(2)

                for j, key in enumerate(current_example_keys):
                    if args.task == 'context_vec' and args.context_mode == 'global':
                        window_vec[current_insts[j]] = np.array(current_context_output[j])
                        continue
                    elif args.task == 'context_vec' and 'gewo' in args.context_mode:
                        score_vec[current_insts[j]] = np.array(current_context_output[j])
                        continue
                    elif args.task == 'context_vec' and 'select' in args.context_mode:
                        select_vec[current_insts[j]] = np.array(current_context_output[j])
                        continue

                    pred_idx = out[j:j + 1, :current_key_len[j]].topk(1, dim=-1)[1].squeeze().item()
                    pred_label = gloss_dict[key][2][pred_idx]
                    curr_vector, curr_lemma, curr_postag = context_output_c[j], key.split('+')[0], key.split('+')[1]
                    curr_postag = pos_tran[curr_postag]

                    if args.sec_wsd:
                        matches = [(i, j) for i, j in
                                   zip(gloss_dict[key][2], out[j:j + 1, :current_key_len[j]].tolist()[0])]
                        matches = sorted(matches, key=lambda x: x[1], reverse=True)
                        preds = sec_wsd(matches, lexname_groups, curr_lemma, curr_postag, curr_vector, csi_data)
                        pred_c_dict[current_insts[j]] = preds
                        pred_label = preds[0]

                    eval_preds.append((current_insts[j], pred_label))
                    pred_dict[current_insts[j]] = pred_label

                    if key not in tag_lemma:
                        if len(gloss_dict[key][2]) > 1:
                            if pred_label in gold_labels[current_insts[j]]:
                                zsl.append(1)
                            else:
                                zsl.append(0)
                    if not tag_sense.intersection(gold_labels[current_insts[j]]):
                        if pred_label in gold_labels[current_insts[j]]:
                            zss.append(1)
                        else:
                            zss.append(0)
                    if set(gloss_dict[key][2][:1]).intersection(gold_labels[current_insts[j]]):
                        if pred_label in gold_labels[current_insts[j]]:
                            mfs_list.append(1)
                        else:
                            mfs_list.append(0)
                    else:
                        if pred_label in gold_labels[current_insts[j]]:
                            lfs_list.append(1)
                        else:
                            lfs_list.append(0)
                    for i in dataset_name:
                        if i in current_insts[j]:
                            name['pred_all_%s' % i] += 1
                            if pred_label in gold_labels[current_insts[j]]:
                                name['pred_c_%s' % i] += 1
                                pred_c_list.append(current_insts[j])
                    for pos in pos_tran.values():
                        if pos in curr_postag:
                            name['pred_all_%s' % pos] += 1
                            if pred_label in gold_labels[current_insts[j]]:
                                name['pred_c_%s' % pos] += 1

    if 'wsd' not in args.task:
        if args.task == 'context_vec' and args.context_mode == 'global':
            pickle.dump(window_vec, open('./data/window_vec-%s-%d.txt' % (args.encoder_name, args.context_lenw), 'wb'), -1)
            # pickle.dump(window_vec, open('./data/window_vec-%s.txt' % args.encoder_name, 'wb'), -1)
        elif args.task == 'context_vec' and 'gewo' in args.context_mode:
            pickle.dump(score_vec, open('./data/score_vec-%s.txt' % args.encoder_name, 'wb'), -1)
        elif args.task == 'context_vec' and 'select' in args.context_mode:
            pickle.dump(select_vec, open('./data/select_vec-%s-%d.txt' % (args.encoder_name, args.context_lens), 'wb'), -1)
            # pickle.dump(select_vec, open('./data/select_vec-%s.txt' % args.encoder_name, 'wb'), -1)
        else:
            sense_dict = dict()
            for sense, vecs_info in sense_vecs.items():
                vec = vecs_info['vecs_sum'] / vecs_info['vecs_num']
                sense_dict[sense] = np.array(vec)
            if 'wngt' not in args.task:
                pickle.dump(sense_dict, open('./data/vectors/lmms-%s.txt' % args.encoder_name, 'wb'), -1)
            else:
                pickle.dump(sense_dict, open('./data/vectors/lmms-wngt-%s.txt' % args.encoder_name, 'wb'), -1)
            logging.info('Written %s %s' % (args.task, args.encoder_name))
        quit()
    else:
        if not os.path.exists('./data/pred_dict-%s.txt' % args.task):
            pickle.dump(pred_dict, open('./data/pred_dict-%s.txt' % args.task, 'wb'), -1)
        correct_pred, all_pred = 0, 0
        test_performance = {}
        for i in dataset_name:
            correct_pred += name['pred_c_%s' % i]
            all_pred += name['pred_all_%s' % i]
            f1 = name['pred_c_%s' % i]/name['pred_all_%s' % i]
            print(i, f1, end='\t')
            test_performance[i] = f1
        for pos in pos_tran.values():
            print(pos, name['pred_c_%s' % pos] / name['pred_all_%s' % pos], end='\t')
        print('ALL', correct_pred/all_pred)
        print(sum(mfs_list)/len(mfs_list), sum(lfs_list)/len(lfs_list), len(mfs_list), len(lfs_list), len(mfs_list + lfs_list))
        print('zss %d, zsl %d' % (len(zss), len(zsl)), 'zss %f, zsl %f' % (sum(zss)/len(zss), sum(zsl)/len(zsl)))
        open('./data/dev_result_%s.txt' % args.task, 'a+').write(
            '%f--' % test_performance['semeval2007'])
        open('./data/dev_result_%s.txt' % args.task, 'a+').write(
            '%s--%f\n' % (str((args.context_lens, args.context_lenw)), correct_pred / all_pred))
    return eval_preds

def train_model(args):
    # print('Training WSD bi-encoder model...')
    if args.freeze_gloss: assert args.gloss_bsz == -1

    #create passed in ckpt dir if doesn't exist
    if not os.path.exists(args.ckpt): os.mkdir(args.ckpt)

    '''
    LOAD PRETRAINED TOKENIZER, TRAIN AND DEV DATA
    '''
    print('Loading data + preprocessing...')
    sys.stdout.flush()

    tokenizer = load_tokenizer(args.encoder_name)

    if 'emb' not in args.task:
        eval_file = 'ALL'
        test_path = os.path.join(args.data_path, 'Evaluation_Datasets/%s/' % eval_file)
        test_data = load_data(test_path, eval_file)[:args.dev_sent]
    else:
        train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/')
        test_data = load_data(train_path, args.train_data, args.train_sent)[:]

    #load gloss dictionary (all senses from wordnet for each lemma/pos pair that occur in data)
    wn_path = os.path.join(args.data_path, 'Data_Validation/candidatesWN30.txt')
    wn_senses = load_wn_senses(wn_path)

    test_gloss_dict = load_and_preprocess_glosses(test_data, tokenizer, wn_senses, max_len=args.gloss_max_length)

    test_data, example_dict, context_dict = preprocess_context(tokenizer, test_data, test_gloss_dict, bsz=args.context_bsz, max_len=args.context_max_length)

    model = ContextEncoder(args.encoder_name, freeze_context=True).cuda()

    eval_preds = _eval(test_data, model, test_gloss_dict, example_dict)

    #generate predictions file
    pred_filepath = os.path.join(args.ckpt, 'tmp_predictions-%s.txt' % args.task)
    with open(pred_filepath, 'w') as f:
        for inst, prediction in eval_preds:
            f.write('{} {}\n'.format(inst, prediction))

    #run predictions through scorer
    gold_filepath = os.path.join(args.data_path, 'Evaluation_Datasets/%s/%s.gold.key.txt' % (eval_file, eval_file))
    scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets')
    _, _, dev_f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath)
    print('Dev f1 = {}'.format(dev_f1))

    return


if __name__ == "__main__":
    #parse args
    args = parser.parse_args()
    print(args)

    #set random seeds
    torch.manual_seed(args.rand_seed)
    os.environ['PYTHONHASHSEED'] = str(args.rand_seed)
    torch.cuda.manual_seed(args.rand_seed)
    torch.cuda.manual_seed_all(args.rand_seed)
    np.random.seed(args.rand_seed)
    random.seed(args.rand_seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    if 'wsd' in args.task:
        print('Loading sense embeddings...')
        if 'lmms' not in args.task:
            if 'ares' not in args.task:
                SREF = pickle.load(open('data/vectors/emb_wn_%s.txt' % args.encoder_name, 'rb'))
                lmms = pickle.load(open('data/vectors/lmms-wngt-%s.txt' % args.encoder_name, 'rb'))
            else:
                SREF = pickle.load(open('data/vectors/ares_wn_%s.txt' % args.encoder_name, 'rb'))
                lmms = pickle.load(open('data/vectors/ares_semcor_%s.txt' % args.encoder_name, 'rb'))
            for sense_key, sense_vector in SREF.items():
                SREF[sense_key] = (np.array(sense_vector) / np.linalg.norm(np.array(sense_vector))).tolist()
            if 'sup' in args.task:
                for sense_key, sense_vector in SREF.items():
                    if sense_key in lmms:
                        lmms[sense_key] = np.array(lmms[sense_key]) / np.linalg.norm(np.array(lmms[sense_key]))
                        SREF[sense_key] = sense_vector + lmms[sense_key].tolist()
                    else:
                        SREF[sense_key] = sense_vector + sense_vector
        else:
            SREF = pickle.load(open('data/vectors/lmms2048.txt', 'rb'))
            for sense_key, sense_vector in SREF.items():
                gloss = sense_vector[:1024] / np.linalg.norm(sense_vector[:1024])
                sem = sense_vector[1024:] / np.linalg.norm(sense_vector[1024:])
                SREF[sense_key] = gloss.tolist() + sem.tolist()

        print(len(list(SREF.values())[0]))
        print('Sense embeddings loaded...')
        key_dict = {}
        key_count = 0
        key_vec = []
        for key, vec in SREF.items():
            key_dict[key] = key_count
            key_count += 1
            key_vec.append(vec)
        key_vec = torch.cat([torch.tensor(key_vec)])
        print(key_vec.shape)
        key_vec = nn.functional.normalize(key_vec, dim=1)
    train_model(args)

#EOF