import json
from tqdm import tqdm



def check_context_persona(file_name):
    with open(file_name, 'r') as f:
            data = json.load(f)

    for scene in tqdm(data):
        context_kws = scene['context_kws']
        persona_kws = scene['persona_kws']
        try:
            assert context_kws[: len(persona_kws)] == persona_kws
        except:
            print('persona_kws = ', persona_kws)
            print('context_kws = ', context_kws[: len(persona_kws)])

def extract_persona_from_context(file_name):
    with open(file_name, 'r') as f:
            data = json.load(f)

    for scene in tqdm(data):
        context_kws = scene['context_kws']
        persona = scene['entries'][-1]['cards'][-1]['description']
        persona_kws = []
        for kw in context_kws:
            if kw in persona:
                persona_kws.append(kw)
                persona = persona[persona.index(kw) + len(kw):]
            else:
                break
        scene['persona_kws'] = persona_kws
    
    with open(file_name[:-5]+'_split.json', 'w') as f:
        json.dump(data, f, indent=4, separators=[',', ':'])


if __name__ == "__main__":
    for split in ['train', 'valid', 'test']:
        print(f"checking {split} ...")
        check_context_persona(f'../data/{split}_dynamic_split.json')
        # extract_persona_from_context(f"../data/{split}_dynamic.json")
