import copy
import pandas
import json
import string

def remove_punctuation(s):
    return s.translate(str.maketrans('', '', string.punctuation))


def map_feedback_inconstency_to_sentence(sent_summary, feedback, threshold=0.8):
    new_feedback = []
    for i, f in enumerate(feedback):
        inconsistency_copy = remove_punctuation(f['inconsistency'].strip().lower())
        found = False
        sentence = None
        for j, s in enumerate(sent_summary):
            s_copy = remove_punctuation(s.lower().strip())
            if inconsistency_copy in s_copy or s_copy in inconsistency_copy:
                found = True
                sentence = s
                break
            else:
                # sometimes the span is decontextualized and this helps ensure that 90% of the inconsistency tokens are in the sentence
                words_sentence = set(s_copy.split())
                words_inconsistency = set(inconsistency_copy.split())
                intersection = words_inconsistency.intersection(words_sentence)
                if len(intersection) / len(words_inconsistency) >= threshold:
                    sentence = s
                    found = True
                    break
        if found:
            f['sentence'] = sentence
            new_feedback.append(f)
        else:
            continue
    for f in new_feedback:
        assert "sentence" in f
    return new_feedback


def get_sentencewise_data(combined_pd):
    """
    Expects  combined dataframe where each row is a summary level feedback, refinements
    the feedback should have already gone through the "map_feedback_inconstency_to_sentence" method
    :param combined_pd:
    :return:
    """
    combined_jsonl= combined_pd.to_json(orient="records", lines=True).strip().split("\n")
    sentence_wise_unrolled = []
    found = 0
    total = 0
    for index, d in enumerate(combined_jsonl):
        d = json.loads(d)
        sent_summary = d['sent_summary']
        labels = d['sent_wise_labels']
        feedback = d['feedback']
        d['feedback_summary'] = copy.deepcopy(d['feedback'])
        d['feedback'] = ""
        if len(feedback)==0:
            for s, l in zip(sent_summary, labels):
                dnew = copy.deepcopy(d)
                dnew['sentence'] = s
                dnew['minicheck_label'] = l
                dnew['label'] = 1
                assert dnew['feedback'] == ""
                sentence_wise_unrolled.append(dnew)
        else:
            for s, l in zip(sent_summary, labels):
                dnew = copy.deepcopy(d)
                mapped_feedback = None
                for i, f in enumerate(feedback):
                    if f['sentence'] == s:
                        mapped_feedback = f
                        break
                if not mapped_feedback:
                    dnew['sentence'] = s
                    dnew['minicheck_label'] = l
                    dnew['label'] = 1
                    assert dnew['feedback']==""
                else:
                    dnew['minicheck_label'] = l
                    dnew['label'] = 0
                    for k in mapped_feedback:
                        dnew[k] = mapped_feedback[k]
                    found+=1
                sentence_wise_unrolled.append(dnew)
            total+=1
    print(found, total)
    return pandas.DataFrame(sentence_wise_unrolled)
