import json
import random
import re
from utils import set_logger, seed_everything

def main():
    seed_everything()
    logger = set_logger()

    with open('./generated_qa.json', 'r') as f:
        js_qa = json.load(f)

    with open('./final_curid_timestamp_mentions.json', 'r') as f:
        js_mentions = json.load(f)
    
    new_js = {key: {'0506': [], '0708': [], '0910': [], '1112': [], '0102': [], '0304': []} for key in js_qa.keys()}

    months = ['0506', '0708', '0910', '1112', '0102', '0304']


    for key, value in js_qa.items():
        for month in months:
            mentions = js_mentions[key][month][1]
            random.shuffle(mentions)
            
            # logger.info(f"key: {key}, month: {month}, mentions: {mentions}")
            new_qa_pairs = []

            for qa_pair in value[month]:
                qa_pair['retrieval_idx'] = qa_pair.pop('idx')
                question = qa_pair['question']
                matches = list(re.finditer(r'\[(.*?)\]', question))
                if len(matches) != 1:
                    logger.warning(f"key: {key}, month: {month}, question: {question} / len(matches) != 1")
                    continue
                match = matches[0]
                _start, _end = match.start(), match.end()

                subs_mention = mentions.pop(0)
                resolved_question = re.sub(r'\[.*?\]', subs_mention, question)
                _end = _start + len(subs_mention) - 1
                qa_pair['question'] = resolved_question
                qa_pair['mention_idx'] = (_start, _end)
                new_qa_pairs.append(qa_pair)

            new_js[key][month] = new_qa_pairs

    with open('./resolved_qa_w_idx.json', 'w') as f:
        json.dump(new_js, f, indent=3)


if __name__ == "__main__":
    main()