from qa_generation_prompt import Prompter
from utils import set_logger, seed_everything, get_informative_paragraphs
from tqdm.auto import tqdm

import json
import re
import os

class Generator:
    def __init__(self, logger):
        self.prompter = Prompter()
        self.logger = logger

    def generate(self, month, curid, entity, length):
        '''
        Given month, curid, and number of qas, generate question-answer using wikipedia article
        '''
        file_path = f'./wiki_text/{curid}_{month}.json'
        if not os.path.exists(file_path):
            self.logger.error(f"Month: {month}, Entity: {entity}, Curid: {curid} : File not found {file_path}")
            return []
        
        with open(file_path, 'r') as f:
            wiki = json.load(f)

        out = []
        article = wiki['text']
        valid_paragraphs = get_informative_paragraphs(length, article)

        if len(valid_paragraphs) < length:
            self.logger.warning(f"Month: {month}, Entity: {entity}, Curid: {curid} = valid paragraph {len(valid_paragraphs)} < num length {length}")

        for i, paragraph in enumerate(valid_paragraphs):
            context = ' '.join(paragraph[0])
            prompt_result = self.prompter.prompt(context, entity)
            extracted_list = re.findall(r"\{(.*?)\}", prompt_result)
            res = {}

            if len(extracted_list) != 2:
                self.logger.warning(f"Month: {month}, Entity: {entity}, Curid: {curid} = Failed to generate {i}-th qa pair due to GPT-4 error")
                continue
            
            res["question"] = extracted_list[0]
            res["answer"] = extracted_list[1]
            res["grounded_text"] = paragraph[0]
            res["idx"] = (paragraph[1], paragraph[2])
            out.append(res)
        
        out = sorted(out, key=lambda x: x["idx"][0])
        return out

def main():
    seed_everything()
    logger = set_logger()
    generator = Generator(logger)
    with open('./final_curid_timestamp_mentions.json', 'r') as f:
        input = json.load(f)

    new_js = {key: {'0506': [], '0708': [], '0910': [], '1112': [], '0102': [], '0304': []} for key in input.keys()}

    for key, value in tqdm(input.items()):
        entity = key.split('/')[-1]
        curid = str(value['curid'])
        for month, in_value in tqdm(value.items(), desc = f"key = {key}", leave=False):
            if month == 'curid': continue
            length = in_value[0]
            if length == 0:
                new_js[key][month] = []
                logger.info(f"Month: {month}, Entity: {entity}, Curid: {curid} = length = 0")
                continue

            qa = generator.generate(month, curid, entity, length)
            new_js[key][month] = qa
            logger.info(f"Month: {month}, Entity: {entity}, Curid: {curid} = Generated {len(qa)} qa pairs")

            with open('./generated_qa.json', 'w') as f:
                json.dump(new_js, f, indent=3)

if __name__ == "__main__":
    main()