import pickle 

##Import tokenizer from HF
from transformers import AutoTokenizer
from tqdm import tqdm
from glob import glob
import os
from data_fns import featurise_data 
from multiprocessing import Pool
from sentence_transformers import SentenceTransformer
import sys
import numpy as np
import pandas as pd
import json
import argparse
import faiss


##import ARI from sklearn
from sklearn.metrics import adjusted_rand_score
# print(data.keys())

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)

from nlp_utils.modified_sbert.cluster_fns import get_cluster_assignment





def get_fp_embeddings(fp_path, embedding_model_path, special_tokens=None, ent_featurisation='ent_mark',
                      use_multi_gpu=False,pre_featurised=True,
                      re_embed=False,override_max_seq_length=None,asymm_model=False):

    new_embeddings_needed = True
    save_path = f'{embedding_model_path}/fp_embeddings.pkl'

    with open(fp_path) as f:
        fp_raw = json.load(f)
        
    #prototype by keeping only the first 1000
    # fp_raw = {k:fp_raw[k] for k in list(fp_raw.keys())[:10000]}

    # Create list of wikipedia IDs
    dict_list = []
    wik_ids = []
    qid_list = []
    median_year_list=[]
    birth_year_list=[]
    qrank_list=[]

    for _, fp in tqdm(fp_raw.items()):

        if fp != {}:
            wik_ids.append(fp['wiki_title'])
            dict_list.append(fp)
            qid_list.append(fp['wikidata_info']['wikidata_id'])
            median_year_list.append(fp['median_year'])
            birth_year_list.append(fp['birth_year'])
            qrank_list.append(fp['qrank'])

    # Load embeddings if previously created
    if os.path.exists(save_path) and re_embed == False:

        print("Previous embeddings found, loading previous embeddings ...")
        with open(save_path, 'rb') as f:
            embeddings = pickle.load(f)

        if len(embeddings) == len(wik_ids):

            print('Loaded embeddings have same length as data. Using loaded embeddings')
            new_embeddings_needed = False

        else:
            print('Loaded embeddings are different length to data. Creating new embeddings')

    else:
        print("No previous embeddings found./ Rembedding is set to ",re_embed)

    # Otherwise create new embeddings
    if new_embeddings_needed:
        print("Creating new embeddings ...")

        # Load model
        model = SentenceTransformer(embedding_model_path)
        tokenizer = model.tokenizer

        for fp in dict_list:
            fp['context'] = fp['text']

        if not pre_featurised or os.path.exists(f'{embedding_model_path}/featurised_fps.pkl') == False:
            print("Featurising data ...")
            if not asymm_model:
                featurised_fps = featurise_data_with_dates_flex(dict_list,  ent_featurisation, "prepend_1", special_tokens, tokenizer, override_max_seq_length)
            else:
                print("Using asymmetric model featurisation - prepend featurisation and 'FP' key for each fp.")
                featurised_fps = featurise_data(dict_list, featurisation="prepend", special_tokens=special_tokens,  model=SentenceTransformer("all-mpnet-base-v2"),override_max_seq_length=override_max_seq_length)
                ###For each fp, make it a dict instead. {'FP':text}
                featurised_fps = [{'FP':fp} for fp in featurised_fps]
            ##Save featurised data
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'wb') as f:
                pickle.dump(featurised_fps, f)
        else:
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'rb') as f:
                featurised_fps = pickle.load(f)
        
        if use_multi_gpu:
            print("Using multiple GPUs: ", torch.cuda.device_count())
            pool = model.start_multi_process_pool()
            embeddings = model.encode_multi_process(featurised_fps, batch_size=64, pool=pool)
        else:
            embeddings = model.encode(featurised_fps, show_progress_bar=True, batch_size=64)

        # Normalize the embeddings to unit length
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

        # Save embeddings
        with open(save_path, 'wb') as f:
            pickle.dump(embeddings, f)
    instance_types_to_keep=set(['human'])

    BIRTH_DATE_CUTOFF=1970     
    QIDS_TO_REMOVE=["Q26702663","Q2292195","Q2600236",
                    "Q112039672","Q5442741","Q16354385",
                    "Q61600540","Q19276771","Q5456005","Q4659108","Q6264176","Q5300213"]
    
    ##Load more QIDs to remove
    with open("qids_to_prune.json") as f:
        more_qids = json.load(f)
        
    ##Add to the list
    QIDS_TO_REMOVE.extend(more_qids)
    QIDS_TO_REMOVE=set(QIDS_TO_REMOVE)
    BIRTH_DATES_ONLY=True
    ##Clean up the dict - to conform with the embeddings - remove empty dicts. 
    fp_dict_cleaned = {k:v for k,v in fp_raw.items() if v!={}}
    
    indices_to_keep = []
    
    ##Find indices that belong to the instance types to keep. It helps that python preserves the order of the dict. entity type is in dict['wikidata_info']['instance_of_labels']
    if len(instance_types_to_keep)>0:
        ##By instance type
        indices_to_keep = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if len(set(v['wikidata_info']['instance_of_labels']).intersection(instance_types_to_keep))>0]
        ##By date cutoff
        print("indices after instance type filtering: ",len(indices_to_keep))
    if BIRTH_DATE_CUTOFF:
        indices_to_keep_dates = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if v['birth_year']<BIRTH_DATE_CUTOFF or pd.isna(v['birth_year'])]
        indices_to_keep=set(indices_to_keep).intersection(indices_to_keep_dates)
        print("indices after date filtering: ",len(indices_to_keep))
    
    if QIDS_TO_REMOVE:
        indices_to_remove = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if v['wikidata_info']['wikidata_id'] in QIDS_TO_REMOVE]
        indices_to_keep=set(indices_to_keep).difference(indices_to_remove)
        print("indices after qid filtering: ",len(indices_to_keep))
    
    if BIRTH_DATES_ONLY:
      ##Drop if birth year or death year is not available
      indices_to_keep_births = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if not pd.isna(v['birth_year']) or not pd.isna(v['death_year'])]
      indices_to_keep=set(indices_to_keep).intersection(indices_to_keep_births)
      print("indices after birth/death year filtering: ",len(indices_to_keep))
    ##Filter the embeddings and ids
    print("Number of embeddings: ",len(embeddings))
    indices_to_keep=set(indices_to_keep)

    fp_embeddings = np.array([v for i,v in enumerate(tqdm(embeddings)) if i in indices_to_keep])
    fp_ids = [k for i,k in enumerate(wik_ids) if i in indices_to_keep]
    qid_list=[k for i,k in enumerate(qid_list) if i in indices_to_keep]
    median_year_list=[k for i,k in enumerate(median_year_list) if i in indices_to_keep]
    birth_year_list=[k for i,k in enumerate(birth_year_list) if i in indices_to_keep]
    qrank_list=[k for i,k in enumerate(qrank_list) if i in indices_to_keep]
    
    ##Make qrank None to 0
    qrank_list=[0 if pd.isna(q) or q==None else q for q in qrank_list]
    
    print("Number of embeddings after type filtering: ",len(fp_embeddings))
    
    ##Filter the fp_dict_cleaned
    fp_dict_cleaned = {k:v for i,(k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if i in indices_to_keep}
        
    fp_dict_only_text = [v['text'] for k,v in fp_dict_cleaned.items()]
    
    print(len(fp_embeddings), len(fp_ids))
    # assert len(fp_embeddings)==len(fp_ids)==len(fp_dict_only_text) 
    print("All lengths are equal")

    return fp_embeddings, fp_ids, qid_list, median_year_list, birth_year_list, qrank_list, fp_dict_only_text

def rerank_by_qrank(qrank_list, qrank_rerank_threshold, distances, neighbours):
    """Rerank based on QRank. Higher QRANK is better."""
    reranked_indices=[]
    reranked_distances=[]
    for i, nn_list in enumerate(neighbours):
        ##Get the indices of the nearest neighbours
        ##Threshold is determined by nearest neighbour distance - qrank_rerank_threshold
        nearest_nn_distance=distances[i][0]
        ##Threshold = nearest neighbour distance - qrank_rerank_threshold
        threshold=nearest_nn_distance-qrank_rerank_threshold
        distances_i = distances[i]
        nn_indices = [n for j, n in enumerate(nn_list) if distances[i][j] > threshold]
        distances_i = [d for j, d in enumerate(distances_i) if distances[i][j] > threshold]
        ##Get the qrank of the query
        query_qrank = qrank_list[i]
        ##Get the qrank of the nearest neighbours
        nn_qranks = [qrank_list[n] for n in nn_indices]
        
        if len(nn_indices)==0:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
            continue
        
        ##If all the nearest neighbours have a qrank
        if all([pd.notna(y) for y in nn_qranks]):
            ##Sort by the difference in qrank - max to min
            qrank_diffs = [-abs(query_qrank-y) for y in nn_qranks]
            reranked_indices_i = [n for _, n in sorted(zip(qrank_diffs,nn_indices))]
            reranked_distances_i = [d for _, d in sorted(zip(qrank_diffs,distances_i))]
            reranked_indices.append(reranked_indices_i)
            reranked_distances.append(reranked_distances_i)
            
        else:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
            
    return reranked_indices, reranked_distances
    

## STEP1 : ###Convert the NER outputs to disamb format - article:disamb_list
##not needed     

##STEP 1* - Embed disamb format articles using coref model

def featurise_data_has():
  INPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted"
  OUTPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised"
  files_done=glob(OUTPUT_DIR+"/*.json")
  files_done=[file.replace(OUTPUT_DIR, INPUT_DIR) for file in files_done]
  stoks={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}
  file_list=glob(INPUT_DIR+"/*.json")
  file_list=[file for file in file_list if file not in files_done]
  print("Files to process:", len(file_list))
  for file_name in tqdm(file_list, total=len(file_list), desc="files done"):


      ##Load the json
      with open(file_name, 'r') as f:
        data = json.load(f)

      ##Make a dict - article: disamb format
      article_key_list=list(data.keys())
      print(article_key_list[0])
      article_dict_list=[data[article] for article in article_key_list]
      print(article_dict_list[0])
      disamb_data_feat=featurise_data(article_dict_list, featurisation="ent_mark", special_tokens=stoks,  model=SentenceTransformer("all-mpnet-base-v2"),override_max_seq_length=256)
      data_featurised=dict(zip(article_key_list, disamb_data_feat))
      ##Save as the same pickle file name
      with open(OUTPUT_DIR+"/"+file_name.split("/")[-1].replace(".json", ".pkl"), 'wb') as f:
          pickle.dump(data_featurised, f)

# ##Step 2- Embed disamb format articles using coref model  - save as mention_article_id:embedding

def embed_coref():
  INPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised"
  OUTPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_corr_embedded"
  
  files_done=glob(OUTPUT_DIR+"/*.pkl")
  files_done=[file.replace(OUTPUT_DIR, INPUT_DIR) for file in files_done]
  
  file_list=glob(INPUT_DIR+"/*.pkl")
  file_list=[file for file in file_list if file not in files_done]
  
  print("Files to process:", len(file_list))
  model_path= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext"
  model=SentenceTransformer(model_path)
  for file_name in tqdm(file_list, total=len(file_list), desc="files done"):
      if file_name in files_done:
          continue
      ##Load the pickle file
      with open(file_name, 'rb') as f:
          data = pickle.load(f)
          
      ##Make a dict - article_mention_id: embedding
      ids=list(data.keys())
      texts=list(data.values())
      # pool = model.start_multi_process_pool()
      # embeddings = model.encode_multi_process(texts, batch_size=512, pool=pool)
      embeddings=model.encode(texts, batch_size=64,show_progress_bar=True)
      
      ##save dict
      output_dict=dict(zip(ids, embeddings))
      
      ##Save as the same pickle file name
      with open(OUTPUT_DIR+"/"+file_name.split("/")[-1], 'wb') as f:
          pickle.dump(output_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
      
    
    
    


##Step 3 - Cluster emb within date, then, make a dict with date_cluster_id as key and dict of mention_ids as value

def coref_cluster():
    INPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_corr_embedded"
    OUTPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered"
    CLUSTERING="NONE" #Or ARTICLE or DATE or CHAPTER 
    assert  CLUSTERING in ["ARTICLE", "DATE", "CHAPTER","NONE"]
    file_list=glob(INPUT_DIR+"/*.pkl")
    print("Files to process:", len(file_list))
    files_done=glob(OUTPUT_DIR+"/*.pkl")
    files_done=[file.replace(OUTPUT_DIR, INPUT_DIR) for file in files_done]
    file_list=[file for file in file_list if file not in files_done]
    print("Files to process:", len(file_list))
  
    for file_name in tqdm(file_list, total=len(file_list), desc="files done"):
        ##Load the pickle file
        with open(file_name, 'rb') as f:
            data = pickle.load(f)
        if "_ca" in file_name:
            ca=True
        else:
            ca=False
        ##Make a dict with date:mention_ids
        cluster_dict={}
        for mention_id in data.keys():
            if CLUSTERING=="DATE":
                raise ValueError("Not implemented")
            elif CLUSTERING=="ARTICLE": 
                cluster_unit="_".join(mention_id.split("_")[1:])
            elif CLUSTERING=="CHAPTER":
                cluster_unit="_".join(mention_id.split("_")[1:-1])
            else:
                print("Not clustering")
                cluster_dict[mention_id]={'cluster_mention_ids':[mention_id]}
            if not CLUSTERING=="NONE": 
                if cluster_unit not in cluster_dict:
                    cluster_dict[cluster_unit]={}
                cluster_dict[cluster_unit][mention_id]=data[mention_id]
                

        if not CLUSTERING=="NONE":
            ###We now want to cluster the embeddings within each date
            unit_cluster_ids={} ##Dict that stores the mention ids and embedding for each cluster in the date. 
            for i in tqdm(range(len(cluster_dict.keys()))):
                unit=list(cluster_dict.keys())[i]
                mention_ids=list(cluster_dict[unit].keys())
                mention_embeddings=list(cluster_dict[unit].values())
                mention_embeddings=np.array(mention_embeddings)
                if not mention_embeddings.shape[0]==1:
                    cluster_ids=get_cluster_assignment("agglomerative", cluster_params={'threshold': 0.15, 'clustering linkage': 'average', 'metric': 'cosine'}, corpus_embeddings=mention_embeddings)
                else:
                    cluster_ids=[0]
                
                ###Collect the embeddings in each cluster, average them and save as date_cluster_id:{'embedding', 'cluster_mention_ids'}. cluster_mention_ids are the mention_ids in the cluster
                for cluster_id in set(cluster_ids):
                    cluster_mention_ids=[mention_ids[i] for i in range(len(mention_ids)) if cluster_ids[i]==cluster_id]
                    cluster_embeddings=[mention_embeddings[i] for i in range(len(mention_ids)) if cluster_ids[i]==cluster_id]
                    cluster_embedding=np.mean(cluster_embeddings, axis=0)
                    date_cluster_id=str(cluster_id)+"_"+unit
                    cluster_output={'embedding':cluster_embedding, 'cluster_mention_ids':cluster_mention_ids}
                    unit_cluster_ids[date_cluster_id]=cluster_output

            date_entity_count={unit:len(cluster_dict[unit]) for unit in cluster_dict.keys()}


            ##Save the dict
            with open(OUTPUT_DIR+"/"+file_name.split("/")[-1], 'wb') as f:
                pickle.dump(unit_cluster_ids, f, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            with open(OUTPUT_DIR+"/"+file_name.split("/")[-1], 'wb') as f:
                pickle.dump(cluster_dict, f, protocol=pickle.HIGHEST_PROTOCOL)


##Step 4 - Embed the articles using the disamb model
def embed_disamb():
  INPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised"
  OUTPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_disamb_embedded"
  
  files_done=glob(OUTPUT_DIR+"/*.pkl")
  files_done=[file.replace(OUTPUT_DIR, INPUT_DIR) for file in files_done]
  
  file_list=glob(INPUT_DIR+"/*.pkl")
  file_list=[file for file in file_list if file not in files_done]
  
  
  model_path= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_epoch_1"
  model=SentenceTransformer(model_path)
  for file_name in tqdm(file_list, total=len(file_list), desc="files done"):
      if file_name in files_done:
          continue
      ##Load the pickle file
      with open(file_name, 'rb') as f:
          data = pickle.load(f)
          
      ##Make a dict - article_mention_id: embedding
      ids=list(data.keys())
      texts=list(data.values())
      # pool = model.start_multi_process_pool()
      embeddings = model.encode(texts, batch_size=64, show_progress_bar=True)
      
      ##save dict
      output_dict=dict(zip(ids, embeddings))
      
      ##Save as the same pickle file name
      with open(OUTPUT_DIR+"/"+file_name.split("/")[-1], 'wb') as f:
          pickle.dump(output_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
  
  



##Step 5: Search for nearest neighbor in wiki corpus and get the QID of the date_cluster_id. Count #mentions

def search_qid():
  INPUT_DIR_EMBEDDINGS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_disamb_embedded"
  INPUT_DIR_CLUSTERS="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered"
  INPUT_DIR_FORMATTED="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted"
  OUTPUT_DIR="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/has/has_outputs_disamb_formatted_coref_featurised_corr_embedded_date_clustered_wiki_qid"
  
  RERANK_QRANK=True
  
  files_done=glob(OUTPUT_DIR+"/*.pkl")
  files_done=[file.replace(OUTPUT_DIR, INPUT_DIR_CLUSTERS) for file in files_done]
  
  file_list=glob(INPUT_DIR_CLUSTERS+"/*.pkl")
  # file_list=[file for file in file_list if file not in files_done]
  stoks={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}
  trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_epoch_1"
  ###
  fp_embeddings, fp_ids, qid_list, median_year_list, birth_year_list,qrank_list, fp_dict_only_text = get_fp_embeddings(
    fp_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json',
    embedding_model_path = trained_model_path,
    special_tokens = stoks,
    use_multi_gpu=True,
    override_max_seq_length=256,
    re_embed=False,
    pre_featurised=True,
    asymm_model=False,        
  )
  
  ###Now search for the nearest neighbor in the fp_embeddings. 
  print("Setting up index")
  res = faiss.StandardGpuResources()
  index = faiss.GpuIndexFlatIP(res, fp_embeddings.shape[1])
  
  ##normalize
  fp_embeddings=fp_embeddings/np.linalg.norm(fp_embeddings, axis=1, keepdims=True)
  index.add(fp_embeddings)

  for file_name in tqdm(file_list, total=len(file_list), desc="files done"):
      
    with open(file_name, 'rb') as f:
        cluster_data = pickle.load(f)
    
    with open(file_name.replace(INPUT_DIR_CLUSTERS, INPUT_DIR_FORMATTED).replace("pkl", "json"), 'r') as f:
        text_data = json.load(f)
    
    with open(file_name.replace(INPUT_DIR_CLUSTERS, INPUT_DIR_EMBEDDINGS), 'rb') as f:
        disamb_emb_data = pickle.load(f)
        
    ##Keep only cluster_mention_ids
    cluster_data={key: cluster_data[key]['cluster_mention_ids'] for key in cluster_data.keys()}
    
    
    ##Now, attach the embeddings to the cluster_mention_ids and then average them to get the cluster embedding
    cluster_data_emb={key: [disamb_emb_data[mention_id] for mention_id in cluster_data[key]] for key in cluster_data.keys()}
    cluster_text_data_mention_text={key: [text_data[mention_id]["mention_text"] for mention_id in cluster_data[key]] for key in cluster_data.keys()}
    # cluster_text_data_per_id={key: [text_data[mention_id]["per_id"] for mention_id in cluster_data[key]] for key in cluster_data.keys()}

    # cluster_text_data_context={key: [text_data[mention_id]["context"] for mention_id in cluster_data[key]] for key in cluster_data.keys()}
    ##Average the embeddings
    cluster_data_emb_avg={key: np.mean(cluster_data_emb[key], axis=0) for key in cluster_data_emb.keys()}
    
    
    
    query_embds=np.array(list(cluster_data_emb_avg.values()))
    query_embds=query_embds/np.linalg.norm(query_embds, axis=1, keepdims=True)
    
    print(query_embds.shape)
    ##Search only 100
    if RERANK_QRANK:
      D, I = index.search(query_embds, 10)
      I, D = rerank_by_qrank(qrank_list, 0.01, D, I)
      I=[n[0] for n in I]
      D=[n[0] for n in D]

    else:
      print("Searching")
      D, I = index.search(query_embds, 1)

    ##Get the qid of the nearest neighbor
    qid_list_nn=[qid_list[i] for i in I.flatten()] if not RERANK_QRANK else [qid_list[i] for i in I]
    fp_ids_nn=[fp_ids[i] for i in I.flatten()] if not RERANK_QRANK else [fp_ids[i] for i in I]
      
    
    ##Now, prepare the output dict - also add mention_text and distance
    output_dict={key: {"qid": qid_list_nn[i], 
                      "fp_id": fp_ids_nn[i],
                      "mention_text": cluster_text_data_mention_text[key],
                    #   "per_ids": cluster_text_data_per_id[key],
                        "art_mention_ids": cluster_data[key] } for i, key in enumerate(cluster_data_emb_avg.keys())}
    
    ##Add distance
    # output_dict={key: {**output_dict[key], "distance": D[i][0]} for i, key in enumerate(output_dict.keys())} if not RERANK_QRANK else {key: {**output_dict[key], "distance": D[i]} for i, key in enumerate(output_dict.keys())}
    
    ##make pandas df
    output_df=pd.DataFrame(output_dict).T
    output_df["date_cluster_id"]=output_df.index
    
    ##Add distance
    output_df["distance"]=D.flatten() if not RERANK_QRANK else D
    
    ##Sort by distance
    output_df=output_df.sort_values("distance", ascending=False)
    
    ###Write
    output_df.to_csv(OUTPUT_DIR+"/"+file_name.split("/")[-1].replace(".pkl", ".csv"), index=False)
    
    
    #Save as json
    with open(OUTPUT_DIR+"/"+file_name.split("/")[-1], 'w') as f:
        json.dump(output_dict, f)
    

 
  
##Run as script
if __name__=="__main__":
  parser = argparse.ArgumentParser(description='Run entity disambiguation pipeline')
  parser.add_argument('--step', type=str, help='Step to run', required=True)
  args = parser.parse_args()
  
  if args.step=="format_data":
    raise ValueError("Not needed")
  elif args.step=="featurise_data":
    featurise_data_has()
  elif args.step=="embed_coref":
    embed_coref()
  elif args.step=="coref_cluster":
    coref_cluster()
  elif args.step=="embed_disamb":
    embed_disamb()
  elif args.step=="search_qid":
    # coref_cluster()
    search_qid()
  elif args.step=="all":
    # featurise_data_has()
    # embed_coref()
    coref_cluster()
    # embed_disamb()
    search_qid()
    
    

    
    
    
