import re
from src.classes.qadataset import QADataset
import os
import sys
import json
import string
import random

random.seed(42)
# dset_name = os.path.basename(args.inpath).split(".")[0]

def select_random_non_identical_answer(ex, data):
    """Randomly samples an answer from `sample_set` that is non-identical to the gold answers
    currently represented in the QAExample."""
    norm_gold_answers = {normalize_text(ga): ga for ga in ex['answers']}
    sample_set = []
    for l in data:
        sample_set += l['answers']
    sub_ans = None
    while not sub_ans or normalize_text(sub_ans) in norm_gold_answers:
        sub_ans = random.choice(sample_set)
    return sub_ans


def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def _find_answer_in_context(answer_text: str, context: str):
    """Finds all instances of the `answer_text` in the context passage.
    
    Returns a list of (start index, end index) tuples.
    """
    context_spans = [
        (m.start(), m.end())
        for m in re.finditer(re.escape(answer_text.lower()), context.lower())
    ]
    return context_spans

def update_context_with_substitution(sub_answer, gold_answers, context):
    """Replace all found instances of the answer in the context."""
    replace_spans = []
    replace_answers = (
        gold_answers
    )
    for orig_answer in gold_answers:
        replace_spans.extend(
            _find_answer_in_context(orig_answer, context)
        )
    # Find and replace all string variants that correspond to the original answer in the context
    replace_strs = set([context[span[0] : span[1]] for span in replace_spans])
    for replace_str in replace_strs:
        context = context.replace(replace_str, sub_answer)

    return context


def find_ex_wo_entity(dset):
    """
        find examples without entity, and write into json format like FiD
    """
    data = []
    for ex in dset.examples:
        answers = []
        for answer in ex.gold_answers:
            if not answer.answer_type:
                answers.append(answer)
        if answers:
            data.append(
                {
                    'question': ex.query,
                    'ctxs': [{"id":"0", "title": "", "text": ex.context}],
                    'answers': [l.text for l in answers],
                    'qid': ex.uid

                }
            )

    print(len(data))

    fw = open('data.json', 'w')
    fw.write(json.dumps(data, indent=4))
    fw.close()
    new_data = []
    for ex in data:
        sub_answer = select_random_non_identical_answer(
            ex, data
        )
        new_context = update_context_with_substitution(sub_answer, ex['answers'], ex['ctxs'][0]['text'])
        print('='*50)
        print('ex answer', ex['answers'])
        print('sub answer', sub_answer)
        print(ex['ctxs'][0]['text'])
        print(new_context)
        new_data.append(
            {
                'question': ex['question'],
                'ctxs': [{"id":"0", "title": "", "text": new_context}],
                'answers': [sub_answer],
                'qid': ex['qid']

            }
        )

    fw = open('new_data.json', 'w')
    fw.write(json.dumps(new_data, indent=4))


#TODO: how do decide whether entity?

def group_answers_by_answer_type(dset: QADataset):
    """Reorganizes a QADataset into a mapping from answer type to member answers."""
    group_to_answer_sets = defaultdict(dict)
    for ex in dset.examples:
        for answer in ex.gold_answers:
            if answer.answer_type:
                group_to_answer_sets[answer.answer_type][answer.text] = answer
    return group_to_answer_sets


if __name__ == '__main__':
    dset_name = sys.argv[1]
    preprocessed_dataset = QADataset.load(dset_name)

    find_ex_wo_entity(preprocessed_dataset)

