import sys
import os

from copy import deepcopy
import json
import pickle
import argparse
import numpy as np
from click import progressbar
from tqdm import tqdm
from functools import partial
import multiprocessing as mp

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--bertout", type=str, help="bert predictions output file [pickle]", required=True)
    parser.add_argument("--ent2type", type=str, help="file containing entity name to type mapping [json]", required=True)
    parser.add_argument("--toptypes", type=str, help="file containing top types for evaluation [json]", default=None)
    parser.add_argument("--tailvocab", type=str, help="tail vocab [txt]", required=True)
    parser.add_argument("--roottype", action="store_true", default=False, help="use this flag to use the root type")
    parser.add_argument("--k", type=int, help="number of prediction to use", default=5)
    return parser

def load_tail_vocab(filename):
    vocab = set()
    with open(filename, 'r') as fin:
        for line in fin:
            line = line.strip()
            if line:
                vocab.add(line)
    return vocab

def computeMetrics(preds_list, labels):
    rank = np.inf
    idx = 0
    for preds in preds_list:
        idx += 1
        if labels.intersection(preds):
            rank = idx
            break
    hit_list = [1,3,5,10]
    hits = {}
    if rank == np.inf:
        rank = len(preds_list)+1
    for hit in hit_list:
        if rank <= hit:
            hits[hit] = 1.0
        else:
            hits[hit] = 0.0
    inv_rank = 1.0/rank
    return rank, inv_rank, hits

def processOneResult(roottype, ignore_types, tail_vocab, ent2type, toptypes, res):
    # print("Processing: %s" % res['sample']['masked_sentences'][0])
    tail_vocab_list = list(tail_vocab)
    sample = res['sample']
    topk = res['masked_topk']['topk']
    true_types = sample['tail_types']
    if roottype:
        true_types = set([typ.split('/')[0] for typ in true_types]).difference(ignore_types)
    else:
        true_types = set(true_types).difference(ignore_types)
    preds_list = []
    random_preds_list = []
    k = 0
    for pred in topk:
        token = pred['token_word_form'].lower()
        if token in tail_vocab:
            k += 1
            types = ent2type.get(token, [])
            if roottype:
                types = set([typ.split('/')[0] for typ in types]).difference(ignore_types)
            else:
                types = set(types).difference(ignore_types)
            if toptypes is not None:
                preds_list.append(types.intersection(toptypes))
            else:
                preds_list.append(types)
            # randomly select a tail token and use it's type
            random_token = np.random.choice(tail_vocab_list)
            random_types = ent2type.get(random_token, [])
            if roottype:
                random_types = set([typ.split('/')[0] for typ in random_types]).difference(ignore_types)
            else:
                random_types = set(random_types).difference(ignore_types)
            if toptypes is not None:
                random_preds_list.append(random_types.intersection(toptypes))
            else:
                random_preds_list.append(random_types)

    cur_bert_rank, cur_bert_inv_rank, cur_bert_hits = computeMetrics(preds_list, true_types)
    cur_random_rank, cur_random_inv_rank, cur_random_hits = computeMetrics(random_preds_list, true_types)
    outdict = {}
    outdict['bert'] = {'rank': cur_bert_rank,
                       'inv_rank': cur_bert_inv_rank,
                       'hits': cur_bert_hits,
                      }
    outdict['random'] = {'rank': cur_random_rank,
                         'inv_rank': cur_random_inv_rank,
                         'hits': cur_random_hits,
                        }
    return outdict


def evalBert(params, ent2type, bert):
    allTypes = set()
    for _, types in ent2type.items():
        allTypes.update(types)
    if params.roottype:
        allTypes = set([typ.split("/")[0] for typ in allTypes])
    allTypes = list(allTypes)
    if params.toptypes is not None:
        with open(params.toptypes, 'r') as fin:
            toptypes = json.load(fin)
        if params.roottype:
            toptypes = toptypes['topRootTypes']
        else:
            toptypes = toptypes['topTypes']
    else:
        toptypes = None
    results = bert['list_of_results']
    tail_vocab = load_tail_vocab(params.tailvocab)
    tail_vocab_list = list(tail_vocab)
    ignore_types = 'common/topic'
    hit_list = [1,3,5,10]
    random_inv_ranks = []
    random_ranks = []
    random_hits = {}
    bert_inv_ranks = []
    bert_ranks = []
    bert_hits = {}
    for hit in hit_list:
        random_hits[hit] = []
        bert_hits[hit] = []

    pool = mp.Pool(mp.cpu_count())
    partial_func =  partial(processOneResult, params.roottype, ignore_types, tail_vocab, ent2type, toptypes)
    perfs = pool.map(partial_func, results)
    for cur_perf in perfs:
    # for res in results:
        # cur_perf = processOneResult(params.roottype, ignore_types, tail_vocab, ent2type, toptypes, res)
        bert_ranks.append(cur_perf['bert']['rank'])
        bert_inv_ranks.append(cur_perf['bert']['inv_rank'])
        for hit in hit_list:
            bert_hits[hit].append(cur_perf['bert']['hits'][hit])
        
        random_ranks.append(cur_perf['random']['rank'])
        random_inv_ranks.append(cur_perf['random']['inv_rank'])
        for hit in hit_list:
            random_hits[hit].append(cur_perf['random']['hits'][hit])
 
    # with progressbar(results) as bar:
    #     for res in bar:
            # sample = res['sample']
            # topk = res['masked_topk']['topk']
            # true_types = sample['tail_types']
            # if params.roottype:
            #     true_types = set([typ.split('/')[0] for typ in true_types]).difference(ignore_types)
            # else:
            #     true_types = set(true_types).difference(ignore_types)
            # preds_list = []
            # random_preds_list = []
            # k = 0
            # for pred in topk:
            #     token = pred['token_word_form'].lower()
            #     if token in tail_vocab:
            #         k += 1
            #         types = ent2type.get(token, [])
            #         if params.roottype:
            #             types = set([typ.split('/')[0] for typ in types]).difference(ignore_types)
            #         else:
            #             types = set(types).difference(ignore_types)
            #         if params.toptypes is not None:
            #             preds_list.append(types.intersection(toptypes))
            #         else:
            #             preds_list.append(types)
            #         # randomly select a tail token and use it's type
            #         random_token = np.random.choice(tail_vocab_list)
            #         random_types = ent2type.get(random_token, [])
            #         if params.roottype:
            #             random_types = set([typ.split('/')[0] for typ in random_types]).difference(ignore_types)
            #         else:
            #             random_types = set(random_types).difference(ignore_types)
            #         if params.toptypes is not None:
            #             random_preds_list.append(random_types.intersection(toptypes))
            #         else:
            #             random_preds_list.append(random_types)

            # cur_bert_rank, cur_bert_inv_rank, cur_bert_hits = computeMetrics(preds_list, true_types)
            # cur_random_rank, cur_random_inv_rank, cur_random_hits = computeMetrics(random_preds_list, true_types)

            # bert_ranks.append(cur_bert_rank)
            # bert_inv_ranks.append(cur_bert_inv_rank)
            # for hit in hit_list:
            #     bert_hits[hit].append(cur_bert_hits[hit])
            
            # random_ranks.append(cur_random_rank)
            # random_inv_ranks.append(cur_random_inv_rank)
            # for hit in hit_list:
            #     random_hits[hit].append(cur_random_hits[hit])
            
    bert_mr = np.mean(bert_ranks)
    bert_mrr = np.mean(bert_inv_ranks)
    bert_mean_hits = {}
    for hit in hit_list:
        bert_mean_hits[hit] = np.mean(bert_hits[hit])
    print(f"Bert MR: {bert_mr}")
    print(f"Bert MRR: {bert_mrr}")
    for hit in hit_list:
        print(f"Bert Hits@{hit}: {bert_mean_hits[hit]}")

    random_mr = np.mean(random_ranks)
    random_mrr = np.mean(random_inv_ranks)
    random_mean_hits = {}
    for hit in hit_list:
        random_mean_hits[hit] = np.mean(random_hits[hit])
    print(f"Random MR: {random_mr}")
    print(f"Random MRR: {random_mrr}")
    for hit in hit_list:
        print(f"Random Hits@{hit}: {random_mean_hits[hit]}")

    outdict = {}
    outdict['bert'] = {'mr': bert_mr,
                       'mrr': bert_mrr,
                       'hits': bert_mean_hits,
                      }
    outdict['random'] = {'mr': random_mr,
                         'mrr': random_mrr,
                         'hits': random_mean_hits,
                        }
    import pdb; pdb.set_trace()
    return outdict

def computeMetricsOld(preds, labels):
    if len(labels) == 0:
        if len(preds) == 0:
            f1 = 1
            precision = 1
            recall = 1
        else:
            precision = 0
            recall = 1
            f1 = 0
    else:
        common = preds.intersection(labels)
        if len(preds) == 0:
            precision = 1
        else:
            precision = len(common)/len(preds)
        recall = len(common)/len(labels)
        if len(common) == 0:
            f1 = 0
        else:
            f1 = 2*(precision*recall)/(precision+recall)
    return precision, recall, f1

def evalBertOld(params, ent2type, bert):
    allTypes = set()
    for _, types in ent2type.items():
        allTypes.update(types)
    if params.roottype:
        allTypes = set([typ.split("/")[0] for typ in allTypes])
    allTypes = list(allTypes)
    results = bert['list_of_results']
    tail_vocab = load_tail_vocab(params.tailvocab)
    ignore_types = 'common/topic'
    mostfreq_precisions = []
    mostfreq_recalls = []
    mostfreq_f1s = []
    random_precisions = []
    random_recalls = []
    random_f1s = []
    bert_precisions = []
    bert_recalls = []
    bert_f1s = []
    if params.toptypes is not None:
        with open(params.toptypes, 'r') as fin:
            toptypes = json.load(fin)
        if params.roottype:
            topTypeCounts = toptypes['rootTypeCounts']
            toptypes = toptypes['topRootTypes']
        else:
            topTypeCounts = toptypes['typeCounts']
            toptypes = toptypes['topTypes']
        mostfreq_types_all = [xx[0] for xx in sorted(topTypeCounts, key=lambda x: x[1], reverse=True)]
    # for res in results:
    for res in tqdm(results):
        sample = res['sample']
        topk = res['masked_topk']['topk']
        true_types = sample['tail_types']
        pred_types = set()
        k = 0
        for pred in topk:
            token = pred['token_word_form'].lower()
            if token in tail_vocab:
                k += 1
                types = ent2type.get(token, [])
                pred_types.update(types)
                if k>=params.k:
                    break
        if params.roottype:
            pred_types = set([typ.split('/')[0] for typ in pred_types]).difference(ignore_types)
            true_types = set([typ.split('/')[0] for typ in true_types]).difference(ignore_types)
        else:
            pred_types = set(pred_types).difference(ignore_types)
            true_types = set(true_types).difference(ignore_types)
        random_types = np.random.choice(allTypes, k)
        if params.toptypes is not None:
            # with open(params.toptypes, 'r') as fin:
            #     toptypes = json.load(fin)
            # if params.roottype:
            #     toptypes = toptypes['topRootTypes']
            #     topTypeCounts = toptypes['rootTypeCounts']
            # else:
            #     toptypes = toptypes['topTypes']
            #     topTypeCounts = toptypes['typeCounts']
            pred_types = pred_types.intersection(toptypes)
            true_types = true_types.intersection(toptypes)
            random_types = set(np.random.choice(toptypes, len(pred_types)).tolist())
            mostfreq_types = set(mostfreq_types_all[:len(pred_types)])
            # random_types = set(np.random.choice(toptypes, k).tolist())
        bert_precision, bert_recall, bert_f1 = computeMetricsOld(pred_types, true_types)
        bert_precisions.append(bert_precision)
        bert_recalls.append(bert_recall)
        bert_f1s.append(bert_f1)
        random_precision, random_recall, random_f1 = computeMetricsOld(random_types, true_types)
        random_precisions.append(random_precision)
        random_recalls.append(random_recall)
        random_f1s.append(random_f1)
        mostfreq_precision, mostfreq_recall, mostfreq_f1 = computeMetricsOld(mostfreq_types, true_types)
        mostfreq_precisions.append(mostfreq_precision)
        mostfreq_recalls.append(mostfreq_recall)
        mostfreq_f1s.append(mostfreq_f1)
    import pdb; pdb.set_trace()
    bert_f1 = np.mean(bert_f1s)
    bert_precision = np.mean(bert_precisions)
    bert_recall = np.mean(bert_recalls)
    print(f"Bert F1: {bert_f1}")
    print(f"Bert Precision: {bert_precision}")
    print(f"Bert Recall: {bert_recall}")
    print(f"Bert F1-mean: {2*bert_precision*bert_recall/(bert_precision+bert_recall)}")
    random_f1 = np.mean(random_f1s)
    random_precision = np.mean(random_precisions)
    random_recall = np.mean(random_recalls)
    print(f"Random F1: {random_f1}")
    print(f"Random Precision: {random_precision}")
    print(f"Random Recall: {random_recall}")
    print(f"Random F1-mean: {2*random_precision*random_recall/(random_precision+random_recall)}")
    mostfreq_f1 = np.mean(mostfreq_f1s)
    mostfreq_precision = np.mean(mostfreq_precisions)
    mostfreq_recall = np.mean(mostfreq_recalls)
    print(f"Most Frequent F1: {mostfreq_f1}")
    print(f"Most Frequent Precision: {mostfreq_precision}")
    print(f"Most Frequent Recall: {mostfreq_recall}")
    print(f"Most Frequent F1-mean: {2*mostfreq_precision*mostfreq_recall/(mostfreq_precision+mostfreq_recall)}")
    outdict = {}
    outdict['bert'] = {'precision': bert_precision,
                       'recall': bert_recall,
                       'micro_f1': bert_f1,
                       'macro_f1': 2*bert_precision*bert_recall/(bert_precision+bert_recall), 
                      }
    outdict['random'] = {'precision': random_precision,
                         'recall': random_recall,
                         'micro_f1': random_f1,
                         'macro_f1': 2*random_precision*random_recall/(random_precision+random_recall)
                        }
    outdict['mostfreq'] = {'precision': mostfreq_precision,
                         'recall': mostfreq_recall,
                         'micro_f1': mostfreq_f1,
                         'macro_f1': 2*mostfreq_precision*mostfreq_recall/(mostfreq_precision+mostfreq_recall)
                        }
    import pdb; pdb.set_trace()
    return outdict

def evalAll(params):
    with open(params.bertout, 'rb') as fin:
        bert = pickle.load(fin)
    with open(params.ent2type, 'r') as fin:
        ent2type = json.load(fin)

    toptypefile = "data/FB15K_all/types/top_%d_types.json"
    outfile = "./bert_type_perf_tmp.json"
    # outfile = "./bert_type_perf.json"
    all_runs = []
    for pred_k in [1,5,10]:
        for type_k in range(10, 101, 10):
            for roottype in [True, False]:
                cur_params = deepcopy(params)
                cur_params.k = pred_k
                cur_params.roottype = roottype
                cur_params.toptypes = toptypefile % type_k
                print(f"pred_k: {pred_k}\t type_k: {type_k}\t roottype: {roottype}")
                param_dict = cur_params.__dict__
                perf = evalBert(cur_params, ent2type, bert)
                all_runs.append({'params':param_dict, 'perf':perf})
    with open(outfile, 'w') as fout:
        json.dump(all_runs, fout)

def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    with open(params.bertout, 'rb') as fin:
        bert = pickle.load(fin)
    with open(params.ent2type, 'r') as fin:
        ent2type = json.load(fin)
    outdict = evalBertOld(params, ent2type, bert)
    # outdict = evalBert(params, ent2type, bert)
    # evalAll(params)

if __name__ == "__main__":
    main()
