import json
import pandas as pd
from tqdm import tqdm

PATH_TO_WIKIDATA_PEOPLE_DF="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_ents_wdata_comp_people.csv"
PATH_TO_DISAMB_DICT="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disambiguation_dict_final.json"

###Load the disambiguation dict
with open(PATH_TO_DISAMB_DICT) as f:
    disamb_dict = json.load(f)
    
##Disamb dict is of the form entity:{other entities}

##Load df 
df_wikidata_people = pd.read_csv(PATH_TO_WIKIDATA_PEOPLE_DF,converters={'Aliases': eval})

##

EngWikipediaTitle_set=set(df_wikidata_people['EngWikipediaTitle'].tolist())
aliases_set=set([item for sublist in df_wikidata_people['Aliases'].tolist() for item in sublist])
label_set=set(df_wikidata_people['Label'].tolist())

all_names = EngWikipediaTitle_set.union(aliases_set).union(label_set)

##Keep only those entities in the disambiguation dict that are in the all_names set. 
###Search through the dict, first, remove the values that are not in all_names for each key. 
##IF all values are removed, remove the key as well.
##If the key is not in all_names and length value < 2, remove the key as well. Else, make one of the values the key and remove the key

##remove values that are not in all_names
disamb_dict = {k: [i for i in v if i in all_names] for k,v in tqdm(disamb_dict.items())}

##remove keys that are not in all_names and length of value < 2
disamb_dict = {k: v for k,v in disamb_dict.items() if k in all_names or len(v)>1}

##If the key is not in all_names and length value < 2, remove the key as well. Else, make one of the values the key and remove the key
for key in list(disamb_dict.keys()):
    if key not in all_names and len(disamb_dict[key])<2:
        del disamb_dict[key]
    elif key not in all_names:
        disamb_dict[disamb_dict[key][0]] = disamb_dict[key]
        del disamb_dict[key]
        
###Now, we want to replace the values with the qids. 


##Add all keys and values to a set
all_keys = set(disamb_dict.keys())
all_values = set([item for sublist in disamb_dict.values() for item in sublist])

##Get all the qids
all_names=all_names.union(all_keys).union(all_values)

##First, we index the df by EngWikipediaTitle
df_wikidata_people.set_index('EngWikipediaTitle',inplace=True)

###get qids from the df. First, search the EngWikipediaTitle, then the label, then the aliases
qid_dict = {}

##First, check the EngWikipediaTitle
names_in_eng_titles=set(df_wikidata_people.index)
intersect_names_titles = (names_in_eng_titles.intersection(all_names))
names_not_in_eng_titles = all_names - intersect_names_titles

intersect_names_titles=list(intersect_names_titles)
##assign qids to the names in the disamb_dict where the name is in the EngWikipediaTitle
print("Number of names in EngWikipediaTitle: ", len(intersect_names_titles))
qid_dict.update(df_wikidata_people.loc[intersect_names_titles, 'QID'].to_dict())

###now set Label as index
df_wikidata_people.reset_index(inplace=True)
df_wikidata_people.set_index('Label',inplace=True)

print("Number of names not in EngWikipediaTitle but in Label: ", len(names_not_in_eng_titles))
intersect_names_not_in_titles_but_in_labels = (names_not_in_eng_titles.intersection(set(df_wikidata_people.index)))
names_not_in_titles_and_labels = names_not_in_eng_titles - intersect_names_not_in_titles_but_in_labels

intersect_names_not_in_titles_but_in_labels=list(intersect_names_not_in_titles_but_in_labels)
qid_dict.update(df_wikidata_people.loc[intersect_names_not_in_titles_but_in_labels, 'QID'].to_dict())


###Now, remove all those rows from the df which are done - i.e. the names in the EngWikipediaTitle and the Label are in qid_dict
df_wikidata_people.reset_index(inplace=True)
df_wikidata_people.set_index('EngWikipediaTitle',inplace=True)
df_wikidata_people.drop(intersect_names_titles, inplace=True)
df_wikidata_people.reset_index(inplace=True)
df_wikidata_people.set_index('Label',inplace=True)
df_wikidata_people.drop(intersect_names_not_in_titles_but_in_labels, inplace=True)
df_wikidata_people.reset_index(inplace=True)

print("Number of rows in df after removing names in EngWikipediaTitle and Label: ", df_wikidata_people.shape[0])
##Now, keep only aliases and qid

##Set index as aliases
##drop where aliases is nan
# df_wikidata_people.dropna(subset=['Aliases'], inplace=True)
# df_wikidata_people.set_index('Aliases',inplace=True)

# intersect_names_not_in_titles_but_in_labels_aliases = (names_not_in_titles_and_labels.intersection(set(df_wikidata_people.index)))
# intersect_names_not_in_titles_but_in_labels_aliases=list(intersect_names_not_in_titles_but_in_labels_aliases)
# print("Number of names not in EngWikipediaTitle and Label but in Aliases: ", len(intersect_names_not_in_titles_but_in_labels_aliases))
# ##Get the qids for the aliases
# qid_dict.update(df_wikidata_people.loc[intersect_names_not_in_titles_but_in_labels_aliases, 'QID'].to_dict())

###now, in the disamb_dict, replace the values with the qids. If qid is none, drop the value. 

for key in tqdm(list(disamb_dict.keys())):
    disamb_dict[key] = [qid_dict[val] for val in disamb_dict[key] if qid_dict[val] is not None]
    if len(disamb_dict[key])==0:
        del disamb_dict[key]

##Now, replace the keys with the qids. If key doesn't have a qid, replace with the first value in the list if len(Value)>1. Else, drop the key
for key in tqdm(list(disamb_dict.keys())):
    if qid_dict[key] is not None:
        disamb_dict[qid_dict[key]] = disamb_dict[key]
        del disamb_dict[key]
    elif len(disamb_dict[key])>1:
        disamb_dict[disamb_dict[key][0]] = disamb_dict[key]
        del disamb_dict[key]
    else:
        del disamb_dict[key]
        
##Save the dict
with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disambiguation_dict_final_qid.json', 'w') as fp:
    json.dump(disamb_dict, fp)