import json
import argparse
import numpy as np
from tqdm import tqdm
from collections import Counter
from .retriever import SentenceBertRetriever


def read_cli():
    parser = argparse.ArgumentParser(description='Predict Hints')
    parser.add_argument(
        "-index",
        "--index_path",
        help="Path to index file",
        required=True,
        type=str,
    )
    parser.add_argument(
        "-src",
        "--src_path",
        help="Path to src file",
        required=True,
        type=str,
    )
    parser.add_argument(
        "-ds",
        "--dataset",
        help="dataset",
        required=True,
        type=str,
        choices=['MultiWOZ', 'SMD']
    )
    parser.add_argument(
        "-k",
        "--k",
        help="K in KNN",
        required=False,
        type=int,
        default=10,
    )
    args = parser.parse_args()

    return args



def get_indexed_entity_types(entity_types):
    cnts = Counter(entity_types).most_common()
    ret = set()
    for et, cnt in cnts:
        for jj in range(cnt):
            ret.add(f"{et}#{jj}")

    return ret


def knn_prediction(nrbs_hints):
    phints = dict()

    # Num words
    wlens = [x['word_count'] for x in nrbs_hints]
    phints['word_count'] = int(np.ceil(np.mean(wlens)))

    # Closure
    cvals = [float(x['closure']) for x in nrbs_hints]
    cavg = np.mean(cvals)
    phints['closure'] = bool(cavg > 0.5)

    # Etypes
    nrbs_etypes = []
    unique_etypes = []
    for ee in nrbs_hints:
        etypes = ee['entity_types']
        for et in etypes:
            if et not in unique_etypes:
                unique_etypes.append(et)

        etypes = get_indexed_entity_types(etypes)
        nrbs_etypes.extend(etypes)

    etypes_cnts = Counter(nrbs_etypes)
    pred_etypes = []
    for et, ecnt in etypes_cnts.most_common():
        assert ecnt <= len(nrbs_hints)
        pval = (1.0 * ecnt) / len(nrbs_hints)
        if pval >= 0.5:
            pred_etypes.append(et)

    pred_etypes = [x.split('#')[0] for x in pred_etypes]
    phints['entity_types'] = sorted(pred_etypes, key=lambda x: unique_etypes.index(x))

    return phints


class RuleHintPredicter:
    def __init__(self, index_data, dataset, k=None):
        self.retriever = SentenceBertRetriever(
            index_data, dataset
        )
        self.dataset = dataset
        self.k = k

    def predict_hints(self, sample):
        assert self.k is not None
        if self.dataset == 'MultiWOZ':
            text = sample['context'][-1]
        elif self.dataset == 'SMD':
            text = ' '.join(sample['context'])

        nrbs = self.retriever.search_top_k(text, self.k, uuid=sample['uuid'])
        nrbs_hints = [x['hints'] for x in nrbs]
        ret = knn_prediction(nrbs_hints)

        if self.k == 1:
            assert nrbs_hints[0]['closure'] == ret['closure']
            assert sorted(nrbs_hints[0]['entity_types']) == sorted(ret['entity_types'])
            assert nrbs_hints[0]['word_count'] == ret['word_count']

        return ret


def run(args):
    with open(args.index_path, 'r') as fp:
        index_data = json.load(fp)

    with open(args.src_path, 'r') as fp:
        data = json.load(fp)

    print(f'Loaded {len(index_data)} index samples.')
    print(f'Loaded {len(data)} data samples.')

    predicter = RuleHintPredicter(index_data, args.dataset)

    predicter.k = 1
    for sample in tqdm(data):
        sample['k1_hints'] = predicter.predict_hints(sample)
    
    print('k = 1 Results')
    evaluate(
        [x['hints'] for x in data],
        [x['k1_hints'] for x in data],
    )

    predicter.k = args.k
    for sample in tqdm(data):
        sample[f'k{args.k}_hints'] = predicter.predict_hints(sample)

    print(f'k = {args.k} Results')
    evaluate(
        [x['hints'] for x in data],
        [x[f'k{args.k}_hints'] for x in data],
    )

    with open(args.src_path, 'w') as fp:
        json.dump(data, fp, indent=2)


def evaluate(gold_hints, pred_hints):
    corr = 0
    tp, fn, fp = 0, 0, 0
    corr = 0
    for ii in range(len(gold_hints)):
        corr += int(gold_hints[ii]['closure'] == pred_hints[ii]['closure'])

        gold_types = set(gold_hints[ii]['entity_types'])
        pred_types = set(pred_hints[ii]['entity_types'])

        if len(gold_types) == 0:
            gold_types = {'[no entity]'}
        if len(pred_types) == 0:
            pred_types = {'[no entity]'}

        inter = gold_types.intersection(pred_types)
        ttp = len(inter)
        tfp = len(pred_types - gold_types)
        tfn = len(gold_types - pred_types)

        tp += ttp
        fp += tfp
        fn += tfn

    prec = tp / (tp + fp) if tp + fp > 0 else 0
    rec = tp / (tp + fn) if tp + fn > 0 else 0
    f1 = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0

    acc = corr / len(gold_hints)
    print(f'Closure Accuracy: {round(acc, 4)}')
    print(f'Etypes F1: {round(f1, 4)}')
    print(f'Etypes Precision: {round(prec, 4)}')
    print(f'Etypes Recall: {round(rec, 4)}')


if __name__ == '__main__':
    args = read_cli()
    run(args)
