import sys
import os

import argparse
import pickle
import json
import numpy as np

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--lama-model", type=str, help="file containing lama-model")
    parser.add_argument("--pred-file", type=str, help="file containing BERT predictions")
    parser.add_argument("--ents-file", type=str, help="file containing linked entities and types")
    return parser

def getNumTyps(predfile, entsfile, k=10, typs=True):
    ents = getEnts(entsfile, None)
    preds = np.load(predfile)
    preds = preds[:,3:k+3]
    lengths = np.zeros((preds.shape[0],), dtype=np.int32)
    if typs:
        for i in range(preds.shape[0]):
            cur_typs = set()
            for j in range(k):
                cur_typs.update(ents.get(preds[i,j], {}).get('typs', set()))
            lengths[i] = len(cur_typs)
    else:
        for i in range(preds.shape[0]):
            cur_ents = set()
            for j in range(k):
                cur_ents.update(ents.get(preds[i,j], {}).get('ents', set()))
            lengths[i] = len(cur_ents)
    maxlen = max(lengths)
    lengths[lengths==0] = maxlen+2
    return lengths

def getEnts(filename, candidates):
    with open(filename, 'r') as fin:
        entities = json.load(fin)
    candidate_types = {}
    ent2typs = {}
    for mention, ents in entities.items():
        cur_ents = set()
        cur_typs = set()
        eid = None
        for ent in ents:
            eid = ent['eid']
            for val in ent.values():
                if type(val) != type({}):
                    continue
                cur_ents.add(val['entity'])
                cur_typs.add(val['ner'])
        ent2typs[eid] = {'ents':cur_ents, 'typs': cur_typs}
    if candidates is None:
        return ent2typs
    for candidate in candidates:
        candidate_types[candidate] = ent2typs.get(candidate, {})
    return candidate_types


def getLAMA(filename):
    with open(filename, 'rb') as fin:
        model = pickle.load(fin)
    test_triples = []
    min_scores = []
    candidates = [x['token_word_form'] for x in model['list_of_results'][0]['masked_topk']['topk']]
    candidate_scores = []
    for result in model['list_of_results']:
        sample = result['sample']
        head = sample['sub_id']
        rel = sample['pred_id']
        tail = sample['obj_id']
        test_triples.append([head, rel, tail])
        min_scores.append(min([x['log_prob'] for x in result['masked_topk']['topk']]) - 10)
        candidate_scores.append({x['token_word_form']:x['log_prob'] for x in result['masked_topk']['topk']})
    return np.array(test_triples), candidates, candidate_scores, min_scores

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    # getLAMA(params.lama_model)
    lengths = getNumTyps(params.pred_file, params.ents_file)

if __name__ == "__main__":
    main()
