import os
import sys
import json
import random
random.seed(42)
import re
import string

def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def filter_ans(org_answers, new_answers):
    filtered_answers = []
    org_answers = [normalize_text(l) for l in org_answers]

    for ans in new_answers:
        if normalize_text(ans) not in org_answers:
            filtered_answers.append(ans)
    return filtered_answers

# ['viewed_doc_titles', 'used_queries', 'annotations', 'nq_answer', 'id', 'nq_doc_title', 'question']
# annotations - 'type',  'answer'
# used_queries - 'query', 'results', results have multiple (title, snippet) pair
data = json.load(open('dev.json'))
nq = json.load(open('../NQ/dev.json'))
qs = [l['question'] for l in nq]
print(len(data))

#'qaPairs': [{'question': 'How long does a casual match last in Rainbow Six Siege?', 'answer': ['four minutes']}, {'question': 'How long does a ranked match last in Rainbow Six Siege?', 'answer': ['three minutes']}]}]
# needed - ['nq_dev_id', 'answers', 'question', 'edited_question']

total = 0
count_amb = 0
new_data = []
only1 = 0
qqq=0
for l in data:
    """
    print('='*30)
    print(l['annotations'])
    for c in range(len(l['used_queries'])):
        print('query %d'%c)
        print(l['used_queries'][c]['query'])
        print('results....')
        for d in range(len(l['used_queries'][c]['results'])):
            print(l['used_queries'][c]['results'][d])
    """
    #print('-'*30)
    #print(l['nq_answer'])
    #print(l['annotations'])
    #if l['question'].lower()[:-1] in set(qs):
    #    total += 1
    inst = {}
    inst['nq_dev_id'] = qs.index(l['question'].lower()[:-1])

    qas = []
    for ann in l['annotations']:
        if ann['type'] == 'multipleQAs':
            qas.append(ann)
    if qas:
        annotation = random.choice(qas)
        qa = random.choice(annotation['qaPairs'])
        inst['edited_question'] = qa['question']
        org_answers = nq[inst['nq_dev_id']]['answers']
        new_ans = filter_ans(org_answers, qa['answer'])
        if not new_ans and len(qas) == 1:
            only1+=1
            continue
        elif not new_ans:
            has_ans = False
            for qa in annotation['qaPairs']:
                new_ans = filter_ans(org_answers, qa['answer'])
                if new_ans:
                    inst['edited_question'] = qa['question']
                    break
            if not has_ans:
                qqq+=1
                continue
        inst['answers'] = new_ans
        inst['question'] = qs[inst['nq_dev_id']]
        # append instance to data list
        new_data.append(inst)
        
print(len(new_data))
print(qqq)
print(only1)
# write data to a file
fw = open('new_.json', 'w')
fw.write(json.dumps(new_data, indent=4))
