import os
import json
import pickle
from tqdm import tqdm
import openai
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd
from typing import List, Optional, Dict, Tuple, Union

from data_fns import featurise_data, prep_newspaper_data, prep_sotu_data, clean_entity

path_to_json_1 = "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned_intermediate.json"
path_to_json_2 = "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned_intermediate_2_gpt4.json"


trained_model_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext' # coref model


##Load list 1
with open(path_to_json_1, 'r') as f:
    list_clean_1 = json.load(f)
with open(path_to_json_2, 'r') as f:
    list_clean_2 = json.load(f)

list_clean=list_clean_1+list_clean_2

print("Loaded cleaned text list")
print("Length of cleaned text list: ", len(list_clean))

model= SentenceTransformer(trained_model_path)
print(model.max_seq_length)

stoks={'men_start': "[M]", 'men_end': "[/M]", "sep": '</s>'}

instance_types_to_keep=set(['human','human Biblical figure', 'mythical character', 'religious character', 'historical character', 
                        'supernatural being',
                        'fictional character', 'television character', 'fictional human', 'literary character', 
                        'film character', 'animated character', 'musical theatre character', 'theatrical character'])

model= SentenceTransformer(trained_model_path)
ds = prep_sotu_data(
        dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_reformatted_wp_ids.json',
        model=model,
        special_tokens=stoks,
        featurisation='ent_mark',
        disamb_or_coref='disamb',
        date_featurisation='prepend_1',
        override_max_seq_length=256,
        keep_entity_types=['PER'],
    )

with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids.json', 'w') as f:
    json.dump(ds, f)

##Replace the text in the dataset
for i, d in enumerate(ds):
    d['text']=list_clean[i] if not (list_clean[i] == "Original") else d['text']
    ##Add some space after [M] and before [/M]
    d['text']=d['text'].replace("[M]", "[M] ").replace("[/M]", " [/M]")



##Save the cleaned dataset
with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sotu_disamb_format_wp_ids_cleaned.json', 'w') as f:
    json.dump(ds, f)