import argparse
import json
from span_output_formatting import load_gold


def cal_em(preds, golds, gold_ids):
    def _exact_match_score(pred, gold):
        if gold:
            score = len(set(pred) & set(gold)) / len(set(pred) | set(gold))

        else:
            if pred:
                score = 0.0
            else:
                score = 1.0
        return score

    em = 0
    total = 0
    for gold_id in gold_ids:
        if gold_id in preds:
            p = preds[gold_id]
            g = golds[gold_id]
            em += _exact_match_score(p[args.mode], g[args.mode])
            total += 1
    scores = {"em": em / total, "em_total": total}
    return scores


def cal_rf1(preds, golds, gold_ids):
    def _f1_score(pred, gold):
        if len(pred) == 0 and len(gold) == 0:
            precision = 1
            recall = 1
            f1 = 1.0
        else:
            num_same = len(set(pred) & set(gold))
            if num_same == 0:
                precision = 0
                recall = 0
                f1 = 0.0
            else:
                precision = 1.0 * num_same / len(pred)
                recall = 1.0 * num_same / len(gold)
                f1 = (2 * precision * recall) / (precision + recall)
        return precision, recall, f1

    precision = 0
    recall = 0
    rf1 = 0
    total = 0
    for gold_id in gold_ids:
        if gold_id in preds:
            pred = preds[gold_id]
            gold = golds[gold_id]
            p, r, f1 = _f1_score(pred[args.mode], gold[args.mode])
            precision += p
            recall += r
            rf1 += f1
            total += 1
    scores = {"r_precision": precision / total, "r_recall": recall / total, "r_f1": rf1 / total, "rf1_total": total}
    return scores


def cal_token_f1(preds, golds, gold_ids):
    def _f1_score(pred, gold):
        if len(pred) == 0 and len(gold) == 0:
            precision = 1
            recall = 1
            f1 = 1.0
        else:
            num_same = len(set(pred) & set(gold))
            if num_same == 0:
                precision = 0
                recall = 0
                f1 = 0.0
            else:
                precision = 1.0 * num_same / len(pred)
                recall = 1.0 * num_same / len(gold)
                f1 = (2 * precision * recall) / (precision + recall)
        return precision, recall, f1

    def _char_to_word(data):
        word_span = []
        cur_idx = 0
        for i, word in enumerate(data["comment"].split(" ")):
            if cur_idx in data[args.mode]:
                word_span.append(i)
            cur_idx += len(word) + 1
        return word_span

    precision = 0
    recall = 0
    f1 = 0
    total = 0
    for gold_id in gold_ids:
        if gold_id in preds:
            pred = preds[gold_id]
            gold = golds[gold_id]
            p, r, f = _f1_score(_char_to_word(pred), _char_to_word(gold))
            precision += p
            recall += r
            f1 += f
            total += 1
    scores = {"tok_precision": precision / total, "tok_recall": recall / total, "tok_f1": f1 / total, "tf1_total": total}
    return scores


def cal_iou_f1(preds, golds, gold_ids, threshold=0.5):
    def _char_to_word(data):
        word_span = []
        cur_idx = 0
        for i, word in enumerate(data["comment"].split(" ")):
            if cur_idx in data[args.mode]:
                word_span.append(i)
            cur_idx += len(word) + 1
        return word_span

    def _cal_iou(pred, gold, threshold=0.5):
        if len(pred) == 0 and len(gold) == 0:
            precision = 1
            recall = 1
            f1 = 1.0
        else:
            num_same = len(set(pred) & set(gold))
            if num_same == 0:
                precision = 0
                recall = 0
                f1 = 0.0
            else:
                if num_same / len(set(pred) | set(gold)) > threshold:
                    precision = 1
                    recall = 1
                    f1 = 1
                else:
                    precision = 0
                    recall = 0
                    f1 = 0.0
        return precision, recall, f1

    precision = 0
    recall = 0
    iou_f1 = 0
    total = 0
    for gold_id in gold_ids:
        if gold_id in preds:
            pred = preds[gold_id]
            gold = golds[gold_id]
            p, r, f1 = _cal_iou(_char_to_word(pred), _char_to_word(gold))
            precision += p
            recall += r
            iou_f1 += f1
            total += 1
    scores = {"iou_f1": iou_f1 / total}
    return scores


def measure_faithfulness(pred_dir, gold_dir, offensive_only):
    scores = {}

    if pred_dir.endswith("json"):
        preds = json.load(open(pred_dir, "r", encoding='utf-8'))
    else:
        raise ValueError("Only supports json file")

    golds = load_gold(gold_dir, offensive_only)

    id_to_preds = map_id_to_data(preds)
    id_to_golds = map_id_to_data(golds)
    gold_ids = list(id_to_golds.keys())

    scores.update(cal_em(id_to_preds, id_to_golds, gold_ids))
    scores.update(cal_rf1(id_to_preds, id_to_golds, gold_ids))
    scores.update(cal_token_f1(id_to_preds, id_to_golds, gold_ids))
    scores.update(cal_iou_f1(id_to_preds, id_to_golds, gold_ids))
    for key in scores:
        print(f"{key}: {scores[key]}")
    return scores


def map_id_to_data(json_list):
    """Create a map from id to its prediction/gold data"""
    comment_id_to_data = {}
    for obj in json_list:
        if "text_id" in obj:
            comment_id_to_data[obj["text_id"]] = obj
        elif "guid" in obj:
            comment_id_to_data[obj["guid"]] = obj
        elif "id" in obj:
            comment_id_to_data[obj["id"]] = obj
        else:
            raise ValueError(f"No attr for ID: {obj.keys()}")
    return comment_id_to_data


def main():
    scores = measure_faithfulness(args.pred_dir, args.gold_dir, args.offensive_only)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--pred_dir", type=str, required=True)
    parser.add_argument("--gold_dir", type=str, required=True)
    parser.add_argument("--mode", choices=["off_span", "tgt_span"], required=True)
    parser.add_argument("--offensive_only", action='store_true')
    args = parser.parse_args()

    main()
