import pandas as pd 
import numpy as np 
from tqdm import tqdm
from glob import glob
import os 
import json
DISAMB_OUTPUT="/mnt/data01/entity/ner_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered_wiki_qid"
GENDER_MAPPING="../entity_training_wikidata_times_people_gender_stacked.csv"
OCCUPATION_MAPPING="../all_ents_wdata_comp_people_occupations_with_occupation_names.csv"
OUTPUT_DFS="/mnt/data01/entity/ner_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered_wiki_qid_with_occ_gender"
ANALYSED_OUTPUTS="/mnt/data01/entity/ner_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered_wiki_processed"
gender_df=pd.read_csv(GENDER_MAPPING)
print(len(gender_df))
occupation_df=pd.read_csv(OCCUPATION_MAPPING)
print(len(occupation_df))

##bith have Label column
##combine both  by merging on QID. ##keep all
gender_occ_df = pd.merge(gender_df, occupation_df, on='QID', how='outer')

print(gender_occ_df.head())
print(gender_occ_df.columns)
###combine both Label columns into one (Label_x, Label_y)
gender_occ_df['Label'] =  gender_occ_df["Label_x"].fillna(gender_occ_df["Label_y"])
gender_occ_df=gender_occ_df[['QID','Label',"sex_or_gender","occupation"]]
print(gender_occ_df.head())

##Rename QID to qid
gender_occ_df.rename(columns={"QID":"qid"},inplace=True)

assert "qid" in gender_occ_df.columns

all_csvs = glob(f"{DISAMB_OUTPUT}/*.csv")

for file in tqdm(all_csvs):
    print("Processing file: ", file)
    df = pd.read_csv(file,sep=",")
    
    ##Drop if distance (which is sim thresh) <0.85
    df=df[df['distance']>=0.85]
    
    ##Merge with gender_occ_df 
    df=pd.merge(df, gender_occ_df, on='qid', how='left')
    
    print(df.head())
    
    df=df[['qid','date_cluster_id','Label','sex_or_gender','occupation','distance','mention_text','art_mention_ids','fp_id']]
    
    ##Write
    df.to_csv(os.path.join(OUTPUT_DFS, os.path.basename(file)), index=False,sep="|")
    
    # Most Popular Entity Calculation
    df['sum_art_mention_ids'] = df['art_mention_ids'].apply(lambda x: len(x) if x is not None else 0)
    df_popular = df.groupby(['qid','Label' ,'sex_or_gender', 'occupation']).agg({'sum_art_mention_ids': 'sum'}).reset_index()
    df_popular = df_popular.sort_values(by='sum_art_mention_ids', ascending=False)
    df_popular.to_csv(os.path.join(ANALYSED_OUTPUTS, "pop_" + os.path.basename(file)), index=False, sep="|")

    # Count mentions of each sex_or_gender and occupation
    # Exploding sex_or_gender and occupation to create a row per mention with associated QID
    df_exploded_gender = df.explode('sex_or_gender')
    df_exploded_occupation = df.explode('occupation')

    # Grouping by sex_or_gender and occupation and summing mentions
    gender_mentions = df_exploded_gender.groupby('sex_or_gender').agg({'sum_art_mention_ids': 'sum'}).reset_index()
    occupation_mentions = df_exploded_occupation.groupby('occupation').agg({'sum_art_mention_ids': 'sum'}).reset_index()

    # Write these counts to CSV
    gender_mentions.to_csv(os.path.join(ANALYSED_OUTPUTS, "gender_mentions_" + os.path.basename(file)), index=False, sep="|")
    occupation_mentions.to_csv(os.path.join(ANALYSED_OUTPUTS, "occupation_mentions_" + os.path.basename(file)), index=False, sep="|")
    
    # Additional outputs (if needed)
    unique_qids= df['qid'].nunique()
    total_mentions= df['sum_art_mention_ids'].sum()
    summary_dict= {"unique_qids": unique_qids, "total_mentions": total_mentions}
    
    ##Ensure json serializable
    for key in summary_dict:
        if isinstance(summary_dict[key], np.int64):
            summary_dict[key] = int(summary_dict[key])
    
    ##Save these two as dicts
    with open(os.path.join(ANALYSED_OUTPUTS, "summary_" + os.path.basename(file).replace(".csv", ".json")), 'w') as f:
        json.dump(summary_dict, f)
    