import json
import random
from datetime import datetime
random.seed(42)

import re
import string
def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def filter_overlap_pairs(nq_answers, ctx_ans_pairs):
    new_pairs = []
    nq_answers = list(map(normalize_text, nq_answers))
    for pair in ctx_ans_pairs:
        print(set(list(map(normalize_text, pair['answers']))))        
        if len(set(list(map(normalize_text, pair['answers']))).intersection(set(nq_answers))) == 0:
            new_pairs.append(pair)

    return new_pairs

identify_ans_and_index = True
sample_closest_date=False
if identify_ans_and_index:
    sqa_data = []
    f=(open('data/qa_data/temp.dev.jsonl'))
    f1=(open('data/qa_data/temp.train.jsonl'))
    for l in f:
        sqa_data.append(json.loads(l))
    for l in f1:
        sqa_data.append(json.loads(l))

    dev_data = json.load(open('../NQ/dev.json'))
    # keys 'question', 'id', 'nq_answers', 'timelines', 'cur_answers', 'prev_answers', 'is_dependent', 'context_answer_pairs'
    print(len(sqa_data))
    qs = [l['question'] for l in dev_data]

    output = []
    for i, data in enumerate(sqa_data):
        if sqa_data[i]['is_dependent'] and data['question'] in qs:
            new_pairs = filter_overlap_pairs(data['nq_answers'], data['context_answer_pairs'])
            if not new_pairs:
                continue
            if sample_closest_date:
                print('====')
                for i in range(len(new_pairs)):
                    date = new_pairs[i]['date']
                    print('date', date)
                    if len(date.split(',')) >= 2:
                        #print('%d'%date.split(',')[0].split(' ')[1])
                        #print('%s'%date.split(',')[0].split(' ')[0])
                        #print('%d'%date.split(',')[1].strip())
                        new_pairs[i]['date'] =  '%d %s %d'%(int(date.split(',')[0].split(' ')[1]), date.split(',')[0].split(' ')[0], int(date.split(',')[1].strip()))
                    else:
                        new_pairs[i]['date'] = '1 January ' + date
                #new_pairs = ['1 January ' + l['date'] if l['date']  else  '%d %s %d'%(l.split(',')[0].split(' ')[1], l.split(',')[0].split(' ')[0], l.split(',')[1]) for l in new_pairs]
                new_pairs.sort(key = lambda x: datetime.strptime(x['date'], '%d %B %Y'))
                print(new_pairs)
                ctx_ans_pair = new_pairs[-1]
            else:
                ctx_ans_pair = random.choice(new_pairs)
            assert len(set(ctx_ans_pair['answers']).intersection(set(data['nq_answers']))) == 0
            #while len(set(ctx_ans_pair['answers']).intersection(set(data['nq_answers']))) != 0:
            #    ctx_ans_pair = random.choice(data['context_answer_pairs'])
            output.append({'nq_dev_id':qs.index(data['question']), 'answers':ctx_ans_pair['answers'], 'question':data['question'], 'edited_question':ctx_ans_pair['edited_question']})

    output = sorted(output, key=lambda x: x['nq_dev_id'])
    print(output[0])
    print(output[1])

    fw = open('situatedqa_newans_and_idx_on_nq_norm.json', 'w')
    fw.write(json.dumps(output, indent=4))

else:
    ans_and_idx = json.load(open('situatedqa_newans_and_idx_on_nq.json'))

    data = json.load(open('../NQ/dev.json'))

