import sys
import os

import argparse
import json
from copy import deepcopy
from collections import Counter
import pickle
import numpy as np

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--bertout", type=str, help="pickle containing bert predicitons", default='../BERT/LAMA/output/results/bert_large/freebase/uncased_result.pkl')
    # parser.add_argument("--ent2type", type=str, help="file containing entity name to type mapping [json]", required=True)
    parser.add_argument("--types", type=str, help="file containing the types", default="eval/human/types.txt")
    parser.add_argument("--num-samples", type=int, help="number of samples to extract", default=150)
    parser.add_argument("--num-splits", type=int, help="number of files into which samples should be split", default=3)
    parser.add_argument("--num-copies", type=int, help="number of copies for each sample", default=3)
    parser.add_argument("--response", type=str, help="directory containing the response", default="eval/human/response/")
    parser.add_argument("--outfile", type=str, help="output file to dump the results", default="eval/human/perf.json")
    # parser.add_argument('--most-common', action='store_true', help='', default=False)
    # parser.add_argument('--other-types', action='store_true', help='', default=False)
    # parser.add_argument("--outdir", type=str, help="output dir", required=True)

    return parser

def computeMetricsOld(preds, labels):
    if len(labels) == 0:
        if len(preds) == 0:
            f1 = 1
            precision = 1
            recall = 1
        else:
            precision = 0
            recall = 1
            f1 = 0
    else:
        common = preds.intersection(labels)
        if len(preds) == 0:
            precision = 1
        else:
            precision = len(common)/len(preds)
        recall = len(common)/len(labels)
        if len(common) == 0:
            f1 = 0
        else:
            f1 = 2*(precision*recall)/(precision+recall)
    return precision, recall, f1

def evalBertOld(params, ent2type, bert):
    allTypes = set()
    for _, types in ent2type.items():
        allTypes.update(types)
    if params.roottype:
        allTypes = set([typ.split("/")[0] for typ in allTypes])
    allTypes = list(allTypes)
    results = bert['list_of_results']
    tail_vocab = load_tail_vocab(params.tailvocab)
    ignore_types = 'common/topic'
    mostfreq_precisions = []
    mostfreq_recalls = []
    mostfreq_f1s = []
    random_precisions = []
    random_recalls = []
    random_f1s = []
    bert_precisions = []
    bert_recalls = []
    bert_f1s = []
    if params.toptypes is not None:
        with open(params.toptypes, 'r') as fin:
            toptypes = json.load(fin)
        if params.roottype:
            topTypeCounts = toptypes['rootTypeCounts']
            toptypes = toptypes['topRootTypes']
        else:
            topTypeCounts = toptypes['typeCounts']
            toptypes = toptypes['topTypes']
        mostfreq_types_all = [xx[0] for xx in sorted(topTypeCounts, key=lambda x: x[1], reverse=True)]
    # for res in results:
    for res in tqdm(results):
        sample = res['sample']
        topk = res['masked_topk']['topk']
        true_types = sample['tail_types']
        pred_types = set()
        k = 0
        for pred in topk:
            token = pred['token_word_form'].lower()
            if token in tail_vocab:
                k += 1
                types = ent2type.get(token, [])
                pred_types.update(types)
                if k>=params.k:
                    break
        if params.roottype:
            pred_types = set([typ.split('/')[0] for typ in pred_types]).difference(ignore_types)
            true_types = set([typ.split('/')[0] for typ in true_types]).difference(ignore_types)
        else:
            pred_types = set(pred_types).difference(ignore_types)
            true_types = set(true_types).difference(ignore_types)
        random_types = np.random.choice(allTypes, k)
        if params.toptypes is not None:
            # with open(params.toptypes, 'r') as fin:
            #     toptypes = json.load(fin)
            # if params.roottype:
            #     toptypes = toptypes['topRootTypes']
            #     topTypeCounts = toptypes['rootTypeCounts']
            # else:
            #     toptypes = toptypes['topTypes']
            #     topTypeCounts = toptypes['typeCounts']
            pred_types = pred_types.intersection(toptypes)
            true_types = true_types.intersection(toptypes)
            random_types = set(np.random.choice(toptypes, len(pred_types)).tolist())
            mostfreq_types = set(mostfreq_types_all[:len(pred_types)])
            # random_types = set(np.random.choice(toptypes, k).tolist())
        bert_precision, bert_recall, bert_f1 = computeMetricsOld(pred_types, true_types)
        bert_precisions.append(bert_precision)
        bert_recalls.append(bert_recall)
        bert_f1s.append(bert_f1)
        random_precision, random_recall, random_f1 = computeMetricsOld(random_types, true_types)
        random_precisions.append(random_precision)
        random_recalls.append(random_recall)
        random_f1s.append(random_f1)
        mostfreq_precision, mostfreq_recall, mostfreq_f1 = computeMetricsOld(mostfreq_types, true_types)
        mostfreq_precisions.append(mostfreq_precision)
        mostfreq_recalls.append(mostfreq_recall)
        mostfreq_f1s.append(mostfreq_f1)
    import pdb; pdb.set_trace()
    bert_f1 = np.mean(bert_f1s)
    bert_precision = np.mean(bert_precisions)
    bert_recall = np.mean(bert_recalls)
    print(f"Bert F1: {bert_f1}")
    print(f"Bert Precision: {bert_precision}")
    print(f"Bert Recall: {bert_recall}")
    print(f"Bert F1-mean: {2*bert_precision*bert_recall/(bert_precision+bert_recall)}")
    random_f1 = np.mean(random_f1s)
    random_precision = np.mean(random_precisions)
    random_recall = np.mean(random_recalls)
    print(f"Random F1: {random_f1}")
    print(f"Random Precision: {random_precision}")
    print(f"Random Recall: {random_recall}")
    print(f"Random F1-mean: {2*random_precision*random_recall/(random_precision+random_recall)}")
    mostfreq_f1 = np.mean(mostfreq_f1s)
    mostfreq_precision = np.mean(mostfreq_precisions)
    mostfreq_recall = np.mean(mostfreq_recalls)
    print(f"Most Frequent F1: {mostfreq_f1}")
    print(f"Most Frequent Precision: {mostfreq_precision}")
    print(f"Most Frequent Recall: {mostfreq_recall}")
    print(f"Most Frequent F1-mean: {2*mostfreq_precision*mostfreq_recall/(mostfreq_precision+mostfreq_recall)}")
    outdict = {}
    outdict['bert'] = {'precision': bert_precision,
                       'recall': bert_recall,
                       'micro_f1': bert_f1,
                       'macro_f1': 2*bert_precision*bert_recall/(bert_precision+bert_recall), 
                      }
    outdict['random'] = {'precision': random_precision,
                         'recall': random_recall,
                         'micro_f1': random_f1,
                         'macro_f1': 2*random_precision*random_recall/(random_precision+random_recall)
                        }
    outdict['mostfreq'] = {'precision': mostfreq_precision,
                         'recall': mostfreq_recall,
                         'micro_f1': mostfreq_f1,
                         'macro_f1': 2*mostfreq_precision*mostfreq_recall/(mostfreq_precision+mostfreq_recall)
                        }
    import pdb; pdb.set_trace()
    return outdict

def readResponse(filename):
    delim = '\t'
    resp = {}
    ignore_first_line = True
    linenum = 0
    with open(filename, 'r') as fin:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            if ignore_first_line:
                ignore_first_line = False
                continue
            x = line.strip().split(delim)
            linenum += 1
            try:
                idx = int(x[0])
            except:
                import pdb; pdb.set_trace()
            cur_resp = resp.setdefault(idx, {})
            cur_resp['sub'] = x[1].strip()
            cur_resp['pred'] = x[2].strip()
            cur_resp['obj'] = x[3].strip()
            if len(x) == 4:
                print(f"using default type 'base' for {cur_resp}")
                cur_resp['primary_type'] = 'base'
            else:
                cur_resp['primary_type'] = x[4].strip()
            if len(x) > 5:
                cur_resp['other_types'] = [xx.strip() for xx in x[5:] if xx.strip() != '']
            else:
                cur_resp['other_types'] = []
    return resp

def aggregateResponse(responses, include_others=False):
    output = {}
    for response in responses:
        for key, val in response.items():
            typs = [val['primary_type']]
            if include_others:
                typs.extend(val['other_types'])
            output.setdefault(key, []).extend(typs)
            # output.setdefault(key, []).append(typs)
    return output

def evalPerf(bert, toptypes, resp, most_common=False):
    results = bert['list_of_results']
    ignore_types = 'common/topic'
    precisions = []
    recalls = []
    f1s = []
    for idx, typs in resp.items():
        cur_result = results[idx]
        if most_common:
            typs_counter = Counter(typs)
            typs = [typs_counter.most_common()[0][0]]
        true_types = cur_result['sample']['tail_types']
        true_types = set([typ.split('/')[0] for typ in true_types]).difference(ignore_types)
        true_types = true_types.intersection(toptypes)
        precision, recall, f1 = computeMetricsOld(set(typs), true_types)
        f1s.append(f1)
        precisions.append(precision)
        recalls.append(recall)
    f1 = np.mean(f1s)
    precision = np.mean(precisions)
    recall = np.mean(recalls)
    outdict = {}
    outdict['human'] = {'precision': precision,
                       'recall': recall,
                       'micro_f1': f1,
                       'macro_f1': 2*precision*recall/(precision+recall), 
                      }
    return outdict
 
def evaluate(params):
    with open(params.bertout, 'rb') as fin:
        bert = pickle.load(fin)
    with open(params.types, 'r') as fin:
        toptypes = [typ.strip() for typ in fin.readlines()]
    files = {}
    responses = {}
    for filename in os.listdir(params.response):
        filename = os.path.join(params.response, filename)
        if filename.endswith("1A.tsv"):
            files.setdefault('A', []).append(filename)
            responses.setdefault('A', []).append(readResponse(filename))
        elif filename.endswith("1B.tsv"):
            files.setdefault('B', []).append(filename)
            responses.setdefault('B', []).append(readResponse(filename))
    all_perfs = []
    for include_others in [True, False]:
        for most_common in [True, False]:
            print(f"Include Other types {include_others}")
            print(f"Most common type {most_common}")
            respA = aggregateResponse(responses['A'], include_others) 
            respB = aggregateResponse(responses['B'], include_others) 
            resp = deepcopy(respA)
            resp.update(respB)
            perf = evalPerf(bert, toptypes, resp, most_common) 
            print(perf)
            perf['other_types'] = include_others
            perf['most_common'] = most_common
            all_perfs.append(perf)
    with open(params.outfile, 'w') as fout:
        json.dump(all_perfs, fout)


def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    evaluate(params)

if __name__ == "__main__":
    main()

