import json
import pickle
import numpy as np
from tqdm import tqdm
import random 
import itertools
from datetime import datetime
import pandas as pd
import re
from sentence_transformers import SentenceTransformer
from sentence_transformers.readers import InputExample
from transformers import AutoModel, AutoTokenizer
import pickle
import os
from collections import defaultdict
import hashlib

def get_hash(text):
    return hashlib.md5(text.encode()).hexdigest()


def find_sep_token(tokenizer):

    """
    Returns sep token for given tokenizer
    """

    if 'eos_token' in tokenizer.special_tokens_map:
        sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " "
    else:
        sep = " " + tokenizer.special_tokens_map['sep_token'] + " "

    return sep

dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/"
# with open(f'{dataset_path}/subbatches.pkl', 'rb') as f:
#             subbatches = pickle.load(f)
# print(subbatches["train"])
# with open(f'{dataset_path}/small_splits_v2.json') as f:
#                 cleaned_fp_data = json.load(f)



def positives(context_dict, disamb_dict,  split_ratio=[0.8,0.1,0.1]):

    print("Splitting data ...")


    splits = {'train': [], 'dev': [], 'test': []}
    small_splits = {'train': [], 'dev': [], 'test': []}
    mid_splits = {'train': [], 'dev': [], 'test': []}
    
    ###We just need to distribute the qids into the splits
    ##First split disamb dict keys into train test val
    disamb_keys = list(disamb_dict.keys())
    random.seed(42)
    random.shuffle(disamb_keys)
    disamb_dict_train=disamb_keys[:int(len(disamb_keys)*split_ratio[0])]
    disamb_dict_dev=disamb_keys[int(len(disamb_keys)*split_ratio[0]):int(len(disamb_keys)*(split_ratio[0]+split_ratio[1]))]
    disamb_dict_test=disamb_keys[int(len(disamb_keys)*(split_ratio[0]+split_ratio[1])):]
    
    ##Flatten the disamb dict - split -> qid:qid_list . Make it qid+qid_list
    split_train=[]
    for qid in disamb_dict_train:
        split_train.append(qid)
        split_train.extend(disamb_dict[qid])
    split_dev=[]
    for qid in disamb_dict_dev:
        split_dev.append(qid)
        split_dev.extend(disamb_dict[qid])
    split_test=[]
    for qid in disamb_dict_test:
        split_test.append(qid)
        split_test.extend(disamb_dict[qid])
        
    ###Now, take out the keys (setdiff)  which are in the disamb dict from all_context_dict's keys
    context_dict_keys = list(context_dict.keys())
    context_dict_train = list(set(context_dict_keys) - set(split_train) - set(split_dev) - set(split_test))
    
    ##Now split this into train, dev, test
    random.seed(42)
    random.shuffle(context_dict_train)
    context_dict_train=context_dict_train[:int(len(context_dict_train)*split_ratio[0])]
    context_dict_dev=context_dict_train[int(len(context_dict_train)*split_ratio[0]):int(len(context_dict_train)*(split_ratio[0]+split_ratio[1]))]
    context_dict_test=context_dict_train[int(len(context_dict_train)*(split_ratio[0]+split_ratio[1])):]
    
    ##Add to splits
    splits['train'] = split_train + context_dict_train
    splits['dev'] = split_dev + context_dict_dev
    splits['test'] = split_test + context_dict_test
    
    ##Small splits
    small_splits['train'] = split_train[:int(len(split_train)*0.1)] + context_dict_train[:int(len(context_dict_train)*0.1)]
    small_splits['dev'] = split_dev[:int(len(split_dev)*0.1)] + context_dict_dev[:int(len(context_dict_dev)*0.1)]
    small_splits['test'] = split_test[:int(len(split_test)*0.1)] + context_dict_test[:int(len(context_dict_test)*0.1)]
    
    ##Mid splits
    mid_splits['train'] = split_train[:int(len(split_train)*0.5)] + context_dict_train[:int(len(context_dict_train)*0.5)]
    mid_splits['dev'] = split_dev[:int(len(split_dev)*0.5)] + context_dict_dev[:int(len(context_dict_dev)*0.5)]
    mid_splits['test'] = split_test[:int(len(split_test)*0.5)] + context_dict_test[:int(len(context_dict_test)*0.5)]
    
    ##Check that there is no overlap between the splits
    print(len(set(splits['train']).intersection(set(splits['dev']))))
    print(len(set(splits['train']).intersection(set(splits['test']))))
    print(len(set(splits['dev']).intersection(set(splits['test']))))
    

    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/splits_v3.json', 'w') as f:
        json.dump(splits, f)
    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/small_splits_v3.json', 'w') as f:
        json.dump(small_splits, f)
    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/mid_splits_v3.json', 'w') as f:
        json.dump(mid_splits, f)
    
    


def negatives(clean_disamb_dict,save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/negatives.pkl",
              fp_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json"):

    print("Creating negative dict ...")
    
    with open(fp_dict) as f:
        fp_dict = json.load(f)

    print(len(clean_disamb_dict))

    negatives_dict = {}
    for clu in tqdm(clean_disamb_dict):
        ent_list = clean_disamb_dict[clu]
        ent_list.append(clu)
        for ent in ent_list:
            if ent in negatives_dict:
                negatives_dict[ent].extend([e for e in ent_list if e != ent and e in fp_dict.keys()])
            else:
                negatives_dict[ent] = [e for e in ent_list if e != ent and e in fp_dict.keys()]

    for ent in negatives_dict:
        negatives_dict[ent] = list(set(negatives_dict[ent]))
    
    ##Drop any empty lists
    negatives_dict = {k: v for k, v in negatives_dict.items() if len(v) > 0}
    
    ##Drop any keys where key is not in fp_dict
    negatives_dict = {k: v for k, v in negatives_dict.items() if k in fp_dict.keys()}

    with open(f'{save_path}', 'wb') as f:
        pickle.dump(negatives_dict, f)

    print(len(negatives_dict))




def featurise_data_with_dates_flex(list_of_dicts, featurisation, date_featurisation, special_tokens, tokenizer, max_sequence_len=128,noise_years=5):

    """
    Featurisation options for how entity mention is featurised:
    - ent_mark: Puts special tokens around mention ie. [CLS] ctxt_l [MEN_s] mention [MEN_e] ctxt_r [SEP]. 
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - prepend: Prepends the mention using a special sep token ie. [CLS] mention [MEN] context [SEP]
        Following Wu et al. (2019) https://arxiv.org/abs/1911.03814 
    - hsu_hor: Puts special tokens around mention, splits into 2 context sentence and 1 mention sentence,
        Orders with context sentences first, then sep, then mention sentence.
        Following Hsu and Horwood (2022) https://arxiv.org/abs/2205.11438 
    """

    if featurisation not in ['ent_mark', 'prepend', 'hsu_hor']:
        raise ValueError("featurisation must be one of 'ent_mark', 'prepend', 'hsu_hor'")
    if featurisation in ['ent_mark', 'hsu_hor']:
        if 'men_start' not in special_tokens:
            raise ValueError(f"'men_start' must be in special tokens if featurisation is {featurisation}")
        if 'men_end' not in special_tokens:
            raise ValueError(f"'men_end' must be in special tokens if featurisation is {featurisation}")
    if featurisation == 'prepend':
        if 'men_sep' not in special_tokens:
            raise ValueError(f"'men_sep' must be in special tokens if featurisation is 'prepend'")
    if date_featurisation not in ['none', 'prepend_1', 'prepend_2', 'nat_lang']:
        raise ValueError("date_featurisation must be 'none', 'prepend_1', 'prepend_2', 'nat_lang'")

    max_length = max_sequence_len - 10
    
    sep = find_sep_token(tokenizer)

    wsp = re.compile(r'\s+')

    output_list = []
    for context_dict in tqdm(list_of_dicts):

        men_start = context_dict["mention_start"]
        men_end = context_dict["mention_end"]
        org_ment_text = context_dict["mention_text"]
        text = context_dict["text"] if "text" in context_dict else context_dict["context"]
        if date_featurisation!="none":
            if "year" in context_dict or 'median_year' in context_dict :
                date = context_dict["median_year"] if 'median_year' in context_dict else context_dict["year"]
                ##Convert to integer then string
                ##If date is NaN, set to None
                if pd.isnull(date):
                    date = None
                else:
                    if noise_years>0:
                    ##Randomly reduce or add noise_years to the date (+- noise_years)
                        random_noise=np.random.randint(-noise_years,noise_years)
                        date = str(int(date)+random_noise)
            else:
                date = None
        else:
            date=None

        # Truncate (roughly want the mention in the centre)
        if featurisation in ["ent_mark", "prepend"]:

            mention_text = text[men_start:men_end]

            assert re.sub(wsp, '', mention_text) == re.sub(wsp, '', org_ment_text)

            right_text = text[men_end:]
            
            left = tokenizer.encode(text[:men_start], add_special_tokens=False) 
            mention = tokenizer.encode(mention_text, add_special_tokens=False) 
            right = tokenizer.encode(right_text, add_special_tokens=False)

            if len(left) + len(mention) + len(right) < max_length:
                truncated_text = text    
            elif len(left) < (2*max_length)/3:    # Mention already in first two thirds - leave as is, for automatic truncation 
                truncated_text = text
            else: # Model in final third or beyond text length
                if len(right) < max_length/2:
                    left_len = max_length - len(mention) - len(right)
                else:
                    left_len = round(max_length/2)
                left_text = tokenizer.decode(left[-left_len:])

                truncated_text = left_text + " " + mention_text + right_text

                if org_ment_text != truncated_text[len(left_text) + 1:len(left_text) + len(mention_text) +1]:
                    print(truncated_text)
                    print("**")
                    print(text)
                    print("**")
                    print(len(left_text))
                    print(len(mention_text))
                    print(org_ment_text)
                    print(truncated_text[len(left_text):len(left_text) + len(mention_text)])

                men_start = len(left_text) + 1
                men_end = len(left_text) + len(mention_text) + 1

            text = truncated_text

            assert  re.sub(wsp, '', org_ment_text) == re.sub(wsp, '', text[men_start:men_end])


        if featurisation == "ent_mark":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Featurise dates
            if date:
                if date_featurisation == 'prepend_1' or date_featurisation == 'prepend_2':
                    ent_text = str(date) + sep + ent_text
                elif date_featurisation == 'nat_lang':
                    ent_text = "In " + str(date) + ", " + ent_text


        elif featurisation == "prepend":
            
            if date:
                # Prepend entity with special sep token and featurise dat as well 
                if date_featurisation == 'prepend_1':
                    ent_text = str(date) \
                            + sep \
                            + text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + text

                elif date_featurisation == 'prepend_2':
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + str(date) \
                            + sep \
                            + text

                elif date_featurisation == 'nat_lang':
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + "In " + str(date) + ", " \
                            + text
                
                elif date_featurisation == "none":
                    ent_text = text[men_start:men_end] \
                            + " " + special_tokens['men_sep'] + " " \
                            + text


        elif featurisation == "hsu_hor":

            # Put special tokens around the entity
            ent_text = text[:men_start] \
                    + " " + special_tokens['men_start'] + " " \
                    + text[men_start:men_end] \
                    + " " + special_tokens['men_end'] + " " \
                    + text[men_end:]

            # Split so that you have two context sentences first, followed by entity sentence
            context_sents = []
            mention_sent = None
            for snip in ent_text.split('\\n\\n'):
                if special_tokens['men_start'] in snip:
                    mention_sent = snip
                else:
                    context_sents.append(snip)

            context = " ".join(context_sents[:2])

            # Check that you don't go over the model length limit, and truncate context sentences if you do 
            mention_sent_len = len(tokenizer.encode(mention_sent))
            encoded_context = tokenizer.encode(context)
            encoded_context_len = len(encoded_context)

            if mention_sent_len + encoded_context_len > (max_length):
                context = tokenizer.decode(encoded_context[1:(max_length) - mention_sent_len])
        output_list.append(ent_text)

    return output_list



def prep_wikipedia_data_disamb_hn_for_pt_bienc(
    dataset_path,
    model,
    special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}, 
    featurisation='ent_mark',
    batch_type='contrastive_batchhard',
    samples_per_label = 8,
    batch_size=16,
    small=False,
    max_seq_length=128,
    date_featurisation='none',
    save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn_v3.pkl",
    ):
    
    """
    Opens data, featurises contexts and groups into sub-batches with hard negatives. 
        
    Trims sequences if longer than model.max_seq_length: 
    - ent_mark:
    - prepend: 
    - hsu_hor: Trims context sentences first, then mention sentence if still too long
    """

    ###Check if save_path exists
    if os.path.exists(save_path):
        print("File exists, loading it")
        with open(save_path, 'rb') as f:
            outputs = pickle.load(f)

        return outputs["train"], outputs["dev"], outputs["test"]
    

    tokenizer=AutoTokenizer.from_pretrained(model)
    if batch_size % (samples_per_label) != 0:
        raise ValueError("samples_per_label must be a divisor of batch_size")
    if batch_type not in ['contrastive_batchhard']:
        raise ValueError("unsupported batch type")

    # Load data 
    if small:
        with open(f'{dataset_path}/small_splits_v3.json') as f:
            splits = json.load(f)
    else:
        with open(f'{dataset_path}/splits_v3.json') as f:
            splits = json.load(f) #list of qids for each split (train, test, val)

    with open(f'{dataset_path}/all_contexts_humans_mapped_all_contexts_with_median_year.pkl', 'rb') as f:
        all_contexts = pickle.load(f) ##Keys are qids

    
    with open(f'{dataset_path}/negatives_family_disamb.pkl', 'rb') as f:
        negatives_dict = pickle.load(f) ##keys are qid, value, qid list
    
    ##Open fp dict /mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json
    with open(f'{dataset_path}/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json') as f:
        fp_dict = json.load(f) ##Keys are qid

    pair_types={'positives':0, 'easy_negatives':0, 'hard_negatives_fam_disamb':0}
    dataset_stats={'train': pair_types, 'dev': pair_types, 'test': pair_types}
    split_names=['train', 'dev', 'test']
    ###First, prep positives - for each entity in the split, sample upto 4 contexts , and pair it with the disamb dict
    print("Prepping positives ...")
    outputs={}
    for split in split_names:
        split_fp_intersection_keys=set(fp_dict.keys()).intersection(set(splits[split]))
        
        
        outputs[split] = {
                    'sentence_1': [], ##context
                    'sentence_2': [], ##First para
                    'labels': []
                }
        print(f"Prepping {split} split ...")
        split_size_easy_positives = 0
        for qid in tqdm(splits[split]):
            if qid in all_contexts.keys():
                if qid not in fp_dict.keys():
                    continue
                context_list=all_contexts[qid]
                ##Sample 4
                context_list = random.sample(context_list, min(samples_per_label,len(context_list)))  
            else:
                continue
            fp_qid=fp_dict[qid]
            ##replicate first para = len(context_list) times
            fp_list = [fp_qid]*len(context_list)
            outputs[split]['sentence_1'].extend(context_list)
            outputs[split]['sentence_2'].extend(fp_list)
            outputs[split]['labels'].extend([float(1)]*len(context_list))
            split_size_easy_positives += len(context_list)
        dataset_stats[split]['positives'] = split_size_easy_positives
        print(f"Positives {split} split: {split_size_easy_positives}")
        
    ##Prep negatives (easy)
    for split in splits:
        ##Sample samples_per_label contexts, then pair them up with random fp from the fp in the split
        print(f"Prepping easy negatives {split} split ...")
        
        split_size_easy_negs = 0
        for qid in tqdm(splits[split]):
            if qid in all_contexts.keys():
                context_list=all_contexts[qid]
                ##Sample samples_per_label
                context_list = random.sample(context_list, min(samples_per_label,len(context_list)))  
            
            ##Sample random qids from the split = len(context_list) times
            neg_qids = random.sample(split_fp_intersection_keys, len(context_list))
            for context, neg_qid in zip(context_list, neg_qids):
                fp_qid=fp_dict[neg_qid]
                outputs[split]['sentence_1'].append(context)
                outputs[split]['sentence_2'].append(fp_qid)
                outputs[split]['labels'].append(float(0))
                split_size_easy_negs += 1
            
        dataset_stats[split]['easy_negatives'] = split_size_easy_negs
        print(f"Easy negatives {split} split: {split_size_easy_negs}")
    
    ##Prep negatives (hard)
    for split in splits:
        ##Sample 4 * 2contexts, then pair them up with fp from the set corresponding to the qids in the negative dict
        print(f"Prepping hard negatives {split} split ...")
        for qid in tqdm(splits[split]):
            if qid not in all_contexts.keys():
                continue
            if qid not in fp_dict.keys():
                continue
            if qid not in negatives_dict.keys():
                continue
            context_list=all_contexts[qid]
            ##Sample 4
            context_list = random.sample(context_list, min(samples_per_label*2,len(context_list)))  
        
            ##Sample random qids from the split = len(context_list) times (with replacement)
            available_neg_choices=negatives_dict[qid]
            neg_qids = random.choices(available_neg_choices, k=len(context_list))
            for context, neg_qid in zip(context_list, neg_qids):
                fp_qid=fp_dict[neg_qid]
                outputs[split]['sentence_1'].append(context)
                outputs[split]['sentence_2'].append(fp_qid)
                outputs[split]['labels'].append(float(0))
                
        dataset_stats[split]['hard_negatives_fam_disamb'] = len(outputs[split]['labels']) - dataset_stats[split]['positives'] - dataset_stats[split]['easy_negatives']
                
        splits[split] = outputs
        print(f"Hard negatives {split} split: {dataset_stats[split]['hard_negatives_fam_disamb']}")
        ##Featurise data for each split
        print(f"Featurising {split} split ...")
        outputs[split]['sentence_1'] = featurise_data_with_dates_flex(outputs[split]['sentence_1'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
        outputs[split]['sentence_2'] = featurise_data_with_dates_flex(outputs[split]['sentence_2'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
        
    ##Save outputs
    with open(save_path, 'wb') as f:
        pickle.dump(outputs, f)
        
    print(dataset_stats)
    
    ##Save stats
    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/dataset_stats_v3.json', 'w') as f:
        json.dump(dataset_stats, f)
    
    return outputs["train"], outputs["dev"], outputs["test"]

def data_shuffler(all_splits_pickle):
    """
    Shuffles the data within the splits
    """
    with open(all_splits_pickle, 'rb') as f:
        outputs = pickle.load(f)

    for split in outputs:
        print(f"Shuffling {split} split ...")
        data = list(zip(outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels']))
        random.seed(42)
        random.shuffle(data)
        outputs[split]['sentence_1'], outputs[split]['sentence_2'], outputs[split]['labels'] = zip(*data)
   
    ##Save 
    with open(all_splits_pickle, 'wb') as f:
        pickle.dump(outputs, f,protocol=4)
   
    return outputs

###Run as script
if __name__ == '__main__':
    
    # # ##Prep negatives
    # family_rel_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/qid_family_neg_dict.json"
    # disamb_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disambiguation_dict_final_qid.json"
    # all_contexts_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_contexts_humans_mapped_all_contexts_with_median_year.pkl"
    # # ##Open
    # with open(family_rel_dict) as f:
    #     family_dict = json.load(f)
    # with open(disamb_dict) as f:
    #     disamb_dict = json.load(f)
    
    # ###Drop any empty lists (Values)
    # family_dict = {k: v for k, v in family_dict.items() if len(v) > 0}
    # disamb_dict = {k: v for k, v in disamb_dict.items() if len(v) > 0}
    
    # ##Combined dicts - union of keys and union of values
    # family_disamb_dict = {}
    # for key in family_dict:
    #     if key in disamb_dict:
    #         family_disamb_dict[key] = list(set(family_dict[key] + disamb_dict[key]))
    #     else:
    #         family_disamb_dict[key] = family_dict[key]
    # for key in disamb_dict:
    #     if key not in family_disamb_dict:
    #         family_disamb_dict[key] = disamb_dict[key]
    
    # ##Assert that no empty lists
    # assert all([len(v) > 0 for v in family_disamb_dict.values()])
    
    # ##Combined dict
    # with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/family_disamb_dict.json", 'w') as f:
    #     json.dump(family_disamb_dict, f)

    # ##Prep negatives
    # negatives(family_disamb_dict,save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/negatives_family_disamb.pkl")
    # disamb_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disambiguation_dict_final_qid.json"
    # all_contexts_dict="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_contexts_humans_mapped_all_contexts_with_median_year.pkl"

    # # ###open all_contexts_dict
    # with open(all_contexts_dict, 'rb') as f:
    #     all_contexts_dict = pickle.load(f)
    # with open(disamb_dict) as f:
    #     disamb_dict = json.load(f)
    # # ##prep positives (splits now)
    # positives(all_contexts_dict,disamb_dict)
    
    
    # ##prep data for pt biencoder - disamb
    # model_path = "sentence-transformers/all-mpnet-base-v2"
    # model = SentenceTransformer(model_path)
    
    
    train,val,test=prep_wikipedia_data_disamb_hn_for_pt_bienc(
        dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/",
        model="sentence-transformers/all-mpnet-base-v2",
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation="ent_mark",
        date_featurisation="none",
        batch_type="contrastive_batchhard",
        max_seq_length=256,
        samples_per_label=4,
        batch_size=256,
        small=False,
        save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn_v3.pkl"
    )
    
    data_shuffler("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn_v3.pkl")