import os
import json
import pickle
from tqdm import tqdm

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd

from data_fns import featurise_data, prep_newspaper_data, prep_sotu_data, clean_entity


def get_fp_embeddings(fp_path, embedding_model_path, special_tokens=None, ent_featurisation='ent_mark',
                      use_multi_gpu=False,pre_featurised=True,
                      re_embed=False,override_max_seq_length=None,asymm_model=False):

    new_embeddings_needed = True
    save_path = f'{embedding_model_path}/fp_embeddings.pkl'

    with open(fp_path) as f:
        fp_raw = json.load(f)
        
    #prototype by keeping only the first 1000
    # fp_raw = {k:fp_raw[k] for k in list(fp_raw.keys())[:10000]}

    # Create list of wikipedia IDs
    dict_list = []
    wik_ids = []
    qid_list = []

    for _, fp in tqdm(fp_raw.items()):

        if fp != {}:
            wik_ids.append(fp['wiki_title'])
            dict_list.append(fp)
            qid_list.append(fp['wikidata_info']['wikidata_id'])

    # Load embeddings if previously created
    if os.path.exists(save_path) and re_embed == False:

        print("Previous embeddings found, loading previous embeddings ...")
        with open(save_path, 'rb') as f:
            embeddings = pickle.load(f)

        if len(embeddings) == len(wik_ids):

            print('Loaded embeddings have same length as data. Using loaded embeddings')
            new_embeddings_needed = False

        else:
            print('Loaded embeddings are different length to data. Creating new embeddings')

    else:
        print("No previous embeddings found./ Rembedding is set to ",re_embed)

    # Otherwise create new embeddings
    if new_embeddings_needed:
        print("Creating new embeddings ...")

        # Load model
        model = SentenceTransformer(embedding_model_path)
        

        for fp in dict_list:
            fp['context'] = fp['text']

        if not pre_featurised or os.path.exists(f'{embedding_model_path}/featurised_fps.pkl') == False:
            print("Featurising data ...")
            if not asymm_model:
                featurised_fps = featurise_data(dict_list, featurisation=ent_featurisation, special_tokens=special_tokens,model=model,override_max_seq_length=override_max_seq_length)
            else:
                print("Using asymmetric model featurisation - prepend featurisation and 'FP' key for each fp.")
                featurised_fps = featurise_data(dict_list, featurisation="prepend", special_tokens=special_tokens,  model=SentenceTransformer("all-mpnet-base-v2"),override_max_seq_length=override_max_seq_length)
                ###For each fp, make it a dict instead. {'FP':text}
                featurised_fps = [{'FP':fp} for fp in featurised_fps]
            ##Save featurised data
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'wb') as f:
                pickle.dump(featurised_fps, f)
        else:
            with open(f'{embedding_model_path}/featurised_fps.pkl', 'rb') as f:
                featurised_fps = pickle.load(f)
        
        if use_multi_gpu:
            print("Using multiple GPUs: ", torch.cuda.device_count())
            pool = model.start_multi_process_pool()
            embeddings = model.encode_multi_process(featurised_fps, batch_size=768, pool=pool)
        else:
            embeddings = model.encode(featurised_fps, show_progress_bar=True, batch_size=768)

        # Normalize the embeddings to unit length
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

        # Save embeddings
        with open(save_path, 'wb') as f:
            pickle.dump(embeddings, f)

    return embeddings, wik_ids, qid_list


def eval(pred_labs, gt_labs):

    in_wiki_count = 0
    not_in_wiki_count = 0

    for i, gt in enumerate(gt_labs):

        if gt in pred_labs[i]:
            in_wiki_count += 1

        elif len(pred_labs[i]) == 0 and (gt == 'Not in wikipedia' or gt == '' or gt==None):
            not_in_wiki_count += 1

    all_acc = (in_wiki_count + not_in_wiki_count)/len(gt_labs)
    in_wiki = len([i for i in gt_labs if (i != 'Not in wikipedia' and i != '' and i!=None)])
    not_in_wiki = len(gt_labs) - in_wiki

    print(f'Total accuracy: {all_acc}')
    print(f'In wikipedia accuracy: {in_wiki_count/in_wiki}')
    print(f'Not in wikipedia accuracy: {not_in_wiki_count/not_in_wiki}')


def evaluate(trained_model_path, data, ent_featurisation, date_featurisation, fp_embeddings, fp_ids, special_tokens, th=None, k=1,
             featurised_fp_text=None,override_max_seq_length=None,
             keep_entity_types=[],fp_labels=None,full_dict=None,asymm_model=False,
             prepared_ds_path=None,pre_coreffed_ds=False): 

    model = SentenceTransformer(trained_model_path)

    # Extract and featurise data.
    if data in ["newspapers_hns_extended_dev", "newspapers_hns_extended_test"]:
        train_data, dev_data, test_data = prep_newspaper_data(
            dataset_path = '/mnt/data01/entity/labelled_data/new_with_negs/labeled_datasets_full_extended.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            input_examples=False,
            date_featurisation=date_featurisation
        )

        if data == "newspapers_hns_extended_dev":
            ds = dev_data
        elif data == "newspapers_hns_extended_test":
            ds = test_data

    elif data in ["newspapers_hns_restricted_dev", "newspapers_hns_restricted_test"]:
        train_data, dev_data, test_data = prep_newspaper_data(
            dataset_path = '/mnt/data01/entity/labelled_data/new_with_negs/labeled_datasets_full_restricted.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            input_examples=False,
            date_featurisation=date_featurisation
        )

        if data == "newspapers_hns_restricted_dev":
            ds = dev_data
        elif data == "newspapers_hns_restricted_test":
            ds = test_data

    elif data == "newspapers_sotu":
        ds = prep_sotu_data(
            dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            date_featurisation=date_featurisation,
            override_max_seq_length=override_max_seq_length,
            keep_entity_types=keep_entity_types,
            asymm_model=asymm_model
        )
    
    if prepared_ds_path:
        print("overriding the prepared ds path")
        with open(prepared_ds_path, 'r') as f:
            ds = json.load(f)
    
    queries = []
    entity_labels = []
    qid_list = []
    article_ids=[]
    cluster_ids=[]
    
    if not pre_coreffed_ds:

        for d in ds:
            queries.append(d['text'])
            entity_labels.append(d['entity'])
            qid_list.append(d['wiki_entity'])
            article_ids.append(d['art_id'])

    else : ##Struct is {art_id:cluster_id:text_id:{text_dict}}
        for art_id in ds:
            for cluster_id in ds[art_id]:
                for text_id in ds[art_id][cluster_id]:
                    queries.append(ds[art_id][cluster_id]["0"]['text'])
                    entity_labels.append(ds[art_id][cluster_id]["0"]['entity'])
                    qid_list.append(ds[art_id][cluster_id]["0"]['wiki_entity'])
                    article_ids.append(art_id)
                    cluster_ids.append(cluster_id)
    
                
    
    ##If cluster_key is not none, for each art_id, cluster, get only the 
    
    print("Number of queries: ",len(set(queries)))
    
    
    
    # Embed
    query_embeddings = model.encode(queries, show_progress_bar=True, batch_size=128)
    print("Query embeddings shape: ",query_embeddings.shape)
    query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

    # Find nearest neighbours
    res = faiss.StandardGpuResources()

    d = query_embeddings.shape[1]

    gpu_index_flat = faiss.GpuIndexFlatIP(res, d)
    gpu_index_flat.add(fp_embeddings)
    print(len(fp_embeddings), (len(featurised_fp_text)))

    distances, neighbours = gpu_index_flat.search(query_embeddings, k)

    

    gpu_index_flat.reset()
    


    if th is not None:

        for threshold in th:

            print(f'** Results for {threshold} threshold, top {k} **')

            pred_labels = []
            pred_label_text=[]

            for i, nn_list in enumerate(neighbours):

                preds = [fp_ids[nn] for j, nn in enumerate(nn_list) if distances[i][j] > threshold]
                preds_text=[fp_labels[nn] for j, nn in enumerate(nn_list) if distances[i][j] > threshold]
                pred_labels.append(preds)
                pred_label_text.append(preds_text)
                
                
                

            eval(pred_labels, qid_list)

    else:
        pred_labels = [[fp_ids[n] for n in l] for l in neighbours]
        pred_label_text=[[fp_labels[n] for n in l] for l in neighbours]
        eval(pred_labels, qid_list)
    
    if featurised_fp_text:
    ##Create a folder to save prediction output - query text, matched text for each query, ground truth and predicted labels  as csv
        save_path = f'{trained_model_path}/predictions/{data}'
        ##Create the folder if it doesn't exist - recursively if needed
        os.makedirs(save_path, exist_ok=True)
        assert len(queries)==len(neighbours)==len(qid_list)==len(pred_labels)
        
        
        result_df=pd.DataFrame()
        result_df['query_text'] = queries
        result_df['matched_text'] = [[featurised_fp_text[n] for n in l] for l in neighbours] 
        result_df['ground_truth'] = qid_list
        result_df['predicted_labels'] = pred_labels
        result_df['gt_entity_label'] = entity_labels
        result_df['pred_entity_label']=pred_label_text
        result_df['correct'] = [1 if gt in pred else 0 for gt,pred in zip(qid_list,pred_labels)]
        
        if full_dict:
            ##Add the actual text of the ground truth entity - if available. Else add None
            result_df['gt_entity_text'] = [full_dict[gt]['text'] if gt in full_dict else None for gt in qid_list]
        
        ##Drop duplicates by query text
        result_df.drop_duplicates(subset=['query_text'],inplace=True)
        
        result_df.to_csv(os.path.join(save_path, 'predictions.csv'), index=False)
        
            

    

 
        
        


if __name__ == '__main__':

    # trained_model_path = '/mnt/data01/entity/trained_models/cgis_model_ent_mark_incontext_90' # not finetuned
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/entity_split_newspaper_wiki_coref_disamb_more_incontext' # finetuned
    trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymmetric_disambiguation_full_100' # assym wiki only
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymm_newspapers_0.3530040607711171_64_5_0.9285210198462236/' # assym wiki + news
    
    # model= SentenceTransformer(trained_model_path)
    # print(model.max_seq_length)
    # exit()

    stoks={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}
    
    instance_types_to_keep=set(['human','human Biblical figure', 'mythical character', 'religious character', 'historical character', 
                            'supernatural being',
                            'fictional character', 'television character', 'fictional human', 'literary character', 
                            'film character', 'animated character', 'musical theatre character', 'theatrical character'])

    # Load fp embeddings if they exist, creates them if not
    fp_embeddings, fp_ids, qid_list = get_fp_embeddings(
        fp_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people.json',
        embedding_model_path = trained_model_path,
        special_tokens = stoks,
        use_multi_gpu=True,
        override_max_seq_length=256,
        re_embed=False,
        pre_featurised=True,
        asymm_model=False,
        
        
    )
    
    # if os.path.exists(f'{trained_model_path}/featurised_fps.pkl'):
    #     with open(f'{trained_model_path}/featurised_fps.pkl', 'rb') as f:
    #         featurised_fp_text = pickle.load(f)
    # else:
    #     featurised_fp_text = None
    
    with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people.json', 'r') as f:
        fp_dict_raw = json.load(f)
    
    
    
    
    ##Clean up the dict - to conform with the embeddings - remove empty dicts. 
    fp_dict_cleaned = {k:v for k,v in fp_dict_raw.items() if v!={}}
    
    indices_to_keep = []
    
    ##Find indices that belong to the instance types to keep. It helps that python preserves the order of the dict. entity type is in dict['wikidata_info']['instance_of_labels']
    if len(instance_types_to_keep)>0:
        indices_to_keep = [i for i, (k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if len(set(v['wikidata_info']['instance_of_labels']).intersection(instance_types_to_keep))>0]
        
    ##Filter the embeddings and ids
    print("Number of embeddings: ",len(fp_embeddings))
    indices_to_keep=set(indices_to_keep)

    fp_embeddings = np.array([v for i,v in enumerate(tqdm(fp_embeddings)) if i in indices_to_keep])
    fp_ids = [k for i,k in enumerate(fp_ids) if i in indices_to_keep]
    qid_list=[k for i,k in enumerate(qid_list) if i in indices_to_keep]
    print("Number of embeddings after type filtering: ",len(fp_embeddings))
    
    ##Filter the fp_dict_cleaned
    fp_dict_cleaned = {k:v for i,(k,v) in enumerate(tqdm(fp_dict_cleaned.items())) if i in indices_to_keep}
    
    
        
    fp_dict_only_text = [v['text'] for k,v in fp_dict_cleaned.items()]
    
    print(len(fp_embeddings), len(fp_ids), len(fp_dict_only_text))
    # assert len(fp_embeddings)==len(fp_ids)==len(fp_dict_only_text) 
    print("All lengths are equal")
    
    
        
        
        
    
    

    # # Options for the evaluation data:
    # dt = "newspapers_hns_restricted_dev"
    # dt = "newspapers_hns_restricted_test"
    # dt = "newspapers_hns_extended_dev"
    # dt = "newspapers_hns_extended_test"
    dt = "newspapers_sotu"

    evaluate(
        trained_model_path,
        data=dt,
        ent_featurisation='ent_mark',
        date_featurisation='prepend_1',
        special_tokens=stoks,
        fp_embeddings = fp_embeddings,
        fp_ids = qid_list,
        # th=np.arange(0.1, 0.99, 0.01) , # list of values to report for, not thresholded if not supplied
        k=1,    # nearest neighbour count
        featurised_fp_text=fp_dict_only_text,
        override_max_seq_length=256,
        keep_entity_types=['PER'],
        fp_labels=fp_ids,
        full_dict=fp_dict_cleaned,
        asymm_model=False,
        prepared_ds_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/github_repos/end-to-end-pipeline/ds_coref_article.json",
        pre_coreffed_ds=True
    )

    # Todo:
    # - Eval for different types of entity
    # - Add option to do people only (ie. only include wikipedia fps which refer to people and only evaluate people)
    # - Code for selecting the best threshold