import os
import json
import pickle
from tqdm import tqdm

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd

from data_fns import featurise_data, prep_newspaper_data, prep_sotu_data, clean_entity
from data_fn_disamb import featurise_data_with_dates_flex

def get_mention_text_between_special_tokens(text, start_token="[M]", end_token="[/M]"):
    """Get the text between the start and end tokens"""
    start_index=text.find(start_token)
    end_index=text.find(end_token)
    mention_text= text[start_index+len(start_token):end_index]
    ##trim ws
    mention_text=mention_text.strip()
    ###replace \n- with ""
    mention_text=mention_text.replace("\n-","")
    ##replace \n with " "
    mention_text=mention_text.replace("\n"," ")
    return mention_text

def calculate_string_match_accuracy(query_mention_text_list,query_qid_list ,fp_titles, fp_qids):
    """Calculate in-wikipedia accuracy based on exact string match"""
    
    ##For each mention text find the exact match in the fp_titles and get the corresponding qid
    print(query_mention_text_list[:10])
    print(fp_titles[:10])
    
    ##Keep only those text and qids that don't have "Not in wikipedia" , '' or None as the qid
    query_mention_text_list=[m for m,q in zip(query_mention_text_list,query_qid_list) if q not in ['Not in wikipedia', '', None]]
    query_qid_list=[q for q in query_qid_list if q not in ['Not in wikipedia', '', None]]
    
    print("Number of queries with valid qids: ",len(query_mention_text_list))
    print("Number of queries with valid qids: ",len(query_qid_list))
    pred_labels=[]
    gt_qids=query_qid_list
    ###Find mention 
    for mention_text in query_mention_text_list:
        if mention_text in fp_titles:
            pred_labels.append(fp_qids[fp_titles.index(mention_text)])
        else:
            pred_labels.append('Not in wikipedia')
    
    ##Calculate accuracy
    correct=[1 if gt==pred else 0 for gt,pred in zip(gt_qids,pred_labels)]
    print(correct)
    accuracy=sum(correct)/len(correct)
    print("Accuracy: ",accuracy)
    return accuracy

        
    
            
    
    
    


def rerank_by_date(median_year_list, query_year_list, date_rerank_threshold, distances, neighbours):
    ##Rerank based on date difference
    ## Only rerank among those that are within date_rerank_threshold
    ##date_rerank_threshold is a cosine similarity threshold - like those in distances
    ##If none of the neighbours are within the threshold, keep the original order
    ##If there are neighbours within the threshold, if any of them doesn't have a median year, keep the original order
    ##If all of them have a median year, sort by the difference in median year
    ##If the difference is the same, sort by the distance
    ##Return only indices of the nearest neighbours
    reranked_indices=[]
    reranked_distances=[]
    for i, nn_list in enumerate(neighbours):
        ##Get the indices of the nearest neighbours
        ##Threshold is determined by nearest neighbour distance - date_rerank_threshold
        nearest_nn_distance=distances[i][0]
        ##Threshold = nearest neighbour distance - date_rerank_threshold
        threshold=nearest_nn_distance-date_rerank_threshold
        distances_i = distances[i]
        nn_indices = [n for j, n in enumerate(nn_list) if distances[i][j] > threshold]
        distances_i = [d for j, d in enumerate(distances_i) if distances[i][j] > threshold]
        ##Get the median year of the query
        query_year = query_year_list[i]
        ##Get the median years of the nearest neighbours
        nn_years = [median_year_list[n] for n in nn_indices]
        
        if len(nn_indices)==0:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
            continue
        
        ##If all the nearest neighbours have a median year
        if all([pd.notna(y) for y in nn_years]):
            ##Sort by the difference in median year - min to max
            date_diffs = [abs(query_year-y) for y in nn_years]
            reranked_indices_i = [n for _, n in sorted(zip(date_diffs,nn_indices))]
            reranked_distances_i = [d for _, d in sorted(zip(date_diffs,distances_i))]
            reranked_indices.append(reranked_indices_i)
            reranked_distances.append(reranked_distances_i)
            
        else:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
    return reranked_indices, reranked_distances
        
        
        
def rerank_by_qrank(qrank_list, qrank_rerank_threshold, distances, neighbours):
    """Rerank based on QRank. Higher QRANK is better."""
    reranked_indices=[]
    reranked_distances=[]
    for i, nn_list in enumerate(neighbours):
        ##Get the indices of the nearest neighbours
        ##Threshold is determined by nearest neighbour distance - qrank_rerank_threshold
        nearest_nn_distance=distances[i][0]
        ##Threshold = nearest neighbour distance - qrank_rerank_threshold
        threshold=nearest_nn_distance-qrank_rerank_threshold
        distances_i = distances[i]
        nn_indices = [n for j, n in enumerate(nn_list) if distances[i][j] > threshold]
        distances_i = [d for j, d in enumerate(distances_i) if distances[i][j] > threshold]
        ##Get the qrank of the query
        query_qrank = qrank_list[i]
        ##Get the qrank of the nearest neighbours
        nn_qranks = [qrank_list[n] for n in nn_indices]
        
        if len(nn_indices)==0:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
            continue
        
        ##If all the nearest neighbours have a qrank
        if all([pd.notna(y) for y in nn_qranks]):
            ##Sort by the difference in qrank - max to min
            qrank_diffs = [-abs(query_qrank-y) for y in nn_qranks]
            reranked_indices_i = [n for _, n in sorted(zip(qrank_diffs,nn_indices))]
            reranked_distances_i = [d for _, d in sorted(zip(qrank_diffs,distances_i))]
            reranked_indices.append(reranked_indices_i)
            reranked_distances.append(reranked_distances_i)
            
        else:
            reranked_indices.append(nn_list)
            reranked_distances.append(distances_i)
            
    return reranked_indices, reranked_distances
    
    
        
    
    
        
        
        

def article_id_to_year(article_id):
    """Convert article id to year"""
    article_id=article_id.split("-p-")[0]
    ##split by - and take the last 3 elements
    date=article_id.split("-")[-3:]
    ##Year is the last 4 elements
    year=date[-1]
    return year


def get_fp_embeddings(fp_path, embedding_model_path, special_tokens=None, ent_featurisation='ent_mark',
                      use_multi_gpu=False,pre_featurised=True,
                      re_embed=False,override_max_seq_length=None,asymm_model=False):

    new_embeddings_needed = True
    save_path = f'{embedding_model_path}/fp_embeddings.pkl'

    with open(fp_path) as f:
        fp_raw = json.load(f)
        
    #prototype by keeping only the first 1000
    # fp_raw = {k:fp_raw[k] for k in list(fp_raw.keys())[:10000]}

    # Create list of wikipedia IDs
    dict_list = []
    wik_ids = []
    qid_list = []
    median_year_list=[]
    birth_year_list=[]
    qrank_list=[]

    for _, fp in tqdm(fp_raw.items()):

        if fp != {}:
            wik_ids.append(fp['wiki_title'])
            dict_list.append(fp)
            qid_list.append(fp['wikidata_info']['wikidata_id'])
            median_year_list.append(fp['median_year'])
            birth_year_list.append(fp['birth_year'])
            qrank_list.append(fp['qrank'])

    # Load embeddings if previously created
    if os.path.exists(save_path) and re_embed == False:

        print("Previous embeddings found, loading previous embeddings ...")
        with open(save_path, 'rb') as f:
            embeddings = pickle.load(f)

        if len(embeddings) == len(wik_ids):

            print('Loaded embeddings have same length as data. Using loaded embeddings')
            new_embeddings_needed = False

        else:
            print('Loaded embeddings are different length to data. Creating new embeddings')

    else:
        print("No previous embeddings found./ Rembedding is set to ",re_embed)

    # Otherwise create new embeddings
    if new_embeddings_needed:
        print("Creating new embeddings ...")

        # Load model
        model = SentenceTransformer(embedding_model_path)
        tokenizer = model.tokenizer

        for fp in dict_list:
            fp['context'] = fp['text']

        if not pre_featurised or os.path.exists(f'{embedding_model_path}/featurised_fps.pkl') == False:
            print("Featurising data ...")
            if not asymm_model:
                featurised_fps = featurise_data_with_dates_flex(dict_list,  ent_featurisation, "prepend_1", special_tokens, tokenizer, override_max_seq_length)
            else:
                print("Using asymmetric model featurisation - prepend featurisation and 'FP' key for each fp.")
                featurised_fps = featurise_data(dict_list, featurisation="prepend", special_tokens=special_tokens,  model=SentenceTransformer("all-mpnet-base-v2"),override_max_seq_length=override_max_seq_length)
                ###For each fp, make it a dict instead. {'FP':text}
                featurised_fps = [{'FP':fp} for fp in featurised_fps]
            ##Save featurised data
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'wb') as f:
                pickle.dump(featurised_fps, f)
        else:
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'rb') as f:
                featurised_fps = pickle.load(f)
        
        if use_multi_gpu:
            print("Using multiple GPUs: ", torch.cuda.device_count())
            pool = model.start_multi_process_pool()
            embeddings = model.encode_multi_process(featurised_fps, batch_size=720, pool=pool)
        else:
            embeddings = model.encode(featurised_fps, show_progress_bar=True, batch_size=720)

        # Normalize the embeddings to unit length
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

        # Save embeddings
        with open(save_path, 'wb') as f:
            pickle.dump(embeddings, f)
    instance_types_to_keep=set(['human'])

    BIRTH_DATE_CUTOFF=1970     
    QIDS_TO_REMOVE=["Q26702663","Q2292195","Q2600236",
                    "Q112039672","Q5442741","Q16354385",
                    "Q61600540","Q19276771","Q5456005","Q4659108","Q6264176","Q5300213"]
    
    ##Load more QIDs to remove
    with open("qids_to_prune.json") as f:
        more_qids = json.load(f)
        
    ##Add to the list
    QIDS_TO_REMOVE.extend(more_qids)
    QIDS_TO_REMOVE=set(QIDS_TO_REMOVE)
    BIRTH_DATES_ONLY=True
    ##Clean up the dict - to conform with the embeddings - remove empty dicts. 
    fp_dict_cleaned = {k:v for k,v in fp_raw.items() if v!={}}
    
    indices_to_keep = []
    
    ##Find indices that belong to the instance types to keep. It helps that python preserves the order of the dict. entity type is in dict['wikidata_info']['instance_of_labels']
    if len(instance_types_to_keep)>0:
        ##By instance type
        indices_to_keep = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if len(set(v['wikidata_info']['instance_of_labels']).intersection(instance_types_to_keep))>0]
        ##By date cutoff
        print("indices after instance type filtering: ",len(indices_to_keep))
    if BIRTH_DATE_CUTOFF:
        indices_to_keep_dates = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if v['birth_year']<BIRTH_DATE_CUTOFF or pd.isna(v['birth_year'])]
        indices_to_keep=set(indices_to_keep).intersection(indices_to_keep_dates)
        print("indices after date filtering: ",len(indices_to_keep))
    
    if QIDS_TO_REMOVE:
        indices_to_remove = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if v['wikidata_info']['wikidata_id'] in QIDS_TO_REMOVE]
        indices_to_keep=set(indices_to_keep).difference(indices_to_remove)
        print("indices after qid filtering: ",len(indices_to_keep))
    
    if BIRTH_DATES_ONLY:
      ##Drop if birth year or death year is not available
      indices_to_keep_births = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if not pd.isna(v['birth_year']) or not pd.isna(v['death_year'])]
      indices_to_keep=set(indices_to_keep).intersection(indices_to_keep_births)
      print("indices after birth/death year filtering: ",len(indices_to_keep))
    ##Filter the embeddings and ids
    print("Number of embeddings: ",len(embeddings))
    indices_to_keep=set(indices_to_keep)

    fp_embeddings = np.array([v for i,v in enumerate(tqdm(embeddings)) if i in indices_to_keep])
    fp_ids = [k for i,k in enumerate(wik_ids) if i in indices_to_keep]
    qid_list=[k for i,k in enumerate(qid_list) if i in indices_to_keep]
    median_year_list=[k for i,k in enumerate(median_year_list) if i in indices_to_keep]
    birth_year_list=[k for i,k in enumerate(birth_year_list) if i in indices_to_keep]
    qrank_list=[k for i,k in enumerate(qrank_list) if i in indices_to_keep]
    
    ##Make qrank None to 0
    qrank_list=[0 if pd.isna(q) or q==None else q for q in qrank_list]
    
    print("Number of embeddings after type filtering: ",len(fp_embeddings))
    
    ##Filter the fp_dict_cleaned
    fp_dict_cleaned = {k:v for i,(k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if i in indices_to_keep}
        
    fp_dict_only_text = [v['text'] for k,v in fp_dict_cleaned.items()]
    
    print(len(fp_embeddings), len(fp_ids))
    # assert len(fp_embeddings)==len(fp_ids)==len(fp_dict_only_text) 
    print("All lengths are equal")

    return fp_embeddings, fp_ids, qid_list, median_year_list, birth_year_list, qrank_list, fp_dict_only_text, fp_dict_cleaned


def eval(pred_labs, gt_labs):

    in_wiki_count = 0
    not_in_wiki_count = 0

    for i, gt in enumerate(gt_labs):

        if gt in pred_labs[i]:
            in_wiki_count += 1

        elif len(pred_labs[i]) == 0 and (gt == 'Not in wikipedia' or gt == '' or gt==None):
            not_in_wiki_count += 1

    all_acc = (in_wiki_count + not_in_wiki_count)/len(gt_labs)
    in_wiki = len([i for i in gt_labs if (i != 'Not in wikipedia' and i != '' and i!=None)])
    not_in_wiki = len(gt_labs) - in_wiki

    print(f'Total accuracy: {all_acc}')
    print(f'In wikipedia accuracy: {in_wiki_count/in_wiki}')
    print(f'Not in wikipedia accuracy: {not_in_wiki_count/not_in_wiki}')

    

def evaluate(trained_model_path, data, ent_featurisation, date_featurisation,
             fp_embeddings, fp_ids, special_tokens, th=None, k=1,
             featurised_fp_text=None,override_max_seq_length=None,
             keep_entity_types=[],fp_labels=None,full_dict=None,asymm_model=False,
             prepared_ds_path=None,pre_coreffed_ds=False, average_cluster_embs=False, weighted_accuracy=True,
             median_year_list=None,date_rerank_k=None,date_rerank_threshold=None,
             qrank_list=None,qrank_rerank_k=None,qrank_rerank_threshold=None,calculate_string_match=False): 

    model = SentenceTransformer(trained_model_path)

    # Extract and featurise data.
    if data in ["newspapers_hns_extended_dev", "newspapers_hns_extended_test"]:
        train_data, dev_data, test_data = prep_newspaper_data(
            dataset_path = '/mnt/data01/entity/labelled_data/new_with_negs/labeled_datasets_full_extended.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            input_examples=False,
            date_featurisation=date_featurisation
        )

        if data == "newspapers_hns_extended_dev":
            ds = dev_data
        elif data == "newspapers_hns_extended_test":
            ds = test_data

    elif data in ["newspapers_hns_restricted_dev", "newspapers_hns_restricted_test"]:
        train_data, dev_data, test_data = prep_newspaper_data(
            dataset_path = '/mnt/data01/entity/labelled_data/new_with_negs/labeled_datasets_full_restricted.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            input_examples=False,
            date_featurisation=date_featurisation
        )

        if data == "newspapers_hns_restricted_dev":
            ds = dev_data
        elif data == "newspapers_hns_restricted_test":
            ds = test_data

    elif data == "newspapers_sotu":
        ds = prep_sotu_data(
            dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            date_featurisation=date_featurisation,
            override_max_seq_length=override_max_seq_length,
            keep_entity_types=keep_entity_types,
            asymm_model=asymm_model
        )
    
    if prepared_ds_path:
        print("overriding the prepared ds path")
        with open(prepared_ds_path, 'r') as f:
            ds = json.load(f)
    
    queries = []
    entity_labels = []
    qid_list = []
    article_ids=[]
    cluster_ids=[]
    year_list=[]
    
    
    if not pre_coreffed_ds:

        for d in ds:
            queries.append(d['text'])
            entity_labels.append(d['entity'])
            qid_list.append(d['wiki_entity'])
            article_ids.append(d['art_id'])
            year_list.append(d['year'])
    

            

    else : ##Struct is {art_id:cluster_id:text_id:{text_dict}}
        for art_id in ds:
            for cluster_id in ds[art_id]:
                for text_id in ds[art_id][cluster_id]:
                    if not average_cluster_embs:
                        queries.append(ds[art_id][cluster_id]["0"]['text'])
                    else:
                        queries.append(ds[art_id][cluster_id][text_id]['text'])
                    entity_labels.append(ds[art_id][cluster_id]["0"]['entity'])
                    qid_list.append(ds[art_id][cluster_id]["0"]['wiki_entity'])
                    article_ids.append(art_id)
                    cluster_ids.append("_".join([art_id,cluster_id]))
                    year_list.append(article_id_to_year(art_id))

    mention_text_list=[get_mention_text_between_special_tokens(q) for q in queries]
   
    ##If cluster_key is not none, for each art_id, cluster, get only the 
    ##convert year_list to int
    year_list=[int(y) for y in year_list]
    
    
    print("Number of queries: ",len(set(queries)))
    
    
    # Embed
    query_embeddings = model.encode(queries, show_progress_bar=True, batch_size=128)
    print("Query embeddings shape: ",query_embeddings.shape)

    ###Average embeddings for each cluster - matching cluster_ids
    if average_cluster_embs:
        # raise NotImplementedError("This is not implemented yet")
        print("Averaging cluster embeddings")
        ##Get the unique cluster_ids
        unique_cluster_ids=set(cluster_ids)
        ##For each cluster_id, get the indices of the queries that belong to that cluster
        cluster_indices={cluster_id:[i for i,c_id in enumerate(cluster_ids) if c_id==cluster_id] for cluster_id in unique_cluster_ids}
        ##For each cluster_id, get the embeddings that belong to that cluster
        cluster_embeddings={cluster_id:query_embeddings[indices] for cluster_id,indices in cluster_indices.items()}
        ##For each cluster_id, average the embeddings
        cluster_avg_embeddings={cluster_id:np.mean(embeddings,axis=0) for cluster_id,embeddings in cluster_embeddings.items()}
        ##For each query, get the cluster_id and replace the query embedding with the cluster average
        for i,cluster_id in enumerate(cluster_ids):
            query_embeddings[i]=cluster_avg_embeddings[cluster_id]
            
        print("Query embeddings shape after averaging: ",query_embeddings.shape)
        

        
    if not weighted_accuracy:
        ###Keep only 1 query (or embedding) per cluster
        unique_cluster_ids=set(cluster_ids)
        unique_cluster_indices={cluster_id:cluster_ids.index(cluster_id) for cluster_id in unique_cluster_ids}
        query_embeddings=[query_embeddings[i] for i in unique_cluster_indices.values()]
        ##Convert to array
        query_embeddings=np.array(query_embeddings)
        queries=[queries[i] for i in unique_cluster_indices.values()]
        entity_labels=[entity_labels[i] for i in unique_cluster_indices.values()]
        qid_list=[qid_list[i] for i in unique_cluster_indices.values()]
        article_ids=[article_ids[i] for i in unique_cluster_indices.values()]
        year_list=[year_list[i] for i in unique_cluster_indices.values()]
        cluster_ids=[cluster_ids[i] for i in unique_cluster_indices.values()]
        print("Query embeddings shape after removing duplicates: ",query_embeddings.shape)
    
    
    print("Query embeddings shape: ",query_embeddings.shape)
    query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

    # Find nearest neighbours
    res = faiss.StandardGpuResources()

    d = query_embeddings.shape[1]

    gpu_index_flat = faiss.GpuIndexFlatIP(res, d)
    
    ##Normalize the embeddings
    fp_embeddings = fp_embeddings / np.linalg.norm(fp_embeddings, axis=1, keepdims=True)
    
    gpu_index_flat.add(fp_embeddings)
    print(len(fp_embeddings), (len(featurised_fp_text)))


    ##RErank based on date
    if (date_rerank_threshold and median_year_list) and not (qrank_list and qrank_rerank_threshold):
        ###Rerank based on date difference
        ## Only rerank among those that are within date_rerank_threshold 
        distances, neighbours = gpu_index_flat.search(query_embeddings, date_rerank_k)
        reranked_neighbours, _ = rerank_by_date(median_year_list, year_list, date_rerank_threshold, distances, neighbours)
        neighbours=reranked_neighbours
        ###Keep only the top k
        neighbours=[n[:k] for n in neighbours]
    elif (qrank_list and qrank_rerank_threshold) and not (date_rerank_threshold and median_year_list):
        ##Rerank based on QRank
        distances, neighbours = gpu_index_flat.search(query_embeddings, qrank_rerank_k)
        reranked_neighbours, _= rerank_by_qrank(qrank_list, qrank_rerank_threshold, distances, neighbours)
        neighbours=reranked_neighbours
        ###Keep only the top k
        neighbours=[n[:k] for n in neighbours]
    
    elif (date_rerank_threshold and median_year_list) and (qrank_list and qrank_rerank_threshold):
        ##We first rerank based on date, then rerank based on QRank
        distances, neighbours = gpu_index_flat.search(query_embeddings, qrank_rerank_k)
        
        reranked_neighbours, reranked_distances= rerank_by_qrank(qrank_list, qrank_rerank_threshold, distances, neighbours)

        reranked_neighbours=[n[:date_rerank_k] for n in reranked_neighbours]
        reranked_distances=[n[:date_rerank_k] for n in reranked_distances]
        ##Now rerank based on QRank
        reranked_neighbours,_ = rerank_by_date(median_year_list, year_list, date_rerank_threshold, reranked_distances, reranked_neighbours)
        neighbours=reranked_neighbours
        ###Keep only the top k
        neighbours=[n[:k] for n in neighbours]        

    else:
        distances, neighbours = gpu_index_flat.search(query_embeddings, k)


    gpu_index_flat.reset()
    


    if th is not None:

        for threshold in th:

            print(f'** Results for {threshold} threshold, top {k} **')

            pred_labels = []
            pred_label_text=[]

            for i, nn_list in enumerate(neighbours):

                preds = [fp_ids[nn] for j, nn in enumerate(nn_list) if distances[i][j] >= threshold]
                preds_text=[fp_labels[nn] for j, nn in enumerate(nn_list) if distances[i][j] >= threshold]
                pred_labels.append(preds)
                pred_label_text.append(preds_text)
                
                
                

            eval(pred_labels, qid_list)

    else:
        pred_labels = [[fp_ids[n] for n in l] for l in neighbours]
        pred_label_text=[[fp_labels[n] for n in l] for l in neighbours]
        eval(pred_labels, qid_list)
    
    if featurised_fp_text:
    ##Create a folder to save prediction output - query text, matched text for each query, ground truth and predicted labels  as csv
        save_path = f'{trained_model_path}/predictions/{data}'
        ##Create the folder if it doesn't exist - recursively if needed
        os.makedirs(save_path, exist_ok=True)
        assert len(queries)==len(neighbours)==len(qid_list)==len(pred_labels)
        
        
        result_df=pd.DataFrame()
        result_df['query_text'] = queries
        result_df['matched_text'] = [[featurised_fp_text[n] for n in l] for l in neighbours] 
        result_df['ground_truth'] = qid_list
        result_df['predicted_labels'] = pred_labels
        result_df['gt_entity_label'] = entity_labels
        result_df['pred_entity_label']=pred_label_text
        result_df['correct'] = [1 if gt in pred else 0 for gt,pred in zip(qid_list,pred_labels)]
        
        if full_dict:
            ##Add the actual text of the ground truth entity - if available. Else add None
            result_df['gt_entity_text'] = [full_dict[gt]['text'] if gt in full_dict else None for gt in qid_list]
        
        ##Drop duplicates by query text
        result_df.drop_duplicates(subset=['query_text'],inplace=True)
        
        result_df.to_csv(os.path.join(save_path, 'predictions.csv'), index=False)
    
    if calculate_string_match:
        calculate_string_match_accuracy(mention_text_list,qid_list,fp_labels,fp_ids)
        
            

    

 
        
        


if __name__ == '__main__':

    # trained_model_path = '/mnt/data01/entity/trained_models/cgis_model_ent_mark_incontext_90' # not finetuned
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/entity_split_newspaper_wiki_coref_disamb_more_incontext' # finetuned
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymmetric_disambiguation_full_100' # assym wiki only
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymm_newspapers_0.3530040607711171_64_5_0.9285210198462236/' # assym wiki + news
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model

    # trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled"
    
    # trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_epoch_1" #Best so far
    # trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_newsft_epoch_1" 
    # trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_newsft"
    # trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_newsft_old" 
    trained_model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_2e6"#Best on date

    # model= SentenceTransformer(trained_model_path)
    # print(model.max_seq_length)
    # exit()

    stoks={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}
    
    instance_types_to_keep=set(['human','human Biblical figure', 'mythical character', 'religious character', 'historical character', 
                            'supernatural being',
                            'fictional character', 'television character', 'fictional human', 'literary character', 
                            'film character', 'animated character', 'musical theatre character', 'theatrical character'])

    BIRTH_DATE_CUTOFF=1960
    
    # Load fp embeddings if they exist, creates them if not
    fp_embeddings, fp_ids, qid_list, median_year_list, birth_year_list,qrank_list, fp_dict_only_text, fp_dict_cleaned = get_fp_embeddings(
        fp_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json',
        embedding_model_path = trained_model_path,
        special_tokens = stoks,
        use_multi_gpu=True,
        override_max_seq_length=256,
        re_embed=False,
        pre_featurised=True,
        asymm_model=False,        
    )
    

    
    # if os.path.exists(f'{trained_model_path}/featurised_fps.pkl'):
    #     with open(f'{trained_model_path}/featurised_fps.pkl', 'rb') as f:
    #         featurised_fp_text = pickle.load(f)
    # else:
    #     featurised_fp_text = None
    

        
        
    
    

    # # Options for the evaluation data:
    # dt = "newspapers_hns_restricted_dev"
    # dt = "newspapers_hns_restricted_test"
    # dt = "newspapers_hns_extended_dev"
    # dt = "newspapers_hns_extended_test"
    dt = "newspapers_sotu"

    evaluate(
        trained_model_path,
        data=dt,
        ent_featurisation='ent_mark',
        date_featurisation='prepend_1',
        special_tokens=stoks,
        fp_embeddings = fp_embeddings,
        fp_ids = qid_list,
        th=[0.87,0.8729,0.873] , # list of values to report for, not thresholded if not supplied
        k=1,    # nearest neighbour count
        featurised_fp_text=fp_dict_only_text,
        override_max_seq_length=256,
        keep_entity_types=['PER'],
        fp_labels=fp_ids,
        full_dict=fp_dict_cleaned,
        asymm_model=False,
        prepared_ds_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/github_repos/end-to-end-pipeline/ds_coref_date_clean.json", #"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/github_repos/end-to-end-pipeline/ds_coref_article_clean.json", ##_article or _date /mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned.json
        pre_coreffed_ds=True,
        weighted_accuracy=True,
        average_cluster_embs=True,
        # median_year_list=median_year_list,
        # date_rerank_threshold=0.001,
        # date_rerank_k=2,
        qrank_list=qrank_list,
        qrank_rerank_threshold=0.01,
        qrank_rerank_k=10,
        # calculate_string_match=True
    )

    # Todo:
    # - Eval for different types of entity
    # - Add option to do people only (ie. only include wikipedia fps which refer to people and only evaluate people)
    # - Code for selecting the best threshold