import json
import pickle
from tqdm import tqdm
import random 
import itertools
from datetime import datetime
import pandas as pd
import re

from sentence_transformers.readers import InputExample
from transformers import AutoModel, AutoTokenizer
import pickle
import os
from collections import defaultdict
import hashlib

def get_hash(text):
    return hashlib.md5(text.encode()).hexdigest()


dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/"
# with open(f'{dataset_path}/subbatches.pkl', 'rb') as f:
#             subbatches = pickle.load(f)
# print(subbatches["train"])
# with open(f'{dataset_path}/small_splits_v2.json') as f:
#                 cleaned_fp_data = json.load(f)



def sub_batch(splits, negatives_dict, ent_id_to_split, ent_id_context_id_dict, context_id_ent_id_dict, samples_per_label=4):

    # Group into subbatches
    batched_contexts = {}
    for split in splits:

        print(f"***** Creating batches from {split} split *****")
        
        negatives = {}
        for ent_id in negatives_dict:
            if ent_id in ent_id_to_split:
                spl = ent_id_to_split[ent_id]
                if spl == split:
                    new_negs = [e for e in negatives_dict[ent_id] if ent_id_to_split[e] == spl]
                    if len(new_negs) > 0:
                        negatives[ent_id] = new_negs

        print(f"{len(negatives)} entities in {split} split with at least one negative in the split")

        positives = {}
        for ent_id in negatives:
            positives[ent_id] = ent_id_context_id_dict[ent_id]

        assert len(positives) == len(negatives)

        all_ents = list(positives.keys())
        ents_remaining = list(positives.keys())
        random.shuffle(ents_remaining)

        batches = []

        while len(ents_remaining) > 0:

            print("Longest remaining: ", max([len(positives[ent_id]) for ent_id in ents_remaining]))

            to_remove = []

            for ent_id in tqdm(ents_remaining):

                batch = []

                # Positives: n examples per cluster, hard negs: n_hard_negs
                pos_list = positives[ent_id] 
                if len(pos_list) >= samples_per_label:

                    related_negatives = {}
                    for nid in negatives[ent_id]:
                        cotxs = positives[nid]
                        if len(cotxs) >= samples_per_label:
                            related_negatives[nid] = [cotxs]

                    # Select first negative
                    if len(related_negatives) > 0: 
                        first_neg = random.choice(list(related_negatives.keys()))
                        for nid in negatives[first_neg]:
                            cotxs = positives[nid]
                            if len(cotxs) >= samples_per_label:
                                related_negatives[nid] = [cotxs]

                        # Select second negative 
                        if len(related_negatives) >= 2:
                            second_neg = random.choice([e for e in list(related_negatives.keys()) if e != first_neg])
                            
                            batch.extend(random.sample(pos_list, samples_per_label))
                            batch.extend(random.sample(positives[first_neg], samples_per_label))
                            batch.extend(random.sample(positives[second_neg], samples_per_label))

                            # Random negatives 
                            found = 0
                            not_to_choose = list(related_negatives.keys()) + negatives[second_neg]
                            while found == 0: 
                                choice = random.choice(all_ents)
                                if choice != ent_id and choice not in not_to_choose:
                                    random_neg_cotxs = positives[choice]
                                    if len(random_neg_cotxs) >= samples_per_label:
                                        found = 1

                            batch.extend(random.sample(random_neg_cotxs, samples_per_label))

                            # Add to batch list 
                            assert len(batch) == 4*samples_per_label
                             
                            
                            random.shuffle(batch)
                            batches.append(batch)

                            # Remove selected from dicts
                            for example in batch:
                                positives[context_id_ent_id_dict[example]] = [e for e in positives[context_id_ent_id_dict[example]] if e != example]

                        else:
                            to_remove.append(ent_id)
                    else:
                        to_remove.append(ent_id)                        
                else:
                    to_remove.append(ent_id)

            print(datetime.now())
            temp_er = set(ents_remaining)
            to_remove = set(to_remove)
            ents_remaining = [ent for ent in temp_er if ent not in to_remove]

            print(len(ents_remaining), "entities left")

        batched_contexts[split] = batches

        # Print some useful things to terminal 
        print(f'{len(batches)} sub-batches in {split} split')
        print(f'{len(batches) * samples_per_label * 4} contexts used')
        print(sum([len(val) for val in list(positives.values())]), "contexts not used")

    return(batched_contexts)


def featurise_data_pytorch(list_of_dicts, featurisation, special_tokens, tokenizer,max_sequence_len=128):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """
    ###If model is a string, load it and the relevant tokenizer here. 
    ##Check if model is a string

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")

    max_length = max_sequence_len - 10

    output_list = []
    for context_dict in list_of_dicts:

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        text = context_dict["context"]
        # org_ment_text = context_dict["mention_text"]

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]
            right_text = text[men_end:]
            
            left = tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = tokenizer.encode(mention_text, add_special_tokens=False) 
            right = tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*max_sequence_len)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length 
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = tokenizer.decode(left[-left_len:])
                truncated_text = left_text + " " + mention_text + right_text
                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            # assert org_ment_text == text[men_start:men_end]   

        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

        elif featurisation == "prepend":
    
            # Prepend entity with special sep token 
            ent_text = text[men_start:men_end] \
                    + " " + special_tokens['men_sep'] + " " \
                    + text

        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(tokenizer.encode(mention_sent))
            encoded_context = tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])

        output_list.append(ent_text)

    return output_list


def prep_wikipedia_data_coref_hn_for_pt_bienc(
    dataset_path,
    model,
    special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
    featurisation='ent_mark',
    batch_type='contrastive_batchhard',
    samples_per_label = 8,
    batch_size=16,
    small=False,
    max_seq_length=128,
    save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn.pkl"
    ):
    
    """
    Opens data, featurises contexts and groups into sub-batches with hard negatives. 
        
    Trims sequences if longer than model.max_seq_length: 
    - ent_mark:
    - prepend: 
    - hsu_hor: Trims context sentences first, then mention sentence if still too long
    """

    ###Check if save_path exists
    if os.path.exists(save_path):
        print("File exists, loading it")
        with open(save_path, 'rb') as f:
            outputs = pickle.load(f)

        return outputs["train"], outputs["dev"], outputs["test"]
    

    tokenizer=AutoTokenizer.from_pretrained(model)
    if batch_size % (samples_per_label) != 0:
        raise ValueError("samples_per_label must be a divisor of batch_size")
    if batch_type not in ['contrastive_batchhard']:
        raise ValueError("unsupported batch type")

    # Load data 
    if small:
        with open(f'{dataset_path}/small_splits_v2.json') as f:
            splits = json.load(f)
    else:
        with open(f'{dataset_path}/splits_v2.json') as f:
            splits = json.load(f)

    with open(f'{dataset_path}/all_contexts.pkl', 'rb') as f:
        all_contexts = pickle.load(f)

    
    
    with open(f'{dataset_path}/negatives.pkl', 'rb') as f:
                negatives_dict = pickle.load(f)

    # Dictionary mapping entity IDs to list of Context IDs and reverse 
    ent_id_context_id_dict = {}
    context_id_ent_id_dict = {}
    context_id_full_context_dict = {}
    for ent_id in all_contexts:
        ent_id_context_id_dict[ent_id] = [context["id"] for context in all_contexts[ent_id]]
        for context in all_contexts[ent_id]:
            context_id_ent_id_dict[context["id"]]= ent_id

            context["ent_id"] = ent_id
            context_id_full_context_dict[context["id"]] = context

    ###
    
    assert len(context_id_ent_id_dict) == sum([len(val) for val in list(ent_id_context_id_dict.values())])
    ##Right now, each context can be duplicated across entities, but the id is different.
    ##Make a dict with alternative context ids, mapping if needed, more than 1 entity to the same context.
    ##There would thus be unique contexts
    
    ##make a dict {alt_context_id: {context_ids: [list of context ids], ent_ids: [list of ent_ids]}}
    alt_context_dict = {}
    for context_id, context in tqdm(context_id_full_context_dict.items()):
        context_text = context['context'] 
        context_hash = get_hash(context_text)
        alt_context_dict.setdefault(context_hash, []).append(context)


    
    # Numeric labels for entities
    ent_label_dict = {}
    for i, ent in enumerate(list(all_contexts.keys())):
        ent_label_dict[ent] = i

    # Map ent_id to split
    ent_id_to_split = {}
    for split in splits:
        for ent_id in splits[split]:
            ent_id_to_split[ent_id] = split
    for ent_id in all_contexts:
        if ent_id not in ent_id_to_split:
            ent_id_to_split[ent_id] = "None"

    # Subbatches - now allow for negatives frm the same context string.
    # subbatches = sub_batch(splits, negatives_dict, ent_id_to_split, ent_id_context_id_dict, context_id_ent_id_dict, samples_per_label)

    # with open(f'{dataset_path}/subbatches.pkl', 'wb') as f:
    #     pickle.dump(subbatches, f)

    if small:
        with open(f'{dataset_path}/small_subbatches.pkl', 'rb') as f:
            subbatches = pickle.load(f)
    else: 
        with open(f'{dataset_path}/subbatches.pkl', 'rb') as f:
            subbatches = pickle.load(f)


    
    outputs = {}

    # all splits otherwise: output as pairs
    for spl in [s for s in subbatches if s not in outputs]:
        print(f"Featurising {spl} split with pair labels ...")
        outputs[spl] = {
            'sentence_1': [],
            'sentence_2': [],
            'labels': []
        }

        split_subbatches = subbatches[spl]
        random.shuffle(split_subbatches)
        counter_neg=0
        counter_pos=0
        for sb in tqdm(split_subbatches):

            full_contexts = [context_id_full_context_dict[cid] for cid in sb]            
            featurised_text_list = featurise_data_pytorch(full_contexts, featurisation, special_tokens, tokenizer,max_sequence_len=max_seq_length)
            label_list = [ent_label_dict[context["ent_id"]] for context in full_contexts]
            
            combos = itertools.combinations(range(len(featurised_text_list)), 2)

            for combo in combos:
                if label_list[combo[0]] == label_list[combo[1]]:
                    label = 1   # same
                else:
                    label = 0   # different
                
                outputs[spl]['sentence_1'].append(featurised_text_list[combo[0]])
                outputs[spl]['sentence_2'].append(featurised_text_list[combo[1]])
                outputs[spl]['labels'].append(float(label))
            ##For 10% of contexts, make negative pairs with itself - but with other entities within the context
            full_contexts_hashes= list(set([get_hash(context["context"]) for context in full_contexts]))
            ##Get the contexts that have more than 1 entity in them which are in full_contexts_hashes (from alt_context_dict)
            contexts_with_more_than_one_entity = [alt_context_dict[context_hash] for context_hash in full_contexts_hashes if len(alt_context_dict[context_hash]) > 1]
            ##Take 10% of the contexts with more than 1 entity in them
            num_contexts_to_sample = int(len(contexts_with_more_than_one_entity) * 1)
            contexts_to_sample = random.sample(contexts_with_more_than_one_entity, num_contexts_to_sample)
            ##NFor each has, get 2 context_ids and then pull the context from the context_id_full_context_dict
            for context_hash in contexts_to_sample:
                context_to_sample = random.sample(context_hash, 2)
                context_1 = context_to_sample[0]
                context_2 = context_to_sample[1]
                entity_1 = context_to_sample[0]["ent_id"]
                entity_2 = context_to_sample[1]["ent_id"]

                ##Featurise them
                featurised_context_1 = featurise_data_pytorch([context_1], featurisation, special_tokens, tokenizer,max_sequence_len=max_seq_length)
                featurised_context_2 = featurise_data_pytorch([context_2], featurisation, special_tokens, tokenizer,max_sequence_len=max_seq_length)
                ##Append to outputs
                outputs[spl]['sentence_1'].append(featurised_context_1[0])
                outputs[spl]['sentence_2'].append(featurised_context_2[0])
                
                if entity_1 == entity_2:
                    outputs[spl]['labels'].append(float(1))
                    counter_pos=counter_pos+1
                else:
                    outputs[spl]['labels'].append(float(0))
                    counter_neg=counter_neg+1

            
        print(f"Added {counter_neg} within context negative pairs to {spl} split")
        print(f"Added {counter_pos} within context positive pairs to {spl} split")

    ###Shuffling the data for the final time - keep pairs and labels together !
    print("Shuffling data ...")
    for split in outputs:
        print(f"Shuffling {split} split ...")
        data = list(zip(outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels']))
        print(f"Number of examples in: {split} ", len(data))
        random.seed(42)
        random.shuffle(data)
        outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels'] = zip(*data)

    with open(save_path, 'wb') as f:
        pickle.dump(outputs, f,protocol=4)
    
    
    return outputs["train"], outputs["dev"], outputs["test"]

def data_shuffler(all_splits_pickle):
    """
    Shuffles the data within the splits
    """
    with open(all_splits_pickle, 'rb') as f:
        outputs = pickle.load(f)
    
    ##Count number of train
    # print("Number of test examples: ", len(outputs["test"]["sentence_1"]))
    
    
    
    ##Count when first word of sentence 1 is not the same as the first word of sentence 2
    # counter=0
    # for split in outputs:
    #     if split == "test":
    #         for i in range(len(outputs[split]["sentence_1"])):
    #             if outputs[split]["sentence_1"][i].split(" ")[0] != outputs[split]["sentence_2"][i].split(" ")[0]:
    #                 counter=counter+1
    #     else:
    #         continue
    # print("Number of times first word of sentence 1 is not the same as the first word of sentence 2: ", counter)
    
    for split in outputs:
        print(f"Shuffling {split} split ...")
        data = list(zip(outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels']))
        random.seed(42)
        random.shuffle(data)
        outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels'] = zip(*data)
   
    # ##Count when first word of sentence 1 is not the same as the first word of sentence 2
    # counter=0
    # for split in outputs:
    #     if split == "test":
    #         for i in range(len(outputs[split]["sentence_1"])):
    #             if outputs[split]["sentence_1"][i].split(" ")[0] != outputs[split]["sentence_2"][i].split(" ")[0]:
    #                 counter=counter+1
    #     else:
    #         continue
    # print("Number of times first word of sentence 1 is not the same as the first word of sentence 2: ", counter)
    
    ##print an example positive pair
    # print("Example positive pair: ")
    # for split in outputs:
    #     if split == "test":
    #         for i in range(len(outputs[split]["sentence_1"])):
    #             if outputs[split]["labels"][i] == 1:
    #                 print(outputs[split]["sentence_1"][i])
    #                 print(outputs[split]["sentence_2"][i])
    #                 print(outputs[split]["labels"][i])
    #                 break
    #     else:
    #         continue
    ##Save 
    with open(all_splits_pickle, 'wb') as f:
        pickle.dump(outputs, f,protocol=4)
   
    return outputs

# data_shuffler("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn.pkl")

# prep_wikipedia_data_coref_hn_for_pt_bienc(dataset_path,"sentence-transformers/all-mpnet-base-v2",
#                                           save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn_small_2.pkl")