import pandas as pd
import numpy as np
import json
from tqdm import tqdm

PATH_TO_ALL_ENT_DF="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_ents_wdata_comp_people.csv"
PATH_TO_INF_CORPUS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people.json"
PATH_TO_INF_CORPUS_with_MEDIAN_DATES="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_median_dates.json"
def calculate_median_year(df):
    birth_years = df['PersonalInfo'].apply(lambda x: int(x.get('date_of_birth', None).get('time',"+0000")[1:5]) if 'date_of_birth' in x else None)
    death_years = df['PersonalInfo'].apply(lambda x: int(x.get('date_of_death', None).get('time',"+0000")[1:5]) if 'date_of_death' in x else None)
    
    ##Positions are of format {'Q65520895': ['+2020-05-18T00:00:00Z'], 'Q65457859': ['+2020-05-23T00:00:00Z']}
    
    median_position_years=df['Positions'].apply(lambda x: np.mean([np.mean([int(date[1:5]) for date in y]) for y in x.values()]) if x else None)
    
    median_years = []
    for birth_year, death_year,median_position_year in tqdm(zip(birth_years, death_years,median_position_years), total=len(birth_years)):
    
        if not pd.isna(median_position_year):
            median_years.append(median_position_year)
        else:
            if birth_year < 1500 or birth_year > 2100:
                birth_year = pd.NaT
            if death_year < 1500 or death_year > 2100:
                death_year = pd.NaT
            
            ##Not NaN
            if pd.notna(birth_year) and pd.notna(death_year):
                median_years.append(((3*birth_year) + (4*death_year))/7)
            ##Only birth year
            elif pd.notna(birth_year) and pd.isna(death_year):
                median_years.append(birth_year+40)
            ##Only death year
            elif pd.notna(death_year) and pd.isna(birth_year):
                median_years.append(death_year-30)
            ##Both are NaN
            else:
                median_years.append(None)
        
    return median_years, birth_years, death_years



###Run as script
if __name__ == '__main__':
    all_ents_df=pd.read_csv(PATH_TO_ALL_ENT_DF, converters={'PersonalInfo': eval, 'Positions': eval})

    print(all_ents_df.head())
    ##In the all_ents_df, the qid is stored in the column qid
    ##birth and death info, when available, is stored in the column PersonalInfo in a dict {"date_of_birth":{'time'}, "date_of_death":{'time'}} }
    ###We only need the year. It starts with "+", so take the substring from 1 to 5
    ##We use the average of the birth and death years as the median year
    ##Whenever only one of the birth or death year is available, we use that as the median year +- 30 years
    

    all_ents_df['median_year'], all_ents_df['birth_year'], all_ents_df['death_year'] = calculate_median_year(all_ents_df)
    
    
    
    ##Drop if qid is NaN
    all_ents_df.dropna(subset=['QID'], inplace=True)
    ##set the qid as index
    all_ents_df.set_index('QID', inplace=True)
    
    
    
    ##Now, the     inference corpus
    with open(PATH_TO_INF_CORPUS) as f:
        inf_corpus=json.load(f)
    ##Inf corpus is a dict with qid as key and a dict as value. 
    ##For each qid, we need to add the median year to the dict as "median_year"

    for qid, data in tqdm(inf_corpus.items()):
        data['median_year'] = all_ents_df.loc[qid, 'median_year']
        data['birth_year'] = all_ents_df.loc[qid, 'birth_year']
        data['death_year'] = all_ents_df.loc[qid, 'death_year']
        
    ##Save the updated inf corpus
    with open(PATH_TO_INF_CORPUS_with_MEDIAN_DATES, 'w') as f:
        json.dump(inf_corpus, f)

    