def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    import re, string

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    """Compute the geometric mean of precision and recall for answer tokens."""
    from collections import Counter
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    # print(f"{prediction} vs {ground_truth} : {normalize_answer(prediction)} : {normalize_answer(ground_truth)}")
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)



def sentence_set_similarity(preds, labels, device='cpu'):
    # from torchmetrics.functional.text.bert import bert_score
    """
        preds: list of strings, each strings is in JSON format
        labels: list of strings, each strings is in JSON format
    """
    import json
    from tqdm.auto import tqdm
    import torch
    import bert_score
    def _save_json_loads(x):
        try:
            return json.loads(x)
        except:
            return None
    
    assert isinstance(preds, list), f'preds is not a list'
    assert isinstance(labels, list), f'labels is not a list'
    bert_scorer = bert_score.BERTScorer(model_type="roberta-large", lang="en", rescale_with_baseline=True, device=device)
    results = []
    for pred_str, label_str in tqdm(zip(preds, labels)):
        label_sent_list = _save_json_loads(label_str)
        if label_sent_list is None or not all([isinstance(x, str) for x in label_sent_list]):
            # print(f'[WARNING] label is None or not all([isinstance(x, str) for x in label])')
            continue
        pred_sent_list = _save_json_loads(pred_str)
        if pred_sent_list is None or not all([isinstance(x, str) for x in pred_sent_list]):
            # print(f'[WARNING] pred is None or not all([isinstance(x, str) for x in pred])')
            results.append({"precision": 0.0, "recall": 0.0, "f1": 0.0})
            continue
        sent_pairs = []
        for i, pred_sent in enumerate(pred_sent_list):
            for j, label_sent in enumerate(label_sent_list):
                sent_pairs.append((pred_sent, label_sent))
        # print(f'len pred_sent_list: {len(pred_sent_list)}, len label_sent_list: {len(label_sent_list)}, len sent_pairs: {len(sent_pairs)}')
        sent_pairs_preds, sent_pairs_target = zip(*sent_pairs)
        # sent_pairs_preds = sent_pairs_preds[-1:]
        # sent_pairs_target = sent_pairs_target[-1:]
        # print(f'type sent_pairs_preds: {type(sent_pairs_preds)}, type sent_pairs_target: {type(sent_pairs_target)}')
        # print(f'len sent_pairs_preds: {len(sent_pairs_preds)}, len sent_pairs_target: {len(sent_pairs_target)}')
        bert_score_res = bert_scorer.score(sent_pairs_preds, sent_pairs_target)
        bert_score_p, bert_score_r, bert_score_f  = bert_score_res
        bert_score_f = bert_score_f.reshape(len(pred_sent_list), len(label_sent_list))
        # print(f'bert_score_f.shape: {bert_score_f.shape}')
        # print(f'bert_score_f: {bert_score_f}')
        # print(f'bert_score_f.max(dim=1): {bert_score_f.max(dim=1)}')
        # print(f'bert_score_f.max(dim=1): {bert_score_f.max(dim=0)}')
        P = bert_score_f.max(dim=1).values.mean().item()
        R = bert_score_f.max(dim=0).values.mean().item()
        F1 = 2 * P * R / (P + R + 1e-12)
        results.append({"precision": P, "recall": R, "f1": F1})
        # score_list.append(F1)
    # return sum(score_list) / len(score_list) if len(score_list) > 0 else 0.0
    # print(results)
    keys = ["precision", "recall", "f1"]
    return {
        key: torch.tensor([x[key] for x in results]) for key in keys
    }

if __name__ == "__main__":
    # predictions = ["hello world", "general kenobi"]
    # references = ["goodnight moon", "the sun is shining"]

    # import torchmetrics
    # results = torchmetrics.functional.text.bert.bert_score(predictions, references, model_name_or_path="roberta-large", device="cpu")
    # print("torchmetrics", results)

    # import evaluate
    # bertscore = evaluate.load("bertscore")
    # results = bertscore.compute(predictions=predictions, references=references, model_type="roberta-large", device="cpu")
    # print("huggingface eval", results)

    # results = bert_score.score(cands=predictions, refs=references, model_type="roberta-large", device="cpu")
    # print("official bert_score", results)

    preds = [
        """[""Presidential elections were held in Sri Lanka for the first time on 20 October 1982.", "Names were accepted on 17 September 1982.", "Election participation was 81.06%.", "The election was described as a fight between capitalism and socialism.", "Hector Kobbekaduwa advocated to carry on the policies of the Sri Lanka Freedom Party - led regime from 1970-1977.", "Hector Kobbekaduwa was expected to undo most of the open market and capitalist reforms brought in by J. R. Jayewardene.", "Incumbent president Jayewardene of the governing United National Party was elected.", "Jayewardene received 53% of all votes cast.", "The SLFP lost a significant number of votes in Tamil speaking areas such as Point Pedro."]""",
    ]
    labels = [
        """["Nominations for the presidential elections were accepted on 17 September 1982.", "Electoral participation in the presidential elections was 81.06%.", "The election was described as a fight between capitalism and socialism.", "Hector Kobbekaduwa advocated to carry on the policies of the Sri Lanka Freedom Party-led regime from 1970-1977.", "Hector Kobbekaduwa was expected to undo most of the open market and capitalist reforms brought in by J. R. Jayewardene.", "Incumbent president Jayewardene of the governing United National Party was elected.", "Jayewardene received 53% of all votes cast in the presidential election.", "Although the SLFP lost, they managed to win a significant number of votes in Tamil speaking areas such as Point Pedro."]""",
        # """["Presidential elections were held in Sri Lanka for the first time on 20 October 1982.", "Nominations for the presidential elections were accepted on 17 September 1982.", "Electoral participation in the presidential elections was 81.06%.", "The election was described as a fight between capitalism and socialism.", "Hector Kobbekaduwa advocated to carry on the policies of the Sri Lanka Freedom Party-led regime from 1970-1977.", "Hector Kobbekaduwa was expected to undo most of the open market and capitalist reforms brought in by J. R. Jayewardene.", "Incumbent president Jayewardene of the governing United National Party was elected.", "Jayewardene received 53% of all votes cast in the presidential election.", "Although the SLFP lost, they managed to win a significant number of votes in Tamil speaking areas such as Point Pedro."]""",
    ]
    print(sentence_set_similarity(preds, labels))
