import json 
import pandas as pd 
import numpy as np
import requests
from tqdm import tqdm
import os

from wikidata.client import Client

OCCUPATIONS_DF="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_ents_wdata_comp_people_occupations.csv"
OCCUPATIONS_DF_WITH_LABELS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_ents_wdata_comp_people_occupations_with_occupation_names.csv"
PATH_TO_QRANK_CORPUS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank.json"
PATH_TO_NEW_CORPUS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations_fixed.json"
if not os.path.exists(OCCUPATIONS_DF):
    ##Get unique occupations
    occupations_df = pd.read_csv(OCCUPATIONS_DF)
    print(occupations_df.head())
    occupations_df = occupations_df.dropna(subset=['occupation'])

    ##Eval the occupation column
    occupations_df['occupation'] = occupations_df['occupation'].apply(eval)
    ##Unique occupations. Some qids have multiple occupations - unlist and make a flat list
    unique_occupations = list(set([item for sublist in occupations_df['occupation'].tolist() for item in sublist]))


    ###These are qids of occupations - we need to get the english label from wikidata
    def get_english_labels(qids):
        client = Client()
        labels = {}
        for qid in tqdm(qids):
            try:
                entity = client.get(qid, load=True)
                ##Get entity title
                labels[qid] = str(entity.label)
            except Exception as e:
                print(f"Error retrieving label for {qid}: {e}")
        return labels

    ###let's make a dict of qid to occupation name
    if not os.path.exists("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/qid_to_occupation_labels.json"):
        qid_to_occupation = get_english_labels(unique_occupations)

        ###save the dict
        with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/qid_to_occupation_labels.json", 'w') as f:
            json.dump(qid_to_occupation, f)
    else: 
    ##open the dict
        print("Loading qid to occupation dict")
        with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/qid_to_occupation_labels.json", 'r') as f:
            qid_to_occupation = json.load(f)
            
    ###Now, replace the qids with the occupation names
    def replace_qids_with_occupation_names(occupations):
        return [qid_to_occupation[qid] for qid in occupations if qid in qid_to_occupation]

    occupations_df['occupation'] = occupations_df['occupation'].apply(replace_qids_with_occupation_names)

    ##Save the df
    occupations_df.to_csv(OCCUPATIONS_DF_WITH_LABELS, index=False)
else:
    occupations_df=pd.read_csv(OCCUPATIONS_DF_WITH_LABELS)
    
##Set QID as index
occupations_df.set_index('QID', inplace=True)

##Evaluate the occupation column
occupations_df['occupation'] = occupations_df['occupation'].apply(eval)

##Keep only 3 occupations at most
occupations_df['occupation'] = occupations_df['occupation'].apply(lambda x: x[:3])

##Open the inference corpus
with open(PATH_TO_QRANK_CORPUS) as f:
    inf_corpus=json.load(f)
    
##Add occupation to 'template' - if occupation exists, say "Has worked as <occupation1>, <occupation2>."
for qid, data in tqdm(inf_corpus.items()):
    if qid in occupations_df.index:
        template_old=data['template']
        ##Add occupation to template if it is  non-missing
        data['template'] = data["template"]+ f" Has worked as {', '.join(occupations_df.loc[qid, 'occupation'])}." 
        data['text'] = data['text'].replace(template_old, data['template'])
    else:
        pass 
    
    ##Now, in the text, template exists within it. Replace the old template with the new one
    
##Save the updated inference corpus
with open(PATH_TO_NEW_CORPUS, 'w') as f:
    json.dump(inf_corpus, f)
    
print("Occupations added to template and saved to new corpus")