##The goal is to add wiki context to the newspaper dataset. 
import pickle
import json
from collections import defaultdict
from itertools import combinations
import ast
import random
import hashlib
from tqdm import tqdm
import os
from data_fn_disamb import featurise_data_with_dates_flex
from sentence_transformers import SentenceTransformer
from sentence_transformers.readers import InputExample
from transformers import AutoModel, AutoTokenizer
import pandas as pd
newspaper_dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/labeled_datasets_full_extended.json"
wiki_firstpara_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json"
disamb_dict_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/negatives_family_disamb.pkl"
reindexed_news_data_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_data_reindexed.json"


CONTEXTS_PER_ENTITY=4 ##To avoid overfitting to specific entities
# ##Open the newspaper dataset
# with open(newspaper_dataset_path, 'r') as f:
#     newspaper_data = json.load(f)
    
# ##Reindex the newspaper dict- it is currently title, make it QID (wiki_id). 
# # Eg: {"David K. E. Bruce": [[{"mention_text": "David K. E. Bruce", "wiki_id": "Q325810", "wiki_url": "['David_K._E._Bruce', 'David_K._Bruce', 'David_K.E._Bruce', 'David_K_E_Bruce', 'David_KE_Bruce', 'David_Kirkpatrick_Este_Bruce', 'David_K._E._Bruce_(Template)']", "mention_start": 1738, "mention_end": 1755, "classification": "Positive", "text_id": "text5", "text": "ee ee ee ee ee ee  WASHINGTON (AP) \u2014 President Johnson and British Prime | Minister Harold Wilson held an \u2018interesting, cordial and fruitful\u201d discussion today which covered a wide range of subjects, including Viet Nam,  Johnson himself gave newsmen this report and Wilson said he agreed with what the President said.  Johnson, noting that Wilson agrees with the U.S. position in Viet Nam, said \u201cWe appreciate very much this support.

# ##Each element is a list of contexts and each context contains a wiki_id - which is the same across all entity contexts
# newspaper_data_reindexed={}
# for title, contexts in newspaper_data.items():
#     if len(contexts)==0:
#         continue
#     print(len(contexts))
#     wiki_id=contexts[0][0]['wiki_id']
#     newspaper_data_reindexed[wiki_id]=contexts
#     ##Add title as wiki_title_replaced
#     for context in contexts:
#         context[0]['wiki_title_replaced']=title

# ###Save
# with open(reindexed_news_data_path, 'w') as f:
#     json.dump(newspaper_data_reindexed, f)
    
def remove_those_without_mention_text(context_dict):
    ##Remove those where 'mention_text' is empty 
    context_list_filtered=[]
    for context in context_dict:
        if context["mention_text"]!="":
            context_list_filtered.append(context)
    return context_list_filtered

def dedup_list_of_context_dict(context_dict):
    ##Dedup the list
    deduped_list=[]
    for context in context_dict:
        if context not in deduped_list:
            deduped_list.append(context)
    return deduped_list

def create_positives(split_dict, fp_dict):
    positive_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="same":
                positive_dict[qid].append(context[0])
                positive_dict[qid].append(context[1])
        ##Dedup the list
        positive_dict[qid]=dedup_list_of_context_dict(positive_dict[qid])
        positive_dict[qid]=remove_those_without_mention_text(positive_dict[qid])
        
        ##Shorten
        positive_dict[qid]=positive_dict[qid][:CONTEXTS_PER_ENTITY]
    
    ##Pair up the positives with corresponding FP
    paired_dict={"sentence_1":[], "sentence_2":[], "labels":[]}
    
    ##For each positive, get the corresponding FP if it exists
    for qid, pos_contexts in positive_dict.items():
        for pos_context in pos_contexts:
            ##Get the corresponding FP
            if qid in fp_dict:
                corresponding_fp=fp_dict[qid]
            else:
                continue
            ##Pair up
            paired_dict["sentence_1"].append(pos_context)
            paired_dict["sentence_2"].append(corresponding_fp)
            paired_dict["labels"].append(1)
    
    ##print len
    print("Number of positives: ", len(paired_dict["sentence_1"]))
    
    negatives_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="different":
                if 'neg_wiki_id' in context[1] and context[1]['neg_wiki_id'] in fp_dict:
                    negatives_dict[qid].append(context[1])
        
            ##Dedup
        negatives_dict[qid]=dedup_list_of_context_dict(negatives_dict[qid])
        negatives_dict[qid]=remove_those_without_mention_text(negatives_dict[qid])
        ##shorten
        negatives_dict[qid]=negatives_dict[qid][:CONTEXTS_PER_ENTITY]
    ##Pair up the negatives with corresponding FP
    for qid, neg_contexts in negatives_dict.items():
        for context in neg_contexts:
            ##Get the corresponding FP
            corresponding_fp=fp_dict[context['neg_wiki_id']]
            ##Pair up
            paired_dict["sentence_1"].append(context)
            paired_dict["sentence_2"].append(corresponding_fp)
            paired_dict["labels"].append(1)
    
    print("Number of positives: ", len(paired_dict["sentence_1"]))
    
    return paired_dict
            
    
    
def create_easy_negatives(split_dict, fp_dict):
    ##Pair up the positives with corresponding FP
    fp_dict_keys_list=list(fp_dict.keys())
    
    positive_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="same":
                positive_dict[qid].append(context[0])
        
        ##dedup
        positive_dict[qid]=dedup_list_of_context_dict(positive_dict[qid])
        positive_dict[qid]=remove_those_without_mention_text(positive_dict[qid])
        ##Shorten by CONTEXTS_PER_ENTITY
        positive_dict[qid]=positive_dict[qid][:CONTEXTS_PER_ENTITY]

    ##Pair up the positives with a random FP other than the corresponding QID
    paired_dict={"sentence_1":[], "sentence_2":[], "labels":[]}
    
    ##For each positive, get a random FP
    for qid, pos_contexts in tqdm(positive_dict.items()):
        for pos_context in pos_contexts:
            ##Get a random FP other than the corresponding QID
            random_fp_qid=random.choice(fp_dict_keys_list)
            while random_fp_qid==qid:
                random_fp_qid=random.choice(fp_dict_keys_list)
            random_fp_context=(fp_dict[random_fp_qid])
            ##Pair up
            paired_dict["sentence_1"].append(pos_context)
            paired_dict["sentence_2"].append(random_fp_context)
            paired_dict["labels"].append(0)

    
    
    print("Number of easy negatives: ", len(paired_dict["sentence_1"]))
    
    negatives_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="different":
                if 'neg_wiki_id' in context[1] and context[1]['neg_wiki_id'] in fp_dict:
                    negatives_dict[qid].append(context[1])
                    
        ##dedup
        negatives_dict[qid]=dedup_list_of_context_dict(negatives_dict[qid])
        negatives_dict[qid]=remove_those_without_mention_text(negatives_dict[qid])
        ##Shorten
        negatives_dict[qid]=negatives_dict[qid][:CONTEXTS_PER_ENTITY]
    ##Pair up the negatives with a random FP
    
    for qid, neg_contexts in tqdm(negatives_dict.items()):
        for context in neg_contexts:
            ##Get a random FP
            random_fp_qid=random.choice(fp_dict_keys_list)
            while random_fp_qid==context['neg_wiki_id']:
                random_fp_qid=random.choice(fp_dict_keys_list)
            random_fp_context=(fp_dict[random_fp_qid])
            ##Pair up
            paired_dict["sentence_1"].append(context)
            paired_dict["sentence_2"].append(random_fp_context)
            paired_dict["labels"].append(0)
            
        
            
    print("Number of easy negatives: ", len(paired_dict["sentence_1"]))
    
    return paired_dict

def newspaper_hard_negatives(split_dict, fp_dict):
    ##Pair up the positives with corresponding FP
    negatives_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="different":
                negatives_dict[qid].append(context[1])
        
        #dedup
        negatives_dict[qid]=dedup_list_of_context_dict(negatives_dict[qid])
        negatives_dict[qid]=remove_those_without_mention_text(negatives_dict[qid])
        #shorten
        negatives_dict[qid]=negatives_dict[qid][:CONTEXTS_PER_ENTITY]
    ##Pair up the negatives with corresponding FP
    paired_dict={"sentence_1":[], "sentence_2":[], "labels":[]}
    
    for qid, neg_contexts in negatives_dict.items():
        for neg_context in neg_contexts:
            ##Get the corresponding FP
            if qid in fp_dict:
                corresponding_fp=fp_dict[qid]
                ##Pair up
                paired_dict["sentence_1"].append(neg_context)
                paired_dict["sentence_2"].append(corresponding_fp)
                paired_dict["labels"].append(0)
    
    print("Number of newspaper hard negatives: ", len(paired_dict["sentence_1"]))
    
    return paired_dict


def wiki_hard_negatives(split_dict, fp_dict, disamb_dict):
    positives_dict=defaultdict(list)
    for qid, contexts in split_dict.items():
        for context in contexts:
            if context[2]=="same":
                positives_dict[qid].append(context[0])
        
        ##dedup
        positives_dict[qid]=dedup_list_of_context_dict(positives_dict[qid])
        positives_dict[qid]=remove_those_without_mention_text(positives_dict[qid])
        ##shorten
        positives_dict[qid]=positives_dict[qid][:CONTEXTS_PER_ENTITY]
    paired_contexts={"sentence_1":[], "sentence_2":[], "labels":[]}
    ##Pair up the positives with FP's from disamb dict. 
    for qid, pos_contexts in positives_dict.items():
        for pos_context in pos_contexts:
            if qid in disamb_dict:
                disamb_list_qid=disamb_dict[qid]
                ##Randomly sample a disambiguation QID
                disamb_qid=random.choice(disamb_list_qid)
                if disamb_qid in fp_dict:
                    corresponding_fp=fp_dict[disamb_qid]
                    paired_contexts["sentence_1"].append(pos_context)
                    paired_contexts["sentence_2"].append(corresponding_fp)
                    paired_contexts["labels"].append(0)
                    
                    
    print("Number of wiki hard negatives: ", len(paired_contexts["sentence_1"]))
    
    return paired_contexts



def prep_newspaper_disamb_ft_data(wiki_firstpara_path, disamb_dict_path, reindexed_news_data_path, 
                                  featurisation, 
                                  date_featurisation,
                                  special_tokens,
                                  model,
                                  max_seq_length,
                                  output_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_disamb_data.pkl",
                                  only_prototype=False,
                                  remake=False):
    if os.path.exists(output_path) and not remake:
        print("Loading from existing file")
        with open(output_path, 'rb') as f:
            return pickle.load(f)
        
    ##Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model)

    ##Load the wiki first para data
    with open(wiki_firstpara_path, 'r') as f:
        wiki_firstpara_data = json.load(f)

    ##Load the disamb dict
    with open(disamb_dict_path, 'rb') as f:
        disamb_dict = pickle.load(f)
        
    ##Load the reindexed newspaper data
    with open(reindexed_news_data_path, 'r') as f:
        newspaper_data_reindexed = json.load(f)
                    
    ##Subset reindexed data to keys in the fp
    wiki_fp_keys=set(wiki_firstpara_data.keys())
    print(len(newspaper_data_reindexed))
    newspaper_data_reindexed={k:v for k,v in newspaper_data_reindexed.items() if k in wiki_fp_keys}
    print(len(newspaper_data_reindexed))

    ##Remove those where 'mention_text' is empty 
    newspaper_data_reindexed={k:v for k,v in newspaper_data_reindexed.items() if len(v)>0}
    print(len(newspaper_data_reindexed))



    ##Split into train, val, test (0.8, 0.1, 0.1) - by QID
    train_data={}
    val_data={}
    test_data={}

    for qid, contexts in newspaper_data_reindexed.items():
        rand=random.random()
        if rand<0.8:
            train_data[qid]=contexts
        elif rand<0.9:
            val_data[qid]=contexts
        else:
            test_data[qid]=contexts

    print("Train: ", len(train_data))
    print("Val: ", len(val_data))
    print("Test: ", len(test_data))
        


    dataset_stats=defaultdict(dict)
        
    ##Create positives and negatives for each split
    print("Train")
    train_positives=create_positives(train_data, wiki_firstpara_data)
    train_easy_negatives=create_easy_negatives(train_data, wiki_firstpara_data)
    train_hard_negatives=newspaper_hard_negatives(train_data, wiki_firstpara_data)
    train_wiki_hard_negatives=wiki_hard_negatives(train_data, wiki_firstpara_data, disamb_dict)

    dataset_stats["train"]={"positives":len(train_positives["sentence_1"]), "easy_negatives":len(train_easy_negatives["sentence_1"]), "hard_negatives":len(train_hard_negatives["sentence_1"]), "wiki_hard_negatives":len(train_wiki_hard_negatives["sentence_1"])}




    print("Val")
    val_positives=create_positives(val_data, wiki_firstpara_data)
    val_easy_negatives=create_easy_negatives(val_data, wiki_firstpara_data)
    val_hard_negatives=newspaper_hard_negatives(val_data, wiki_firstpara_data)
    val_wiki_hard_negatives=wiki_hard_negatives(val_data, wiki_firstpara_data, disamb_dict)

    dataset_stats["val"]={"positives":len(val_positives["sentence_1"]), "easy_negatives":len(val_easy_negatives["sentence_1"]), "hard_negatives":len(val_hard_negatives["sentence_1"]), "wiki_hard_negatives":len(val_wiki_hard_negatives["sentence_1"])}

    print("Test")
    test_positives=create_positives(test_data, wiki_firstpara_data)
    test_easy_negatives=create_easy_negatives(test_data, wiki_firstpara_data)
    test_hard_negatives=newspaper_hard_negatives(test_data, wiki_firstpara_data)
    test_wiki_hard_negatives=wiki_hard_negatives(test_data, wiki_firstpara_data, disamb_dict)

    dataset_stats["test"]={"positives":len(test_positives["sentence_1"]), "easy_negatives":len(test_easy_negatives["sentence_1"]), "hard_negatives":len(test_hard_negatives["sentence_1"]), "wiki_hard_negatives":len(test_wiki_hard_negatives["sentence_1"])}

    # print("Dataset stats: ", dataset_stats)
    # exit()
    if only_prototype:
        ##Get one positive, one easy negative, one hard negative and one wiki hard negative to save a prototype dict
        prototpe_dict_test={"sentence_1":[], "sentence_2":[], "labels":[]}
        prototpe_dict_test["sentence_1"].append(test_positives["sentence_1"][0])
        prototpe_dict_test["sentence_2"].append(test_positives["sentence_2"][0])
        prototpe_dict_test["labels"].append(test_positives["labels"][0])
        prototpe_dict_test["sentence_1"].append(test_easy_negatives["sentence_1"][0])
        prototpe_dict_test["sentence_2"].append(test_easy_negatives["sentence_2"][0])
        prototpe_dict_test["labels"].append(test_easy_negatives["labels"][0])
        prototpe_dict_test["sentence_1"].append(test_hard_negatives["sentence_1"][0])
        prototpe_dict_test["sentence_2"].append(test_hard_negatives["sentence_2"][0])
        prototpe_dict_test["labels"].append(test_hard_negatives["labels"][0])
        prototpe_dict_test["sentence_1"].append(test_wiki_hard_negatives["sentence_1"][0])
        prototpe_dict_test["sentence_2"].append(test_wiki_hard_negatives["sentence_2"][0])
        prototpe_dict_test["labels"].append(test_wiki_hard_negatives["labels"][0])
        
        ##Featurise
        prototpe_dict_test_featurised={}
        prototpe_dict_test_featurised["sentence_1"]=featurise_data_with_dates_flex(prototpe_dict_test['sentence_1'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
        prototpe_dict_test_featurised["sentence_2"]=featurise_data_with_dates_flex(prototpe_dict_test['sentence_2'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
        prototpe_dict_test_featurised["labels"]=prototpe_dict_test['labels']
        
        ##Save as df
        prototpe_dict_test_df=pd.DataFrame(prototpe_dict_test_featurised)
        prototpe_dict_test_df.to_csv("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/prototype_test.csv", index=False)
        
        return prototpe_dict_test_df
    
    ##Stack em all

    train_stacked={"sentence_1":train_positives["sentence_1"]+train_easy_negatives["sentence_1"]+train_hard_negatives["sentence_1"]+train_wiki_hard_negatives["sentence_1"], "sentence_2":train_positives["sentence_2"]+train_easy_negatives["sentence_2"]+train_hard_negatives["sentence_2"]+train_wiki_hard_negatives["sentence_2"], "labels":train_positives["labels"]+train_easy_negatives["labels"]+train_hard_negatives["labels"]+train_wiki_hard_negatives["labels"]}
    val_stacked={"sentence_1":val_positives["sentence_1"]+val_easy_negatives["sentence_1"]+val_hard_negatives["sentence_1"]+val_wiki_hard_negatives["sentence_1"], "sentence_2":val_positives["sentence_2"]+val_easy_negatives["sentence_2"]+val_hard_negatives["sentence_2"]+val_wiki_hard_negatives["sentence_2"], "labels":val_positives["labels"]+val_easy_negatives["labels"]+val_hard_negatives["labels"]+val_wiki_hard_negatives["labels"]}
    test_stacked={"sentence_1":test_positives["sentence_1"]+test_easy_negatives["sentence_1"]+test_hard_negatives["sentence_1"]+test_wiki_hard_negatives["sentence_1"], "sentence_2":test_positives["sentence_2"]+test_easy_negatives["sentence_2"]+test_hard_negatives["sentence_2"]+test_wiki_hard_negatives["sentence_2"], "labels":test_positives["labels"]+test_easy_negatives["labels"]+test_hard_negatives["labels"]+test_wiki_hard_negatives["labels"]}

    ##Don't stack wiki hard negatives
    # train_stacked={"sentence_1":train_positives["sentence_1"]+train_easy_negatives["sentence_1"]+train_hard_negatives["sentence_1"], "sentence_2":train_positives["sentence_2"]+train_easy_negatives["sentence_2"]+train_hard_negatives["sentence_2"], "labels":train_positives["labels"]+train_easy_negatives["labels"]+train_hard_negatives["labels"]}
    
    # val_stacked={"sentence_1":val_positives["sentence_1"]+val_easy_negatives["sentence_1"]+val_hard_negatives["sentence_1"], "sentence_2":val_positives["sentence_2"]+val_easy_negatives["sentence_2"]+val_hard_negatives["sentence_2"], "labels":val_positives["labels"]+val_easy_negatives["labels"]+val_hard_negatives["labels"]}
    
    # test_stacked={"sentence_1":test_positives["sentence_1"]+test_easy_negatives["sentence_1"]+test_hard_negatives["sentence_1"], "sentence_2":test_positives["sentence_2"]+test_easy_negatives["sentence_2"]+test_hard_negatives["sentence_2"], "labels":test_positives["labels"]+test_easy_negatives["labels"]+test_hard_negatives["labels"]}
    
    ##Now featurise
    train_stacked_featurised={}
    train_stacked_featurised["sentence_1"]=featurise_data_with_dates_flex(train_stacked['sentence_1'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    train_stacked_featurised["sentence_2"]=featurise_data_with_dates_flex(train_stacked['sentence_2'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    train_stacked_featurised["labels"]=train_stacked['labels']

    val_stacked_featurised={}
    val_stacked_featurised["sentence_1"]=featurise_data_with_dates_flex(val_stacked['sentence_1'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    val_stacked_featurised["sentence_2"]=featurise_data_with_dates_flex(val_stacked['sentence_2'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    val_stacked_featurised["labels"]=val_stacked['labels']

    test_stacked_featurised={}
    test_stacked_featurised["sentence_1"]=featurise_data_with_dates_flex(test_stacked['sentence_1'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    test_stacked_featurised["sentence_2"]=featurise_data_with_dates_flex(test_stacked['sentence_2'], featurisation, date_featurisation, special_tokens, tokenizer, max_seq_length)
    test_stacked_featurised["labels"]=test_stacked['labels']

    ##Shuffle 
    train_stacked_featurised=pd.DataFrame(train_stacked_featurised).sample(frac=1).to_dict(orient='list')
    val_stacked_featurised=pd.DataFrame(val_stacked_featurised).sample(frac=1).to_dict(orient='list')
    test_stacked_featurised=pd.DataFrame(test_stacked_featurised).sample(frac=1).to_dict(orient='list')
    
    output=train_stacked_featurised, val_stacked_featurised, test_stacked_featurised

    with open(output_path, 'wb') as f:
        pickle.dump(output, f)
        
    ##Save dataset stats
    with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/dataset_stats_newspaper_disamb_lesshn.json", 'w') as f:
        json.dump(dataset_stats, f)
    
    return output


###Run as script
if __name__ == '__main__':
    train_data, val_data, test_data = prep_newspaper_disamb_ft_data(
        wiki_firstpara_path,
        disamb_dict_path, 
        reindexed_news_data_path,
        featurisation="ent_mark",
        date_featurisation="none",
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        model="sentence-transformers/all-mpnet-base-v2",
        max_seq_length=256,
        output_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_disamb_data_nowikihn_only_people_lesscontx.pkl",
        only_prototype=False,
        remake=True
    )
    
    print("Done")
            






