import argparse
import pprint
import re
import string
from rouge import Rouge
import fire
import json
from collections import Counter
from nltk.tokenize import sent_tokenize


# utility to get gold answers
def normalize_answer(answer):
    """Only evaluate first sentence"""
    answer = sent_tokenize(answer)[0]
    """Lower text and remove punctuation, articles and extra whitespace."""
    def truncate(text):
        return text.replace('</s>', '').strip().split('\n')[0]
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)
    def white_space_fix(text):
        return " ".join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(truncate(answer)))))



# F1 score definition
def _f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def _f1_score_exclude_question(question, prediction, ground_truth):
    question_tokens = normalize_answer(question).split()
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens) - Counter(question_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize

# F1 score definition
def _f1_score_only_tag(prediction, ground_truth, tag=['NNP', 'NN', 'NNS', 'NNPS', 'CD']):
    prediction = normalize_answer(prediction)#.split()
    ground_truth = normalize_answer(ground_truth)#.split()
    prediction_tokens = [i[0] for i in pos_tag(word_tokenize(prediction)) if i[1] in tag]
    ground_truth_tokens = [i[0] for i in pos_tag(word_tokenize(ground_truth)) if i[1] in tag]
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    #print(common)
    #time.sleep(1)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    #if f1 >= 0.5:
    #    return f1
    #else:
    #    return 0
    return f1


# ROUGEL score definition
def _rougel_score(prediction, ground_truth):
    rouge = Rouge()
    # no normalization
    try:
        scores = rouge.get_scores(prediction, ground_truth, avg=True)
    except ValueError:  # "Hypothesis is empty."
        return 0.0
    return scores["rouge-l"]["f"]


def _get_gold_and_pred(results):
    golds = list(map(lambda x: normalize_answer(x['answer']), results))
    preds = list(map(lambda x: normalize_answer(x['prediction'][0]["generated_text"][len(x["prompt"]):]), results))
    #preds = [' '.join(i) for i in preds]
    return golds, preds


def _calculate_metrics(results):
    
    golds, preds = _get_gold_and_pred(results)
    total_count = len(golds)

    # downstream metrics
    f1 = 0
    rougel = 0

    for gold, pred in zip(golds, preds):
        if len(pred) == 0: # empty answer
            continue
        f1 += _f1_score(pred, gold)
        rougel += _rougel_score(pred, gold)

    if total_count > 0:
        f1 /= total_count
        rougel /= total_count

    print(round(f1*100, 2))
    print(round(rougel*100, 2))
    
    return {
        "downstream": {
            "f1": f1,
            "rougel": rougel,
        },
    }


def evaluate(fn):
    with open(fn) as f:
        js = json.load(f) # list of dictionary results
    result = _calculate_metrics(js)
    print(result)



def check_hit(item):
    retrievals = [i["document"] for i in item["retrieval"]]
    evidence_texts = item["grounded_text"]
    for ret in retrievals:
        for evd in evidence_texts:
            if evd in ret:
                return "hit"
    return "miss"
    
def _get_gold_pred_and_hits(results):
    questions = list(map(lambda x: normalize_answer(x['question']), results))
    golds = list(map(lambda x: normalize_answer(x['answer']), results))
    preds = list(map(lambda x: normalize_answer(x['prediction'][0]["generated_text"][len(x["prompt"]):]), results))
    hits = list(map(check_hit, results))
    #preds = [' '.join(i) for i in preds]
    return questions, golds, preds, hits


def _calculate_metrics_per_type(results):
    
    questions, golds, preds, hits = _get_gold_pred_and_hits(results)

    # downstream metrics
    results = {t : {m: 0 for m in ["f1", "rougel"]} for t in ["all", "hit", "miss"]}
    counts = {t : 0 for t in ["all", "hit", "miss"]}

    for question, gold, pred, hit in zip(questions, golds, preds, hits):
        if len(pred) == 0: # empty answer
            continue
        f1 = _f1_score_exclude_question(question, pred, gold) 
        f1 = _f1_score(pred, gold)
        rougel = _rougel_score(pred, gold)
        results["all"]["f1"] += f1 
        results[hit]["f1"] += f1
        results["all"]["rougel"] += rougel 
        results[hit]["rougel"] += rougel 
        counts["all"] += 1
        counts[hit] += 1

    for eval_type, item in results.items():
        print("----------------\nType: ", eval_type)
        if eval_type == "hit":
            print("Hit rate:", round(counts["hit"]/counts["all"] * 100, 2))
        if eval_type == "miss":
            print("Miss rate:", round(counts["miss"]/counts["all"] * 100, 2))
        for metric, value in item.items():
            print(f"{metric}: {round(value/counts[eval_type] * 100, 2)}")
        print("----------------\n")


def _get_gold_pred_and_el_succeed(results, gold_dict):
    questions = list(map(lambda x: normalize_answer(x['question']), results))
    golds = list(map(lambda x: normalize_answer(x['answer']), results))
    preds = list(map(lambda x: normalize_answer(x['prediction'][0]["generated_text"][len(x["prompt"]):]), results))
    els = list(map(lambda x: "Succeed" if x["entity_pred"][0] == gold_dict[x["qa_id"]] else "Fail", results))
    #preds = [' '.join(i) for i in preds]
    return questions, golds, preds, els
    
            
def _calculate_metrics_per_el(results, gold_dict):
    
    questions, golds, preds, els = _get_gold_pred_and_el_succeed(results, gold_dict)

    # downstream metrics
    results = {t : {m: 0 for m in ["f1", "rougel"]} for t in ["All", "Succeed", "Fail"]}
    counts = {t : 0 for t in ["All", "Succeed", "Fail"]}

    for question, gold, pred, el in zip(questions, golds, preds, els):
        if len(pred) == 0: # empty answer
            continue
        f1 = _f1_score_exclude_question(question, pred, gold) 
        f1 = _f1_score(pred, gold)
        rougel = _rougel_score(pred, gold)
        results["All"]["f1"] += f1 
        results[el]["f1"] += f1
        results["All"]["rougel"] += rougel 
        results[el]["rougel"] += rougel 
        counts["All"] += 1
        counts[el] += 1

    for eval_type, item in results.items():
        print("----------------\nType: ", eval_type)
        if eval_type == "Succeed":
            print("Succeed rate:", round(counts["Succeed"]/counts["All"] * 100, 2))
        if eval_type == "Fail":
            print("Failure rate:", round(counts["Fail"]/counts["All"] * 100, 2))
        for metric, value in item.items():
            print(f"{metric}: {round(value/counts[eval_type] * 100, 2)}")
        print("----------------\n")
            
'''
def _get_gold_pred_if_hit(results, is_hit=True):
    
    hits = list(map(chechk_hit, results))
    questions, golds, preds = [], [], []
    target = "hit" if is_hit else "miss"
    for hit in hits:
        if hit == target:
            questions.append(normalize_answer(x['question']))
            golds.append(normalize_answer(x['answer']))
            preds.append(normalize_answer(x['prediction'][0]["generated_text"][len(x["prompt"]):]))
    return questions, golds, preds, hits
'''

def _get_result_if_hit(results, is_hit=True):
    hits = list(map(check_hit, results))
    target = "hit" if is_hit else "miss"
    items = []
    for result, hit in zip(results, hits):
        if hit == target:
            items.append(result)
    return items 

        
def  _calculate_metrics_per_hit_el(results, gold_dict):
    js = _get_result_if_hit(results, is_hit=True)
    print("Retrieval Hit")
    _calculate_metrics_per_el(js, gold_dict)
    print("\n\nRetrieval Miss")
    js = _get_result_if_hit(results, is_hit=False)
    _calculate_metrics_per_el(js, gold_dict)
    
    

def evaluate_per_hit(fn):
    with open(fn) as f:
        js = json.load(f) # list of dictionary results
    _calculate_metrics_per_type(js)
    
    
def evaluate_per_el(fn, el_gold_fn):
    with open(fn) as f:
        js = json.load(f)
    with open(el_gold_fn) as f:
        gold_dict = json.load(f)
    _calculate_metrics_per_el(js, gold_dict)
    
    
def evaluate_per_hit_el(fn, el_gold_fn):
    with open(fn) as f:
        js = json.load(f)
    with open(el_gold_fn) as f:
        gold_dict = json.load(f)
    _calculate_metrics_per_hit_el(js, gold_dict)    


def main(fn, per_hit=False, per_el=False, per_hit_el=True):
    if per_hit:
        evaluate_per_hit(fn)
    elif per_el:
        evaluate_per_el(fn)
    elif per_hit_el:
        evaluate_per_hit_el(fn)
    else:
        evaluate(fn)


if __name__ == "__main__":
    fire.Fire(main)