import os
import sys
import json
import pickle
from tqdm import tqdm

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd

##Import ari to calculate the ARI
from sklearn.metrics.cluster import adjusted_rand_score


from data_fns import  prep_sotu_data

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)

from nlp_utils.modified_sbert.cluster_fns import cluster

def article_id_to_date(article_id):
    """Convert article id to date"""
    article_id=article_id.split("-p-")[0]
    ##split by - and take the last 3 elements
    date=article_id.split("-")[-3:]
    date="-".join(date)
    return date
    
    
    
    
    

def check_position_of_entity_token_in_text(text):
    """Check where [M] token is in the text"""
    ##can't split as sometimes, there is no whitespace after [M]
    text=text.replace("[M]", " [M] ")
    text=text.replace("[/M]", " [/M] ")
    text=text
    text_list=text.split(" ")
    for i in range(len(text_list)):
        if text_list[i]=="[M]":
            return i
def coref_within_article(model, ent_featurisation, date_featurisation, special_tokens, cluster_params,
             override_max_seq_length=None,keep_entity_types=[],save_path='ds_coref_article.json'):
    """" Collect all the entities in the same article and check if they are coreferent. """
    ds = prep_sotu_data(
            dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            date_featurisation=date_featurisation,
            override_max_seq_length=override_max_seq_length,
            keep_entity_types=keep_entity_types,
            prepared_dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned.json"
        )
    
    
    ##This is currently of the format list of text dicts. Each dict has 'text', 'entity', 'art_id'
    
    ##Make a dict with article_id as key and list of text dicts as value
    article_id_list=[]
    text_list=[]
    qid_list=[]
    entity_labels=[]
    text_indices=[]
    for i in range(len(ds)):
        article_id_list.append(ds[i]['art_id'])
        text_list.append(ds[i]['text'])
        qid_list.append(ds[i]['wiki_entity'])
        entity_labels.append(ds[i]['entity'])
        text_indices.append(i)
    
    print("Total number of texts: ", len(text_list)) 
    
    ##Embed the texts
    text_embeddings = model.encode(text_list)
    
    ###Now, we want to cluster embeddings of the same article
    ##First, we need to get the embeddings of the same article
    article_id_list=np.array(article_id_list)
    unique_article_ids=np.unique(article_id_list)
    
    ##Make a dict with article_id as key and list of text embeddings as value
    article_id_to_text_embeddings={}
    article_id_to_text={}
    article_id_to_qid={}
    article_id_to_entity={}
    for i in range(len(unique_article_ids)):
        article_id=unique_article_ids[i]
        text_embeddings_for_article=text_embeddings[article_id_list==article_id]
        article_id_to_text_embeddings[article_id]=text_embeddings_for_article
        article_id_to_text[article_id]=np.array(text_list)[article_id_list==article_id]
        article_id_to_qid[article_id]=np.array(qid_list)[article_id_list==article_id]
        article_id_to_entity[article_id]=np.array(entity_labels)[article_id_list==article_id]
    
    ##Check number of texts in total
    total_texts=0
    for article_id in article_id_to_text.keys():
        total_texts+=len(article_id_to_text[article_id])
    print("Total number of texts: ", total_texts)    
    
    ##Cluster the embeddings for each article
    print("Total number of articles: ", len(unique_article_ids))
    article_text_cluster_ids={}
    for i in range(len(unique_article_ids)):
        article_id=unique_article_ids[i]
        text_embeddings_for_article=article_id_to_text_embeddings[article_id]
        if text_embeddings_for_article.shape[0]==1:
            cluster_ids={0: [0]}
        else:
            cluster_ids=cluster("agglomerative", cluster_params=cluster_params, corpus_embeddings=text_embeddings_for_article)
        
        article_text_cluster_ids[article_id]={}
        article_text_cluster_ids[article_id]["cluster_dict"]=cluster_ids
        article_text_cluster_ids[article_id]["cluster_text"]={}
        article_text_cluster_ids[article_id]["cluster_qid"]={}
        article_text_cluster_ids[article_id]["cluster_entity"]={}
        for cluster_id_key in cluster_ids.keys():
            cluster_id_list=cluster_ids[cluster_id_key]
            article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key]=[]
            article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key]=[]
            article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key]=[]
            for j in range(len(cluster_id_list)):
                text_index=cluster_id_list[j]
                article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key].append(article_id_to_text[article_id][text_index])
                article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key].append(article_id_to_qid[article_id][text_index])
                article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key].append(article_id_to_entity[article_id][text_index])
                
    
    total_texts=0
    for article_id in article_text_cluster_ids.keys():
        for cluster_id in article_text_cluster_ids[article_id]["cluster_text"].keys():
            total_texts+=len(article_text_cluster_ids[article_id]["cluster_text"][cluster_id])
    print("Total number of texts: ", total_texts)
    
    ##Save for viewing
    with open('article_text_cluster_ids.json', 'w') as f:
        json.dump(article_text_cluster_ids, f)
        
    ##For each article, keep only 1 text per cluster - the one with the entity token closest to the start
    
    ##First, get the position of the entity token in each text
    article_text_cluster_ids_with_pos={}
    for article_id in article_text_cluster_ids.keys():
        article_text_cluster_ids_with_pos[article_id]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_dict"]=article_text_cluster_ids[article_id]["cluster_dict"]
        article_text_cluster_ids_with_pos[article_id]["cluster_text"]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_qid"]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_entity"]={}
        for cluster_id_key in article_text_cluster_ids[article_id]["cluster_text"].keys():
            cluster_text_list=article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key]
            cluster_text_pos_list=[]
            for i in range(len(cluster_text_list)):
                cluster_text_pos_list.append(check_position_of_entity_token_in_text(cluster_text_list[i]))
            cluster_text_pos_list=np.array(cluster_text_pos_list)
            cluster_text_list=np.array(cluster_text_list)
            sorted_indices=np.argsort(cluster_text_pos_list)
            sorted_cluster_text_list=cluster_text_list[sorted_indices]
            sorted_cluster_qid_list=np.array(article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key])[sorted_indices]
            sorted_cluster_entity_list=np.array(article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key])[sorted_indices]
            
            article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key]=sorted_cluster_text_list
            article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key]=sorted_cluster_qid_list
            article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key]=sorted_cluster_entity_list
            
    ##Save for viewing
    ##Ensure json serializable
    article_text_cluster_ids_with_pos_json={}
    for article_id in article_text_cluster_ids_with_pos.keys():
        article_text_cluster_ids_with_pos_json[article_id]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_dict"]=article_text_cluster_ids_with_pos[article_id]["cluster_dict"]
        article_text_cluster_ids_with_pos_json[article_id]["cluster_text"]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_qid"]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_entity"]={}
        for cluster_id_key in article_text_cluster_ids_with_pos[article_id]["cluster_text"].keys():
            article_text_cluster_ids_with_pos_json[article_id]["cluster_text"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key].tolist()
            article_text_cluster_ids_with_pos_json[article_id]["cluster_qid"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key].tolist()
            article_text_cluster_ids_with_pos_json[article_id]["cluster_entity"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key].tolist()
            
        
    with open('article_text_cluster_ids_with_pos.json', 'w') as f:
        json.dump(article_text_cluster_ids_with_pos_json, f)
    
    
    ## Construct new ds based on this, along with qid, entity, art_id. Structure - {art_id: {cluster_id: {text: , qid: , entity: }}}}
    
    ds_new={}
    for article_id in article_text_cluster_ids_with_pos.keys():
        ds_new[article_id]={}
        for cluster_id_key in article_text_cluster_ids_with_pos[article_id]["cluster_text"].keys():
            ds_new[article_id][cluster_id_key]={}
            for i in range(len(article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key])):
                ds_new[article_id][cluster_id_key][i]={}
                ds_new[article_id][cluster_id_key][i]["text"]=article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key][i]
                ds_new[article_id][cluster_id_key][i]["wiki_entity"]=article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key][i]
                ds_new[article_id][cluster_id_key][i]["entity"]=article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key][i]

    
    ##Print number of texts in total
    total_texts=0
    for article_id in ds_new.keys():
        for cluster_id in ds_new[article_id].keys():
            total_texts+=len(ds_new[article_id][cluster_id])
    
    print("Total number of texts: ", total_texts)
    
    ##Save for viewing
    with open(save_path, 'w') as f:
        json.dump(ds_new, f)
    
            
    ###Check the ARI. A label would be art_id_cluster_id
    ##Get the true labels
    true_labels=[]
    predicted_labels=[]
    for article_id in ds_new.keys():
        for cluster_id in ds_new[article_id].keys():
            for i in range(len(ds_new[article_id][cluster_id])):
                true_labels.append(str(article_id)+"_"+str(ds_new[article_id][cluster_id][i]["wiki_entity"]))
                predicted_labels.append(str(article_id)+"_"+str(ds_new[article_id][cluster_id][0]["wiki_entity"]))
    
    
    ##Calculate the ARI
    ari=adjusted_rand_score(true_labels, predicted_labels)
    print("ARI: ", ari)

def coref_within_date(model, ent_featurisation, date_featurisation, special_tokens, cluster_params,
             override_max_seq_length=None,keep_entity_types=[],save_path='ds_coref_date.json'):
    """" Collect all the entities in the same article and check if they are coreferent. """
    ds = prep_sotu_data(
            dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
            model=model,
            special_tokens=special_tokens,
            featurisation=ent_featurisation,
            disamb_or_coref='disamb',
            date_featurisation=date_featurisation,
            override_max_seq_length=override_max_seq_length,
            keep_entity_types=keep_entity_types,
            prepared_dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned.json"
        )
    
    ##This is currently of the format list of text dicts. Each dict has 'text', 'entity', 'art_id'
    
    ##Make a dict with article_id as key and list of text dicts as value
    article_id_list=[]
    text_list=[]
    qid_list=[]
    entity_labels=[]
    text_indices=[]
    for i in range(len(ds)):
        article_id_list.append(ds[i]['art_id'])
        text_list.append(ds[i]['text'])
        qid_list.append(ds[i]['wiki_entity'])
        entity_labels.append(ds[i]['entity'])
        text_indices.append(i)
        
    date_list=[article_id_to_date(article_id) for article_id in article_id_list]
    article_id_list=date_list    
    # print(article_id_list) 
    # exit()
    
    print("Total number of texts: ", len(text_list)) 
    
    ##Embed the texts
    text_embeddings = model.encode(text_list)
    print(text_embeddings.shape)

    ###Now, we want to cluster embeddings of the same article
    ##First, we need to get the embeddings of the same article
    date_list=np.array(date_list)
    article_id_list=np.array(article_id_list)
    unique_article_ids=np.unique(date_list)
    print(unique_article_ids)
    ##Make a dict with article_id as key and list of text embeddings as value
    article_id_to_text_embeddings={}
    article_id_to_text={}
    article_id_to_qid={}
    article_id_to_entity={}
    print(article_id_list)
    print(unique_article_ids)
    for i in range(len(unique_article_ids)):
        article_id=unique_article_ids[i]
        print(article_id)
        text_embeddings_for_article=text_embeddings[article_id_list==article_id]
        print(text_embeddings_for_article.shape)
        article_id_to_text_embeddings[article_id]=text_embeddings_for_article
        article_id_to_text[article_id]=np.array(text_list)[article_id_list==article_id]
        article_id_to_qid[article_id]=np.array(qid_list)[article_id_list==article_id]
        article_id_to_entity[article_id]=np.array(entity_labels)[article_id_list==article_id]
    
    ##Check number of texts in total
    total_texts=0
    for article_id in article_id_to_text.keys():
        total_texts+=len(article_id_to_text[article_id])
    print("Total number of texts: ", total_texts)    
    
    ##Cluster the embeddings for each article
    print("Total number of articles: ", len(unique_article_ids))
    article_text_cluster_ids={}
    for i in range(len(unique_article_ids)):
        article_id=unique_article_ids[i]
        text_embeddings_for_article=article_id_to_text_embeddings[article_id]
        print(text_embeddings_for_article.shape)
        if text_embeddings_for_article.shape[0]==1:
            cluster_ids={0: [0]}
        else:
            cluster_ids=cluster("agglomerative", cluster_params=cluster_params, corpus_embeddings=text_embeddings_for_article)
        
        article_text_cluster_ids[article_id]={}
        article_text_cluster_ids[article_id]["cluster_dict"]=cluster_ids
        article_text_cluster_ids[article_id]["cluster_text"]={}
        article_text_cluster_ids[article_id]["cluster_qid"]={}
        article_text_cluster_ids[article_id]["cluster_entity"]={}
        for cluster_id_key in cluster_ids.keys():
            cluster_id_list=cluster_ids[cluster_id_key]
            article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key]=[]
            article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key]=[]
            article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key]=[]
            for j in range(len(cluster_id_list)):
                text_index=cluster_id_list[j]
                article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key].append(article_id_to_text[article_id][text_index])
                article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key].append(article_id_to_qid[article_id][text_index])
                article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key].append(article_id_to_entity[article_id][text_index])
                
    
    total_texts=0
    for article_id in article_text_cluster_ids.keys():
        for cluster_id in article_text_cluster_ids[article_id]["cluster_text"].keys():
            total_texts+=len(article_text_cluster_ids[article_id]["cluster_text"][cluster_id])
    print("Total number of texts: ", total_texts)
    
    ##Save for viewing
    with open('article_text_cluster_ids.json', 'w') as f:
        json.dump(article_text_cluster_ids, f)
        
    ##For each article, keep only 1 text per cluster - the one with the entity token closest to the start
    
    ##First, get the position of the entity token in each text
    article_text_cluster_ids_with_pos={}
    for article_id in article_text_cluster_ids.keys():
        article_text_cluster_ids_with_pos[article_id]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_dict"]=article_text_cluster_ids[article_id]["cluster_dict"]
        article_text_cluster_ids_with_pos[article_id]["cluster_text"]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_qid"]={}
        article_text_cluster_ids_with_pos[article_id]["cluster_entity"]={}
        for cluster_id_key in article_text_cluster_ids[article_id]["cluster_text"].keys():
            cluster_text_list=article_text_cluster_ids[article_id]["cluster_text"][cluster_id_key]
            cluster_text_pos_list=[]
            for i in range(len(cluster_text_list)):
                ##For dates, we go for the longest article
                cluster_text_pos_list.append(len(cluster_text_list[i]))
            
            cluster_text_pos_list=np.array(cluster_text_pos_list)
            cluster_text_list=np.array(cluster_text_list)
            
            ##Sort longest to shortest
            sorted_indices=np.argsort(cluster_text_pos_list)[::-1]
            
            sorted_cluster_text_list=cluster_text_list[sorted_indices]
            sorted_cluster_qid_list=np.array(article_text_cluster_ids[article_id]["cluster_qid"][cluster_id_key])[sorted_indices]
            sorted_cluster_entity_list=np.array(article_text_cluster_ids[article_id]["cluster_entity"][cluster_id_key])[sorted_indices]
            
            article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key]=sorted_cluster_text_list
            article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key]=sorted_cluster_qid_list
            article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key]=sorted_cluster_entity_list
            
    ##Save for viewing
    ##Ensure json serializable
    article_text_cluster_ids_with_pos_json={}
    for article_id in article_text_cluster_ids_with_pos.keys():
        article_text_cluster_ids_with_pos_json[article_id]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_dict"]=article_text_cluster_ids_with_pos[article_id]["cluster_dict"]
        article_text_cluster_ids_with_pos_json[article_id]["cluster_text"]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_qid"]={}
        article_text_cluster_ids_with_pos_json[article_id]["cluster_entity"]={}
        for cluster_id_key in article_text_cluster_ids_with_pos[article_id]["cluster_text"].keys():
            article_text_cluster_ids_with_pos_json[article_id]["cluster_text"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key].tolist()
            article_text_cluster_ids_with_pos_json[article_id]["cluster_qid"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key].tolist()
            article_text_cluster_ids_with_pos_json[article_id]["cluster_entity"][cluster_id_key]=article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key].tolist()
            
        
    with open('article_text_cluster_ids_with_pos.json', 'w') as f:
        json.dump(article_text_cluster_ids_with_pos_json, f)
    
    
    ## Construct new ds based on this, along with qid, entity, art_id. Structure - {art_id: {cluster_id: {text: , qid: , entity: }}}}
    
    ds_new={}
    for article_id in article_text_cluster_ids_with_pos.keys():
        ds_new[article_id]={}
        for cluster_id_key in article_text_cluster_ids_with_pos[article_id]["cluster_text"].keys():
            ds_new[article_id][cluster_id_key]={}
            for i in range(len(article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key])):
                ds_new[article_id][cluster_id_key][i]={}
                ds_new[article_id][cluster_id_key][i]["text"]=article_text_cluster_ids_with_pos[article_id]["cluster_text"][cluster_id_key][i]
                ds_new[article_id][cluster_id_key][i]["wiki_entity"]=article_text_cluster_ids_with_pos[article_id]["cluster_qid"][cluster_id_key][i]
                ds_new[article_id][cluster_id_key][i]["entity"]=article_text_cluster_ids_with_pos[article_id]["cluster_entity"][cluster_id_key][i]

    
    ##Print number of texts in total
    total_texts=0
    for article_id in ds_new.keys():
        for cluster_id in ds_new[article_id].keys():
            total_texts+=len(ds_new[article_id][cluster_id])
    
    print("Total number of texts: ", total_texts)
    
    ##Save for viewing
    with open(save_path, 'w') as f:
        json.dump(ds_new, f)
        
    ###Check the ARI. A label would be art_id_cluster_id
    ##Get the true labels
    true_labels=[]
    predicted_labels=[]
    for article_id in ds_new.keys():
        for cluster_id in ds_new[article_id].keys():
            for i in range(len(ds_new[article_id][cluster_id])):
                true_labels.append(str(article_id)+"_"+str(ds_new[article_id][cluster_id][i]["wiki_entity"]))
                predicted_labels.append(str(article_id)+"_"+str(ds_new[article_id][cluster_id][0]["wiki_entity"]))
    
    ##Calculate the ARI
    ari=adjusted_rand_score(true_labels, predicted_labels)
    print("ARI: ", ari)

if __name__ == '__main__':

    # trained_model_path = '/mnt/data01/entity/trained_models/cgis_model_ent_mark_incontext_90' # not finetuned
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/entity_split_newspaper_wiki_coref_disamb_more_incontext' # finetuned
    trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymmetric_disambiguation_full_100' # assym wiki only
    # trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/asymm_newspapers_0.3530040607711171_64_5_0.9285210198462236/' # assym wiki + news
    
    model= SentenceTransformer(trained_model_path)
    # print(model.max_seq_length)
    # exit()

    stoks={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}
    
    print("Coref within article")
    coref_within_article(model, 'ent_mark', 'prepend_1', stoks,override_max_seq_length=256,
        keep_entity_types=['PER'], cluster_params={'threshold': 0.15, 'clustering linkage': 'average', 'metric': 'cosine'},save_path='ds_coref_article_clean.json')
    
    print("Coref within date")
    coref_within_date(model, 'ent_mark', 'prepend_1', stoks,override_max_seq_length=256,
        keep_entity_types=['PER'], cluster_params={'threshold': 0.15, 'clustering linkage': 'average', 'metric': 'cosine'},save_path='ds_coref_date_clean.json')

        
        