import json
import pickle
from tqdm import tqdm
import random 
import itertools
from datetime import datetime
import numpy as np
import pandas as pd
import re
import math

from wikimapper import WikiMapper

from sentence_transformers.readers import InputExample
from transformers import AutoModel, AutoTokenizer
import pickle
import os
from sklearn.model_selection import train_test_split
import multiprocessing
from functools import partial
from sentence_transformers import SentenceTransformer
"""
This whole script needs sorting out and making more logical
"""


def featurise_data(list_of_dicts, featurisation, special_tokens, model,override_max_seq_length=None):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")
        

    max_length = model.max_seq_length - 10 if override_max_seq_length is None else override_max_seq_length - 10

    output_list = []
    for context_dict in tqdm(list_of_dicts):

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        text = context_dict["context"]
        # org_ment_text = context_dict["mention_text"]

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]
            right_text = text[men_end:]
            
    
            left = model.tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = model.tokenizer.encode(mention_text, add_special_tokens=False) 
            right = model.tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*model.max_seq_length)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length 
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = model.tokenizer.decode(left[-left_len:])
                truncated_text = left_text + " " + mention_text + right_text
                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            # assert org_ment_text == text[men_start:men_end]   

        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

        elif featurisation == "prepend":
    
            # Prepend entity with special sep token 
            ent_text = text[men_start:men_end] \
                    + " " + special_tokens['men_sep'] + " " \
                    + text

        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(model.tokenizer.encode(mention_sent))
            encoded_context = model.tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = model.tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])

        output_list.append(ent_text)

    return output_list



def featurise_data_pytorch(list_of_dicts, featurisation, special_tokens, tokenizer,max_sequence_len=128):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """
    ###If model is a string, load it and the relevant tokenizer here. 
    ##Check if model is a string

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")

    max_length = max_sequence_len - 10

    output_list = []
    for context_dict in list_of_dicts:

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        text = context_dict["context"]
        # org_ment_text = context_dict["mention_text"]

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]
            right_text = text[men_end:]
            
            left = tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = tokenizer.encode(mention_text, add_special_tokens=False) 
            right = tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*max_sequence_len)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length 
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = tokenizer.decode(left[-left_len:])
                truncated_text = left_text + " " + mention_text + right_text
                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            # assert org_ment_text == text[men_start:men_end]   

        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

        elif featurisation == "prepend":
    
            # Prepend entity with special sep token 
            ent_text = text[men_start:men_end] \
                    + " " + special_tokens['men_sep'] + " " \
                    + text

        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(tokenizer.encode(mention_sent))
            encoded_context = tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])

        output_list.append(ent_text)

    return output_list

def sub_batch(splits, negatives_dict, ent_id_to_split, ent_id_context_id_dict, context_id_ent_id_dict, samples_per_label=4):

    # Group into subbatches
    batched_contexts = {}
    for split in splits:

        print(f"***** Creating batches from {split} split *****")
        
        negatives = {}
        for ent_id in negatives_dict:
            if ent_id in ent_id_to_split:
                spl = ent_id_to_split[ent_id]
                if spl == split:
                    new_negs = [e for e in negatives_dict[ent_id] if ent_id_to_split[e] == spl]
                    if len(new_negs) > 0:
                        negatives[ent_id] = new_negs

        print(f"{len(negatives)} entities in {split} split with at least one negative in the split")

        positives = {}
        for ent_id in negatives:
            positives[ent_id] = ent_id_context_id_dict[ent_id]

        assert len(positives) == len(negatives)

        all_ents = list(positives.keys())
        ents_remaining = list(positives.keys())
        random.shuffle(ents_remaining)

        batches = []

        while len(ents_remaining) > 0:

            print("Longest remaining: ", max([len(positives[ent_id]) for ent_id in ents_remaining]))

            to_remove = []

            for ent_id in tqdm(ents_remaining):

                batch = []

                # Positives: n examples per cluster, hard negs: n_hard_negs
                pos_list = positives[ent_id] 
                if len(pos_list) >= samples_per_label:

                    related_negatives = {}
                    for nid in negatives[ent_id]:
                        cotxs = positives[nid]
                        if len(cotxs) >= samples_per_label:
                            related_negatives[nid] = [cotxs]

                    # Select first negative
                    if len(related_negatives) > 0: 
                        first_neg = random.choice(list(related_negatives.keys()))
                        for nid in negatives[first_neg]:
                            cotxs = positives[nid]
                            if len(cotxs) >= samples_per_label:
                                related_negatives[nid] = [cotxs]

                        # Select second negative 
                        if len(related_negatives) >= 2:
                            second_neg = random.choice([e for e in list(related_negatives.keys()) if e != first_neg])
                            
                            batch.extend(random.sample(pos_list, samples_per_label))
                            batch.extend(random.sample(positives[first_neg], samples_per_label))
                            batch.extend(random.sample(positives[second_neg], samples_per_label))

                            # Random negatives 
                            found = 0
                            not_to_choose = list(related_negatives.keys()) + negatives[second_neg]
                            while found == 0: 
                                choice = random.choice(all_ents)
                                if choice != ent_id and choice not in not_to_choose:
                                    random_neg_cotxs = positives[choice]
                                    if len(random_neg_cotxs) >= samples_per_label:
                                        found = 1

                            batch.extend(random.sample(random_neg_cotxs, samples_per_label))

                            # Add to batch list 
                            assert len(batch) == 4*samples_per_label 
                            random.shuffle(batch)
                            batches.append(batch)

                            # Remove selected from dicts
                            for example in batch:
                                positives[context_id_ent_id_dict[example]] = [e for e in positives[context_id_ent_id_dict[example]] if e != example]

                        else:
                            to_remove.append(ent_id)
                    else:
                        to_remove.append(ent_id)                        
                else:
                    to_remove.append(ent_id)

            print(datetime.now())
            temp_er = set(ents_remaining)
            to_remove = set(to_remove)
            ents_remaining = [ent for ent in temp_er if ent not in to_remove]

            print(len(ents_remaining), "entities left")

        batched_contexts[split] = batches

        # Print some useful things to terminal 
        print(f'{len(batches)} sub-batches in {split} split')
        print(f'{len(batches) * samples_per_label * 4} contexts used')
        print(sum([len(val) for val in list(positives.values())]), "contexts not used")

    return(batched_contexts)


def prep_wikipedia_data(
    dataset_path, 
    model,
    special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
    featurisation='ent_mark',
    disamb_or_coref='disamb',
    batch_type='supcon_batchhard', 
    samples_per_label = 8, 
    batch_size=16,
    small=False
    ):

    """
    Opens data, featurises contexts and groups into sub-batches with hard negatives. 
        
    Trims sequences if longer than model.max_seq_length: 
    - ent_mark:
    - prepend: 
    - hsu_hor: Trims context sentences first, then mention sentence if still too long
    """

    if batch_size % (samples_per_label) != 0:
        raise ValueError("samples_per_label must be a divisor of batch_size")
    if batch_type not in ['supcon_batchhard', 'contrastive_batchhard']:
        raise ValueError("unsupported batch type")
    if disamb_or_coref == 'disamb' and batch_type != 'contrastive_batchhard':
        raise ValueError("if disamb_or_coref is disamb, batch_type must be contrastive_batchhard")


    # Load data 
    if small:
        with open(f'{dataset_path}/small_splits_v2.json') as f:
            splits = json.load(f)
    else:
        with open(f'{dataset_path}/splits_v2.json') as f:
            splits = json.load(f)

    with open(f'{dataset_path}/negatives.pkl', 'rb') as f:
        negatives_dict = pickle.load(f)
    with open(f'{dataset_path}/all_contexts.pkl', 'rb') as f:
        all_contexts = pickle.load(f)

    if disamb_or_coref == 'disamb':
        with open(f'{dataset_path}/cleaned_fp_data.json') as f:
            cleaned_fp_data = json.load(f)


    # Dictionary mapping entity IDs to list of Context IDs and reverse 
    ent_id_context_id_dict = {}
    context_id_ent_id_dict = {}
    context_id_full_context_dict = {}
    for ent_id in all_contexts:
        ent_id_context_id_dict[ent_id] = [context["id"] for context in all_contexts[ent_id]]
        for context in all_contexts[ent_id]:
            context_id_ent_id_dict[context["id"]] = ent_id

            context["ent_id"] = ent_id
            context_id_full_context_dict[context["id"]] = context

    assert len(context_id_ent_id_dict) == sum([len(val) for val in list(ent_id_context_id_dict.values())])

    # Numeric labels for entities
    ent_label_dict = {}
    for i, ent in enumerate(list(all_contexts.keys())):
        ent_label_dict[ent] = i

    # Map ent_id to split
    ent_id_to_split = {}
    for split in splits:
        for ent_id in splits[split]:
            ent_id_to_split[ent_id] = split
    for ent_id in all_contexts:
        if ent_id not in ent_id_to_split:
            ent_id_to_split[ent_id] = "None"

    # Subbatches
    # subbatches = sub_batch(splits, negatives_dict, ent_id_to_split, ent_id_context_id_dict, context_id_ent_id_dict, samples_per_label)

    # with open(f'{dataset_path}/subbatches.pkl', 'wb') as f:
    #     pickle.dump(subbatches, f)

    if small:
        with open(f'{dataset_path}/small_subbatches.pkl', 'rb') as f:
            subbatches = pickle.load(f)
    else: 
        with open(f'{dataset_path}/subbatches.pkl', 'rb') as f:
            subbatches = pickle.load(f)


    outputs = {}

    if batch_type == 'supcon_batchhard':

        # Train data: keep in list as is, rather than converting to pairs 

        print("Featurising train split with cluster labels ...")

        outputs['train'] = []

        # Flatten
        random.shuffle(subbatches['train'])
        flat = []
        for sb in subbatches['train']:
            flat.extend(sb)
        full_contexts = [context_id_full_context_dict[cid] for cid in flat]

        # Featurise
        featurised_text_list = featurise_data(full_contexts, featurisation, special_tokens, model)

        # Convert labels to numbers
        label_list = [ent_label_dict[context["ent_id"]] for context in full_contexts]
        assert len(featurised_text_list) == len(label_list)

        # Covert to input examples 
        for l, lab in enumerate(label_list):
            outputs['train'].append(InputExample(texts=[featurised_text_list[l]], label=lab))

    # For dev and test (if 'supcon_batchhard') and for all splits otherwise: output as pairs
    for spl in [s for s in subbatches if s not in outputs]:

        print(f"Featurising {spl} split with pair labels ...")

        outputs[spl] = []

        split_subbatches = subbatches[spl]
        random.shuffle(split_subbatches)

        for sb in tqdm(split_subbatches):

            full_contexts = [context_id_full_context_dict[cid] for cid in sb]

            featurised_text_list = featurise_data(full_contexts, featurisation, special_tokens, model)

            # For coreference, want all pairs of contexts and whether they are positive or negative wrt each other 
            if disamb_or_coref == 'coref':

                label_list = [ent_label_dict[context["ent_id"]] for context in full_contexts]

                combos = itertools.combinations(range(len(featurised_text_list)), 2)

                for combo in combos:
                    if label_list[combo[0]] == label_list[combo[1]]:
                        label = 1   #same
                    else:
                        label = 0   # different
                    outputs[spl].append(InputExample(texts=[featurised_text_list[combo[0]], featurised_text_list[combo[1]]], label=float(label)))

            # For entity disambiguation, want first paragraphs from wikipedia and whether they are positive or negative wrt those
            elif disamb_or_coref == 'disamb':

                # First paragraphs 
                ent_list = [context["ent_id"] for context in full_contexts]

                unique_ent_list = list(set(ent_list))

                fp_text_dict = {}
                for e in unique_ent_list:
                    fp_text_dict[e] = e +  " " + special_tokens['men_sep'] + " " +  cleaned_fp_data[e]  

                # print(json.dumps(fp_text_dicts, indent=4))
                                
                for i, contx in enumerate(full_contexts):

                    for j, fp in enumerate(fp_text_dict):

                        if contx['ent_id'] == fp:
                            label = 1  #same
                        else:
                            label = 0 # different  

                        # outputs[spl].append(InputExample(texts=[featurised_text_list[i], featurised_fps[j]], label=float(label)))
                        outputs[spl].append(InputExample(texts=[{'CTX':featurised_text_list[i]}, {'FP':fp_text_dict[fp]}], label=float(label))) 

            else:
                raise ValueError("disamb_or_coref must be disamb or coref")

    return outputs["train"], outputs["dev"], outputs["test"]


def prep_wikipedia_data_for_pt_bienc(
    dataset_path,
    model,
    special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
    featurisation='ent_mark',
    disamb_or_coref='disamb',
    batch_type='supcon_batchhard',
    samples_per_label = 8,
    batch_size=16,
    small=False,
    max_seq_length=128,
    save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disamb/outputs.pkl"
    ):

    """
    Opens data, featurises contexts and groups into sub-batches with hard negatives. 
        
    Trims sequences if longer than model.max_seq_length: 
    - ent_mark:
    - prepend: 
    - hsu_hor: Trims context sentences first, then mention sentence if still too long
    """

    ##Check if file exists on save_path
    ##If it does, load it and return it.
    ##If it doesn't, run the rest of the function and save it.

    ###Check if save_path exists
    if os.path.exists(save_path):
        print("File exists, loading it")
        with open(save_path, 'rb') as f:
            outputs = pickle.load(f)

        return outputs["train"], outputs["dev"], outputs["test"]

    else:
        tokenizer=AutoTokenizer.from_pretrained(model)
        if batch_size % (samples_per_label) != 0:
            raise ValueError("samples_per_label must be a divisor of batch_size")
        if batch_type not in ['supcon_batchhard', 'contrastive_batchhard']:
            raise ValueError("unsupported batch type")
        if disamb_or_coref == 'disamb' and batch_type != 'contrastive_batchhard':
            raise ValueError("if disamb_or_coref is disamb, batch_type must be contrastive_batchhard")


        # Load data 
        if small:
            with open(f'{dataset_path}/small_splits_v2.json') as f:
                splits = json.load(f)
        else:
            with open(f'{dataset_path}/splits_v2.json') as f:
                splits = json.load(f)

        with open(f'{dataset_path}/negatives.pkl', 'rb') as f:
            negatives_dict = pickle.load(f)
        with open(f'{dataset_path}/all_contexts.pkl', 'rb') as f:
            all_contexts = pickle.load(f)

        if disamb_or_coref == 'disamb':
            with open(f'{dataset_path}/cleaned_fp_data.json') as f:
                cleaned_fp_data = json.load(f)


        # Dictionary mapping entity IDs to list of Context IDs and reverse 
        ent_id_context_id_dict = {}
        context_id_ent_id_dict = {}
        context_id_full_context_dict = {}
        for ent_id in all_contexts:
            ent_id_context_id_dict[ent_id] = [context["id"] for context in all_contexts[ent_id]]
            for context in all_contexts[ent_id]:
                context_id_ent_id_dict[context["id"]] = ent_id

                context["ent_id"] = ent_id
                context_id_full_context_dict[context["id"]] = context

        assert len(context_id_ent_id_dict) == sum([len(val) for val in list(ent_id_context_id_dict.values())])

        # Numeric labels for entities
        ent_label_dict = {}
        for i, ent in enumerate(list(all_contexts.keys())):
            ent_label_dict[ent] = i

        # Map ent_id to split
        ent_id_to_split = {}
        for split in splits:
            for ent_id in splits[split]:
                ent_id_to_split[ent_id] = split
        for ent_id in all_contexts:
            if ent_id not in ent_id_to_split:
                ent_id_to_split[ent_id] = "None"

        # Subbatches
        # subbatches = sub_batch(splits, negatives_dict, ent_id_to_split, ent_id_context_id_dict, context_id_ent_id_dict, samples_per_label)

        # with open(f'{dataset_path}/subbatches.pkl', 'wb') as f:
        #     pickle.dump(subbatches, f)

        if small:
            with open(f'{dataset_path}/small_subbatches.pkl', 'rb') as f:
                subbatches = pickle.load(f)
        else: 
            with open(f'{dataset_path}/subbatches.pkl', 'rb') as f:
                subbatches = pickle.load(f)


        
        outputs = {}

        if batch_type == 'supcon_batchhard':
            print("Featurising train split with cluster labels ...")
            outputs['train'] = {
                'sentence_1': [],
                'sentence_2': [],
                'labels': []
            }

            # Flatten
            random.shuffle(subbatches['train'])
            flat = []
            for sb in subbatches['train']:
                flat.extend(sb)
            full_contexts = [context_id_full_context_dict[cid] for cid in flat]

            # Featurise
            featurised_text_list = featurise_data_pytorch(full_contexts, featurisation, special_tokens, tokenizer,max_sequence_len=max_seq_length)

            # Convert labels to numbers
            label_list = [ent_label_dict[context["ent_id"]] for context in full_contexts]
            assert len(featurised_text_list) == len(label_list)

            for l, lab in enumerate(label_list):
                outputs['train']['sentence_1'].append(featurised_text_list[l])
                outputs['train']['sentence_2'].append(featurised_text_list[l])
                outputs['train']['labels'].append(lab)

        # For dev and test (if 'supcon_batchhard') and for all splits otherwise: output as pairs
        for spl in [s for s in subbatches if s not in outputs]:
            print(f"Featurising {spl} split with pair labels ...")
            outputs[spl] = {
                'sentence_1': [],
                'sentence_2': [],
                'labels': []
            }

            split_subbatches = subbatches[spl]
            random.shuffle(split_subbatches)

            for sb in tqdm(split_subbatches):
                full_contexts = [context_id_full_context_dict[cid] for cid in sb]
                featurised_text_list = featurise_data_pytorch(full_contexts, featurisation, special_tokens, tokenizer,max_sequence_len=max_seq_length)

                if disamb_or_coref == 'coref':
                    label_list = [ent_label_dict[context["ent_id"]] for context in full_contexts]
                    combos = itertools.combinations(range(len(featurised_text_list)), 2)

                    for combo in combos:
                        if label_list[combo[0]] == label_list[combo[1]]:
                            label = 1   # same
                        else:
                            label = 0   # different
                        outputs[spl]['sentence_1'].append(featurised_text_list[combo[0]])
                        outputs[spl]['sentence_2'].append(featurised_text_list[combo[1]])
                        outputs[spl]['labels'].append(float(label))

                elif disamb_or_coref == 'disamb':
                    ent_list = [context["ent_id"] for context in full_contexts]
                    unique_ent_list = list(set(ent_list))
                    fp_text_dict = {}
                    for e in unique_ent_list:
                        fp_text_dict[e] = e +  " " + special_tokens['men_sep'] + " " + cleaned_fp_data[e]  
                    for i, contx in enumerate(full_contexts):
                        for j, fp in enumerate(fp_text_dict):
                            if contx['ent_id'] == fp:
                                label = 1  # same
                            else:
                                label = 0  # different
                            outputs[spl]['sentence_1'].append(featurised_text_list[i])
                            outputs[spl]['sentence_2'].append(fp_text_dict[fp])
                            outputs[spl]['labels'].append(float(label))

                else:
                    raise ValueError("disamb_or_coref must be disamb or coref")

        with open(save_path, 'wb') as f:
            pickle.dump(outputs, f,protocol=4)
        return outputs["train"], outputs["dev"], outputs["test"]


def find_sep_token(tokenizer):

    """
    Returns sep token for given tokenizer
    """

    if 'eos_token' in tokenizer.special_tokens_map:
        sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " "
    else:
        sep = " " + tokenizer.special_tokens_map['sep_token'] + " "

    return sep


def clean_entity(e):

    if len(e) > 1:
        cleaned_ent = e[0].upper() + e[1:]
    else:
        cleaned_ent = e[0].upper()

    return cleaned_ent


def featurise_data_with_dates(list_of_dicts, featurisation, date_featurisation, special_tokens, model, override_max_seq_length=None):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")
    if date_featurisation not in ['none', 'prepend_1', 'prepend_2', 'nat_lang']:
        raise ValueError("date_featurisation must be 'none', 'prepend_1', 'prepend_2', 'nat_lang'")

    max_length = model.max_seq_length - 10 if override_max_seq_length is None else override_max_seq_length

    sep = find_sep_token(model.tokenizer)

    wsp = re.compile(r'\s+')

    output_list = []
    for context_dict in list_of_dicts:

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        org_ment_text = context_dict["mention_text"]
        text = context_dict["text"]
        date = context_dict["year"]

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]

            assert re.sub(wsp, '', mention_text) == re.sub(wsp, '', org_ment_text)

            right_text = text[men_end:]
            
            left = model.tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = model.tokenizer.encode(mention_text, add_special_tokens=False) 
            right = model.tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*model.max_seq_length)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = model.tokenizer.decode(left[-left_len:])

                truncated_text = left_text + " " + mention_text + right_text

                # if org_ment_text != truncated_text[len(left_text) + 1:len(left_text) + len(mention_text) +1]:
                #     print(truncated_text)
                #     print("**")
                #     print(text)
                #     print("**")
                #     print(len(left_text))
                #     print(len(mention_text))
                #     print(org_ment_text)
                #     print(truncated_text[len(left_text):len(left_text) + len(mention_text)])

                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            assert  re.sub(wsp, '', org_ment_text) == re.sub(wsp, '', text[men_start:men_end])


        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Featurise dates
            if date_featurisation == 'prepend_1' or date_featurisation == 'prepend_2':
                ent_text = str(date) + sep + ent_text
            elif date_featurisation == 'nat_lang':
                ent_text = "In " + str(date) + ", " + ent_text


        elif featurisation == "prepend":
    
            # Prepend entity with special sep token and featurise dat as well 
            if date_featurisation == 'prepend_1':
                ent_text = str(date) \
                        + sep \
                        + text[men_start:men_end] \
                        + " " + special_tokens['men_sep'] + " " \
                        + text

            elif date_featurisation == 'prepend_2':
                ent_text = text[men_start:men_end] \
                        + " " + special_tokens['men_sep'] + " " \
                        + str(date) \
                        + sep \
                        + text

            elif date_featurisation == 'nat_lang':
                ent_text = text[men_start:men_end] \
                        + " " + special_tokens['men_sep'] + " " \
                        + "In " + str(date) + ", " \
                        + text
            
            elif date_featurisation == "none":
                ent_text = text[men_start:men_end] \
                        + " " + special_tokens['men_sep'] + " " \
                        + text


        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(model.tokenizer.encode(mention_sent))
            encoded_context = model.tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = model.tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])
        output_list.append(ent_text)

    return output_list

def clean_entity(e):

    if len(e) > 1:
        cleaned_ent = e[0].upper() + e[1:]
    else:
        cleaned_ent = e[0].upper()

    return cleaned_ent

def featurise_data_with_dates_flex(list_of_dicts, featurisation, date_featurisation, special_tokens, model):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")
    if date_featurisation not in ['none', 'prepend_1', 'prepend_2', 'nat_lang']:
        raise ValueError("date_featurisation must be 'none', 'prepend_1', 'prepend_2', 'nat_lang'")

    max_length = model.max_seq_length - 10

    sep = find_sep_token(model.tokenizer)

    wsp = re.compile(r'\s+')

    output_list = []
    for context_dict in list_of_dicts:

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        org_ment_text = context_dict["mention_text"]
        text = context_dict["text"]
        if "year" in context_dict:
            date = context_dict["year"]
        else:
            date = None

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]

            assert re.sub(wsp, '', mention_text) == re.sub(wsp, '', org_ment_text)

            right_text = text[men_end:]
            
            left = model.tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = model.tokenizer.encode(mention_text, add_special_tokens=False) 
            right = model.tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*model.max_seq_length)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = model.tokenizer.decode(left[-left_len:])

                truncated_text = left_text + " " + mention_text + right_text

                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            assert  re.sub(wsp, '', org_ment_text) == re.sub(wsp, '', text[men_start:men_end])


        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Featurise dates
            if date:
                if date_featurisation == 'prepend_1' or date_featurisation == 'prepend_2':
                    ent_text = str(date) + sep + ent_text
                elif date_featurisation == 'nat_lang':
                    ent_text = "In " + str(date) + ", " + ent_text


        elif featurisation == "prepend":
            
            if date:
                # Prepend entity with special sep token and featurise dat as well 
                if date_featurisation == 'prepend_1':
                    ent_text = str(date) \
                            + sep \
                            + text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + text

                elif date_featurisation == 'prepend_2':
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + str(date) \
                            + sep \
                            + text

                elif date_featurisation == 'nat_lang':
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + "In " + str(date) + ", " \
                            + text
                
                elif date_featurisation == "none":
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + text


        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(model.tokenizer.encode(mention_sent))
            encoded_context = model.tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = model.tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])
        output_list.append(ent_text)

    return output_list

def prep_newspaper_wiki_data(dataset_path, model, special_tokens, featurisation, date_featurisation, presplit_path=None):
    """Takes a json that contains pairs connecting newspaper-newspaper, newspaper-wiki and wiki-wiki pairs. Everything is sampled already, just featurise pairs.
    Featurise using flex featurisation, which allows for different date featurisation options if date / year exists in the item.
    """
    
    if not presplit_path:
        ##Load data
        with open(dataset_path) as f:
            raw_data = json.load(f)
        
        ##Shuffle data
        random.seed(42)
        
        ###Split into train test val - 80-10-10 at the entity level
        all_entities= raw_data.keys()
        all_entities= list(all_entities)
        
        random.shuffle(all_entities)
        random.seed(42)
        
        train_test_entities=train_test_split(all_entities, test_size=0.2, random_state=42)
        test_val_entities=train_test_split(train_test_entities[1], test_size=0.5, random_state=42)
        
        train_entities, test_entities, val_entities= train_test_entities[0], test_val_entities[0], test_val_entities[1]
        print(len(train_entities), len(test_entities), len(val_entities), "Entities in train, test, val")
        
        train_data=[raw_data[t] for t in train_entities]
        test_data=[raw_data[t] for t in test_entities]
        val_data=[raw_data[t] for t in val_entities]
        
        ##Flatten all
        train_data= [item for sublist in train_data for item in sublist]
        test_data= [item for sublist in test_data for item in sublist]
        val_data= [item for sublist in val_data for item in sublist]
        
        print(len(train_data), len(test_data), len(val_data), "Pairs in train, test, val")
        
       
    
        ##Save splits
        splits={"train":train_data, "test":test_data, "val":val_data}
        
        data_dir="/".join(dataset_path.split("/")[:-1])
        ##Save splits
        with open(f'{data_dir}/splits.json', 'w') as f:
            json.dump(splits, f)
        
    else:
        with open(presplit_path) as f:
            splits = json.load(f)
    
        
        
    
    outputs = {}   ##feat data
    for split in splits:
        outputs[split] = {"sentence_1": [], "sentence_2": [], "labels": []}
        
        for pair in tqdm(splits[split]):
            s_1= pair[0]
            s_2= pair[1]
            label= pair[2]
            if label=="same":
                label=1
            else:
                label=0
                
        
            sentences_feat= featurise_data_with_dates_flex([s_1,s_2], featurisation, date_featurisation, special_tokens, model)
            
            ##Append to outputs
            outputs[split]["sentence_1"].append(sentences_feat[0])
            outputs[split]["sentence_2"].append(sentences_feat[1])
            outputs[split]["labels"].append(label)
            
    return outputs["train"], outputs["test"], outputs["val"]
            
            
            
        

def prep_newspaper_data(dataset_path, model, special_tokens, featurisation, date_featurisation, disamb_or_coref, input_examples=True, wikipedia_path=None):

    # Format data 
    with open(dataset_path) as f:
        raw_data = json.load(f)

    # Split into train, dev, test
    random.seed(42)
    ent_list = list(raw_data.keys())
    random.shuffle(ent_list)

    test_perc = round(0.2 * len(ent_list))

    # ents = {'test': ent_list[:test_perc],
    #         'dev':  ent_list[test_perc: 2* test_perc],
    #         'train': ent_list[2* test_perc:]}

    save_realizations_path = '/mnt/data01/entity/eval_data/realizations.pkl'
    if os.path.exists(save_realizations_path):
        print("File exists, loading it")
        with open(save_realizations_path, 'rb') as f:
            ents = pickle.load(f)
    else:
        ents = {'test': ent_list[:test_perc],
                'dev':  ent_list[test_perc: 2* test_perc],
                'train': ent_list[2* test_perc:]}
        with open(save_realizations_path, 'wb') as f:
            pickle.dump(ents, f, protocol=4)

    outputs = {}
    for split in ents:

        outputs[split] = []

        if disamb_or_coref == 'coref':

            art_ids = []
            pos_count = 0
            neg_count = 0
            hard_neg_count = 0

            for e in raw_data:
                if e in ents[split]:
                    for pair in raw_data[e]:

                        id_0 = pair[0]["wiki_id"] + "_" + pair[0]['text_id']
                        id_1 = pair[1]["wiki_id"] + "_" + pair[1]['text_id']
                        art_ids.append(id_0)
                        art_ids.append(id_1)

                        if pair[2] == "same":
                            label = 1
                            pos_count += 1
                        else:
                            label = 0
                            neg_count += 1
 

                        featurised_texts = featurise_data_with_dates([pair[0], pair[1]], featurisation, date_featurisation, special_tokens, model)

                        if input_examples:
                            outputs[split].append(InputExample(texts=featurised_texts, label=float(label)))
                        else:
                            outputs[split].append(
                                {
                                    "id_0": id_0,
                                    "id_1": id_1, 
                                    "texts": featurised_texts,
                                    "label": label,
                                    "date_0": pair[0]["year"], 
                                    "date_1": pair[1]["year"], 
                                    "entity_group": e
                                })
                    
            print(f"{len(ents[split])} entities in {split}")
            print(f"{len(set(art_ids))} unique articles in {split}")
            print(f"{len(outputs[split])} pairs in {split}")
            print(f"{pos_count} positive pairs")
            print(f"{neg_count + hard_neg_count} negative pairs")
            print(f"{hard_neg_count} hard negative pairs")

        elif disamb_or_coref == 'disamb':

            for e in raw_data:
                if e in ents[split]:

                    all_texts = {}
                    for pair in raw_data[e]:

                        for t in [pair[0], pair[1]]:

                            if t["classification"] == 'positive':
                                t['wp_id'] = e

                            elif isinstance(t["neg_wiki_id"], float) and np.isnan(t["neg_wiki_id"]):
                                t['wp_id'] = 'Not in wikipedia'

                            else:
                                mapper = WikiMapper("index_enwiki-20190420.db")

                                wikipedia_titles = mapper.id_to_titles(t["neg_wiki_id"])

                                if len(wikipedia_titles) > 0:
                                    t['wp_id']  = wikipedia_titles[0].replace("_", " ")

                                else:
                                    t['wp_id'] = 'Not in wikipedia'

                            all_texts[t['text_id']] = t

                    all_texts = list(all_texts.values())

                    ft_contexts = featurise_data_with_dates(all_texts, featurisation, date_featurisation, special_tokens, model)

                    cl_entities = [clean_entity(t['wp_id']) for t in all_texts]

                    df = [{'text': ft_contexts[i], 'entity': cl_entities[i]} for i in range(len(ft_contexts))]

                    outputs[split].extend(df)

        else:
            raise ValueError("disamb_or_coref must be disamb or coref")


    return outputs['train'], outputs['dev'], outputs['test']


def prep_newspaper_data_pytorch(dataset_path, model, special_tokens, featurisation,
                                 date_featurisation, disamb_or_coref, input_examples=True,
                                   wikipedia_path=None,save_realizations_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/realizations.pkl"):

    # First paragraphs
    if disamb_or_coref == 'disamb':
        with open(f'{wikipedia_path}/cleaned_fp_data.json') as f:
            cleaned_fp_data = json.load(f)
        with open(f'{wikipedia_path}/negatives.pkl', 'rb') as f:
            negatives_dict = pickle.load(f)

        random.seed(42)
        shuffled_ents = list(cleaned_fp_data.keys()) 
        random.shuffle(shuffled_ents)

    # Format data 
    with open(dataset_path) as f:
        raw_data = json.load(f)

    # Split into train, dev, test 
    random.seed(42)
    ent_list = list(raw_data.keys())
    random.shuffle(ent_list)

    test_perc = round(0.2 * len(ent_list))

    if os.path.exists(save_realizations_path):
        print("File exists, loading it")
        with open(save_realizations_path, 'rb') as f:
            ents = pickle.load(f)
    else:
        ents = {'test': ent_list[:test_perc],
                'dev':  ent_list[test_perc: 2* test_perc],
                'train': ent_list[2* test_perc:]}
        with open(save_realizations_path, 'wb') as f:
            pickle.dump(ents, f,protocol=4)

    outputs = {}
    for split in ents:
        outputs[split] = {'sentence_1': [], 'sentence_2': [], 'labels': []}
        art_ids = []
        pos_count = 0
        neg_count = 0
        hard_neg_count = 0

        seen = []

        for e in raw_data:
            if e in ents[split]:
                for pair in raw_data[e]:

                    id_0 = pair[0]["wiki_id"] + "_" + pair[0]['text_id']
                    id_1 = pair[1]["wiki_id"] + "_" + pair[1]['text_id']
                    art_ids.append(id_0)
                    art_ids.append(id_1)

                    if pair[2] == "same":
                        label = 1
                        pos_count += 1
                    else:
                        label = 0
                        neg_count += 1
 
                    if disamb_or_coref == 'coref':

                        featurised_texts = featurise_data_with_dates([pair[0], pair[1]], featurisation, date_featurisation, special_tokens, model)

                        if input_examples:
                            outputs[split]["sentence_1"].append(featurised_texts[0])
                            outputs[split]["sentence_2"].append(featurised_texts[1])
                            outputs[split]["labels"].append(label)
                        else:
                            raise ValueError("input_examples must be True for pytorch")
                    
                    elif disamb_or_coref == 'disamb':

                        raise ValueError("disamb not supported for pytorch")


                    else:
                        raise ValueError("disamb_or_coref must be disamb or coref")


        print(f"{len(ents[split])} entities in {split}")
        print(f"{len(set(art_ids))} unique articles in {split}")
        print(f"{len(outputs[split]['sentence_1'])} pairs in {split}")
        print(f"{pos_count} positive pairs")
        print(f"{neg_count + hard_neg_count} negative pairs")
        print(f"{hard_neg_count} hard negative pairs")

    return outputs['train'], outputs['dev'], outputs['test']


def prep_benchmark_data(dataset_path, model, special_tokens, featurisation, disamb_or_coref, split=False, input_examples=True, wikipedia_path=None):

    if disamb_or_coref == 'disamb':
        with open(f'{wikipedia_path}/cleaned_fp_data.json') as f:
            cleaned_fp_data = json.load(f)
        with open(f'{wikipedia_path}/negatives.pkl', 'rb') as f:
            negatives_dict = pickle.load(f)

        random.seed(42)
        shuffled_ents = list(cleaned_fp_data.keys()) 
        random.shuffle(shuffled_ents)

    # if "ecb" in dataset_path:

    #     ecb = pd.read_csv('/mnt/data01/entity/benchmarks/coref/ecbplus.csv')
    #     ecb = ecb.rename(columns={"start_index": "mention_start", "end_index": "mention_end", "xref_tag": "wiki_entity"})

    #     ecb = ecb.to_dict('records')

    #     index_list = []

    #     raw_data = {}
    #     i = 0
    #     for ex in ecb:
    #         ex['index'] = i
    #         i += 1
    #         if ex['wiki_entity'] not in raw_data:
    #             raw_data[ex['wiki_entity']] = []
    #         raw_data[ex['wiki_entity']].append(ex)
    #         index_list.append(ex['index'])

    #     assert len(index_list) == len(set(index_list))

    #     with open('/mnt/data01/entity/benchmarks/coref/ecb.json', 'w') as f:
    #         json.dump(raw_data, f, indent=4)

    # else:

    with open(dataset_path) as f:
        raw_data = json.load(f)

    splits = {'test': {}, 'val': {}, 'train': {}, 'dev': {}}

    if split:
        for e in raw_data:
            for context_dict in raw_data[e]:
                if e not in splits[context_dict["split"]]:
                    splits[context_dict["split"]][e] = []
                splits[context_dict["split"]][e].append(context_dict)

    else:
        # In most cases, just want a test set
        splits['test'] = raw_data

    # Make 'dev' the dev set
    for ent in splits['val']:
        splits['dev'][ent] = splits['val'][ent]
    del splits['val']

    outputs = {}
    for sp in splits:

        outputs[sp] = []

        texts_in_split = []

        for e in splits[sp]:
            texts_in_split.extend(splits[sp][e])

        featurised_texts = featurise_data(texts_in_split, featurisation, special_tokens, model)

        featurised_dict = {}
        for i, text in enumerate(texts_in_split):
            featurised_dict[text["index"]] = featurised_texts[i]

        if disamb_or_coref == 'disamb':

            i = 0
            c = 0

            for text in tqdm(texts_in_split):
                text_id = text["index"]

                ent = text["wiki_entity"]
                if len(ent) > 1:
                    cleaned_ent = ent[0].upper() + ent[1:]
                else:
                    cleaned_ent = ent[0].upper()

                if sp == 'test':
                    outputs[sp].append({'text': featurised_dict[text_id], 'entity': cleaned_ent})
                else:
                    
                    if cleaned_ent in cleaned_fp_data:

                        # Positive
                        fp = cleaned_ent +  " " + special_tokens['men_sep'] + " " +  cleaned_fp_data[cleaned_ent]
                        outputs[sp].append(InputExample(texts=[{'CTX':featurised_dict[text_id]}, {'FP':fp}], label=float(1))) 

                        # Hard neagtives
                        if cleaned_ent in negatives_dict:

                            for hent in negatives_dict[cleaned_ent]:
                                fp = hent + " " + special_tokens['men_sep'] + " " +  cleaned_fp_data[hent]
                                outputs[sp].append(InputExample(texts=[{'CTX':featurised_dict[text_id]}, {'FP':fp}], label=float(0))) 
                                c += 1
                        
                        else:
                            # Random negative[]
                            rent = shuffled_ents[i]
                            fp = rent + " " + special_tokens['men_sep'] + " " +  cleaned_fp_data[rent]
                            outputs[sp].append(InputExample(texts=[{'CTX':featurised_dict[text_id]}, {'FP':fp}], label=float(0))) 
                            i += 1


            print(f"{len(outputs[sp])} texts in {sp}")
            print(i)
            print(c)

        else:
            pairs = itertools.combinations(texts_in_split, 2)

            positives = []
            negatives = []

            for pair in tqdm(pairs):

                id_0 = pair[0]["index"]
                id_1 = pair[1]["index"]

                ft_texts = [featurised_dict[id_0], featurised_dict[id_1]]

                if input_examples:

                    if pair[0]["wiki_entity"] == pair[1]["wiki_entity"]:
                        positives.append(InputExample(texts=ft_texts, label=float(1)))
                    else:
                        negatives.append(InputExample(texts=ft_texts, label=float(0)))

                else:
                    if pair[0]["wiki_entity"] == pair[1]["wiki_entity"]:
                        positives.append(
                            {
                                "id_0": id_0,
                                "id_1": id_1, 
                                "texts": ft_texts,
                                "label": 1,
                                "entity_group": pair[0]["wiki_entity"]
                            })
                    else:
                        negatives.append(
                            {
                                "id_0": id_0,
                                "id_1": id_1, 
                                "texts": ft_texts,
                                "label": 0,
                                "entity_group": pair[0]["wiki_entity"]
                            })

            random.seed(35)
            if sp == "train":
                negatives = random.sample(negatives, len(positives))

            chosen = positives + negatives
            random.shuffle(chosen)

            outputs[sp] = chosen

            print(f"{len(splits[sp])} entities in {sp}")
            print(f"{len(outputs[sp])} pairs in {sp}")
            print(f"{len(positives)} positive pairs")
            print(f"{len(negatives)} negative pairs")

    return outputs['train'], outputs['dev'], outputs['test']

def prep_sotu_data(dataset_path, model, special_tokens, featurisation, date_featurisation,
                   disamb_or_coref, override_max_seq_length=None,
                   keep_entity_types=[],asymm_model=False, prepared_dataset_path=None):

    if prepared_dataset_path:
        with open(prepared_dataset_path) as f:
            outputs = json.load(f)
            
        return outputs
    
    else:
        
        with open(dataset_path) as f:
            raw_data = json.load(f)

        outputs = []

        # Create list of all texts
        texts = []
        for e in raw_data:
            texts.extend(raw_data[e])

        # featurised_texts = featurise_data(texts, featurisation, special_tokens, model)
        if not asymm_model:
            featurised_texts = featurise_data_with_dates(texts, featurisation, date_featurisation, special_tokens, model,override_max_seq_length=override_max_seq_length)
        else:
            print("Using asymm model, so overriding entity featurisation to prepend")
            featurised_texts = featurise_data_with_dates(texts, "prepend", date_featurisation, special_tokens, SentenceTransformer("all-mpnet-base-v2"),override_max_seq_length=override_max_seq_length)

        if disamb_or_coref == 'disamb':
            
            if asymm_model:
                featurised_texts=[{'CTX': f} for f in featurised_texts]

            for i, ft in enumerate(featurised_texts):
                outputs.append({'text': ft, 'entity': clean_entity(texts[i]["wp_id"]),
                                'entity_type': texts[i]["entity_type"],
                                'mention_text': texts[i]["mention_text"],
                                'wiki_entity': texts[i]["wiki_entity"],
                                'art_id': texts[i]["art_id"],'year':texts[i]["year"]})
            
            ##Further processing - if entity is "Not in wikipedia", make wiki_entity "Not in wikipedia" and if wiki_entity is NA then make it "Not in wikipedia"
            for o in outputs:
                if o["entity"] == "Not in wikipedia":
                    o["wiki_entity"] = "Not in wikipedia"
                if o["wiki_entity"] == "NA" or o["wiki_entity"] == None:
                    o["wiki_entity"] = "Not in wikipedia"
            
            ##Filter by entity type
            if len(keep_entity_types)>0:
                outputs= [o for o in outputs if o["entity_type"] in keep_entity_types]
                
            
                

        elif disamb_or_coref == 'coref':

            featurised_dict = {}
            for i, text in enumerate(texts):
                featurised_dict[text["index"]] = featurised_texts[i]

            pairs = itertools.combinations(texts, 2)

            positives = []
            negatives = []

            for pair in tqdm(pairs):

                id_0 = pair[0]["index"]
                id_1 = pair[1]["index"]

                ft_texts = [featurised_dict[id_0], featurised_dict[id_1]]

                if pair[0]["wiki_entity"] == pair[1]["wiki_entity"] and pair[0]["wiki_entity"] != "NA":
                    positives.append(
                        {
                            "id_0": id_0,
                            "id_1": id_1, 
                            "texts": ft_texts,
                            "label": 1,
                            "entity_types": list(set([pair[0]["entity_type"], pair[1]["entity_type"]])),
                            "entity_group": pair[0]["wiki_entity"]
                        })
                else:
                    negatives.append(
                        {
                            "id_0": id_0,
                            "id_1": id_1, 
                            "texts": ft_texts,
                            "label": 0,
                            "entity_types": list(set([pair[0]["entity_type"], pair[1]["entity_type"]])),
                            "entity_group": pair[0]["wiki_entity"]
                        })

            outputs = positives + negatives
            random.shuffle(outputs)

            print(f"{len(raw_data)} entities")
            print(f"{len(texts)} mentions")
            print(f"{len(outputs)} pairs")
            print(f"{len(positives)} positive pairs")
            print(f"{len(negatives)} negative pairs")

        else:
            raise ValueError("disamb_or_coref must be disamb or coref")


        return outputs