##The goal is to add wiki context to the newspaper dataset. 
import pickle
import json
from collections import defaultdict
from itertools import combinations
import ast
import random
import hashlib
from tqdm import tqdm

wiki_dataset_path_contexts="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/filtered_dict_custom_queries_labelled.pickle"
newspaper_dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/labeled_datasets_full_extended.json"
wiki_firstpara_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/wiki_firstpara_custom_queries.pkl"
list_of_labelled_entity_titles="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/news_title_to_entity_dict.pkl"
all_contexts_disamb_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_contexts.pkl"
all_first_para_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cleaned_fp_data.json"
disamb_dict_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/disambiguation_dict_final.json"

# newspaper_dataset=json.load(open(newspaper_dataset_path, "r"))
# newspaper_dataset=newspaper_dataset["Bob Smith"]
# ##Write subset to json
# json.dump(newspaper_dataset, open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/bob_smith.json", "w"))

# exit()

def get_hash(text):
    return hashlib.md5(text.encode()).hexdigest()



def len_entities(text):

    """
    Counts number of characters in a text that are made up of entities
    """

    split = text.split("[[")

    ent_list = []

    if text.startswith("[["):
        seg_list = split
    else:
        seg_list = split[1:]

    for seg in seg_list:
        seg_split = seg.split("]]")

        ent_list.append(seg_split[0])

    ent = "".join(ent_list)

    return (len(ent)+(4*len(ent_list)))


def context_clean(text, ent):

    bad_keywords = ("[[File:", "[[Image:", "[[file:", "[[image:", "<", ":File:", "[[WP:", "#REDIRECT", "[[Wikipedia:")

    if text.startswith(bad_keywords):
        return None

    if any(x in text for x in ["User:", "||", "[[]]"]):  # Table, comments 
        return None

    else:

        # Deal with external links
        if "http" in text:

            split = text.split("[http")
            keep = []
            if not text.startswith("[http"):
                keep.append(split[0])
                split = split[1:]
            for spl in split:
                part = spl.split("]", 1)
                left = part[0].split(" ", 1)
                if len(left) >1:
                    keep.extend(left[1:])
                keep.extend(part[1:])
            text = "".join(keep)

            if "http" in text:
                split = text.split(" ")
                keep = [spl for spl in split if "http" not in spl]
                text = " ".join(keep)

        for kw in bad_keywords:
            text = text.split(kw)[0]     

        if ("[[" + ent.lower() + "]]") not in text.lower() and ("[[" + ent.lower() + "|") not in text.lower():
            return None

        elif "=" in text:
            return None

        elif len(text) < 75:
            return None
        elif len(text) > 10000:
            return None

        ent_len = len_entities(text)

        if len(text) - ent_len < 25:
            return None

        elif ent_len/len(text) > 0.85:
            return None

        split = text.split("[[")

        if len(split) < 2:
            return None

        else:

            mention_text = None
            men_seg_start = None 

            if text.startswith("[["):
                context_list = []
                seg_list = split

            else:
                context_list = [split[0]]
                seg_list = split[1:]

            for seg in seg_list:
                seg_split = seg.split("]]")
                if len(seg_split) > 2:
                    seg_split = [x for x in seg_split if x]
                    seg_split = [seg_split[0], "".join(seg_split[1:])]

                if "|" in seg_split[0]:
                    mention = seg_split[0].split("|")[1]
                    entity = seg_split[0].split("|")[0]
                else:
                    mention = seg_split[0]    
                    entity = seg_split[0]
                context_list.append(mention)

                if entity.lower() == ent.lower() and not mention_text:
                    mention_text = mention 
                    men_seg_start = len(context_list) - 1

                if len(seg_split) == 2:
                    context_list.append(seg_split[1])

            if men_seg_start == None:
                print("ERROR")
                print(ent)
                print(text)

            context = "".join(context_list)
            men_start = len("".join(context_list[:men_seg_start]))
            men_end = len("".join(context_list[:men_seg_start+1]))

            assert context[men_start:men_end] == mention_text

            data_dict = {"context": context, "mention_text": mention_text, "mention_start": men_start, "mention_end": men_end}
            
            return data_dict



"""We need the following:
1) Newspaper to newspaper Positives (already have) - newspaper data - DONE

2) Newspaper to newspaper Negs (already have) -  newspaper data - DONE

3) Newspaper X All contexts (or a sample) of contexts for the entity (positives)  -  newspaper data X wiki_dataset_path_contexts
For each pair in positives, sample a context from the wiki dataset and pair them up

4) Newspaper → First Para (Positives) - newspaper data X wiki_firstpara_path - DONE
For each pair in positives, sample a context from the wiki dataset and pair them up

5) Wiki to wiki (Pos) - across contexts -  wiki_dataset_path_contexts

6) Wiki to wiki (neg) - across contexts - hard - TOPREP - Like before, from disamb set

7) Wiki to wiki (neg) - across contexts - easy - wiki_dataset_path_contexts

8) Wiki to wiki (pos) - within context - wiki_dataset_path_contexts  - TOPREP - like before, within  context of entity in the same text, but shifted token

9) Wiki to wiki (neg) - within context - wiki_dataset_path_contexts - TOPREP - like before, within context  - entity token is just shifted, but on another entity

10) Newspaper - Wiki (neg) - easy - random sample across contexts of entities - newspaper data X wiki_dataset_path_contexts

11) Newspaper - Wiki (neg) - easy - random sample across first para of entities   - newspaper data X wiki_firstpara_path - DONE

12) Newspaper - Wiki (pos) - hard - for entities that are in disamb set , get wiki contexts - - TOPREP  - This needs disamb dict and all contexts

13) Newspaper - Wiki (pos) - hard - for entities that are in disamb set , get wiki first para - - TOPREP - This needs only disamb dict and the first para dict

 
 """

# with open(wiki_firstpara_path, "r") as f:
#     wiki_firstpara=json.load(f)

# ##Subset to only the entities in the newspaper dataset
# with open(list_of_labelled_entity_titles, "rb") as f:
#     news_titles=pickle.load(f)
# wiki_firstpara_subset={entity:wiki_firstpara[entity] for entity in news_titles if entity in wiki_firstpara}
# print("Number of entities in wiki dataset ", len(wiki_firstpara.keys()))
# print("Number of entities in wiki dataset subset ", len(wiki_firstpara_subset.keys()))
# ##save 
# pickle.dump(wiki_firstpara_subset, open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/wiki_firstpara_custom_queries.pkl", "wb"))


def wiki_url_title_to_title(url_title):
    clean_title= url_title.replace("_", " ")
    ##Trim whitespace
    clean_title=clean_title.strip()
    ##remove leading and trailing single quotes
    clean_title=clean_title.strip("'")
    clean_title=clean_title.strip("\"")
    

    ##Apostrophe's have been escaped with a backslash. Remove the backslash
    clean_title=clean_title.replace("\\ '", "'")
    clean_title=clean_title.replace("\\'", "'")

    return clean_title

###We want to first add wiki titles in the newspaper dataset using the key:[[wiki_url1,wiki_url2],...]

newspaper_dataset=json.load(open(newspaper_dataset_path, "r"))

newspaper_positives=defaultdict(list)
newspaper_hard_negatives=defaultdict(list)

empty_entities=0
empty_entities_list=[]
for entity in newspaper_dataset:
    if len(newspaper_dataset[entity])==0:
        empty_entities+=1
        empty_entities_list.append(entity)
        continue
    original_wiki_urls=newspaper_dataset[entity][0][0]["wiki_url"]
    for article in newspaper_dataset[entity]:
            if article[2]=="different":
                newspaper_hard_negatives[entity].append(article)
                assert article[0]["wiki_url"]==article[1]["wiki_url"]
                assert article[0]["wiki_url"] == original_wiki_urls
            elif article[2]=="same":
                newspaper_positives[entity].append(article)
                assert article[0]["wiki_url"]==article[1]["wiki_url"]
                assert article[0]["wiki_url"] == original_wiki_urls

            else:
                continue
            

print("Number of entities in newspaper dataset (pos) ", len(newspaper_positives.keys()))
print("Number of entities in newspaper dataset (neg) ", len(newspaper_hard_negatives.keys()))
print("Number of entities in newspaper dataset ", len(newspaper_dataset.keys()))
print("Number of empty entities ", empty_entities)
print("Empty entities ", empty_entities_list)
print("Empty Jerry Lewis, ",newspaper_dataset["Jerry Lewis"])


###newspaper entity to title dict
news_title_to_entity_dict=defaultdict(list)
news_entity_to_title_dict=defaultdict(list)
empty_urls=0
empty_url_list=[]
for entity in newspaper_dataset:
    if len(newspaper_dataset[entity])==0:
        continue
    
    wiki_urls=newspaper_dataset[entity][0][0]["wiki_url"]
    
    ##Convert string to list by ast.literal_eval
    wiki_urls=ast.literal_eval(wiki_urls)
    
    if len(wiki_urls)==0:
        empty_urls+=1
        empty_url_list.append(entity)
        continue
    
    wiki_titles=[wiki_url_title_to_title(url) for url in wiki_urls]
    
    ##Make an inverted dict. Title to entity
    for title in wiki_titles:
        news_title_to_entity_dict[title].append(entity)
    
    ###Entity to title
    news_entity_to_title_dict[entity]=wiki_titles

print("Number of entities with empty urls ", empty_urls)
print("Entities with empty urls ", empty_url_list)


news_title_to_entity_dict={title:list(set(news_title_to_entity_dict[title])) for title in news_title_to_entity_dict}

###Check if there aren't two entities with the same title
for title in news_title_to_entity_dict:

    
    assert len(news_title_to_entity_dict[title])<=1


for entity in news_entity_to_title_dict:

    assert len(news_entity_to_title_dict[entity])>0

newspaper_wiki_titles_all=set(news_title_to_entity_dict.keys())    

print("Number of titles in newspaper dataset ", len(newspaper_wiki_titles_all))

##save the list as a pickle - using this to prep wiki data of contexts and first paras - the latter is okay already
pickle.dump(newspaper_wiki_titles_all, open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/news_title_to_entity_dict.pkl", "wb"))


"""Combining newspaper positives with wiki first paras"""
###Load the wiki first para data for our custom queries
with open(wiki_firstpara_path, "rb") as f:
    wiki_firstpara=pickle.load(f)

##This is of format entity: first para
newspaper_wiki_first_para_positives=defaultdict(list)
newspaper_wiki_first_para_negatives=defaultdict(list)
for entity in newspaper_positives:
    entity_titles=news_entity_to_title_dict[entity]
    for article_pair in newspaper_positives[entity]:
        wiki_articles_with_entity_title=[(wiki_firstpara[title],title) for title in entity_titles if title in wiki_firstpara]
        ##KEep longest first para
        if len(wiki_articles_with_entity_title)==0:
            continue
        longest_first_para=max(wiki_articles_with_entity_title, key=lambda x: len(x[0]))
        longest_first_para_text=longest_first_para[0]
        wiki_longestpara_title=longest_first_para[1]
        ##make a dict - 'wiki_title', 'text' as keys
        wiki_para_dict={"wiki_title":wiki_longestpara_title, "text":longest_first_para_text}
        article_1=article_pair[0]
        article_2=article_pair[1]
        ##Pair these two up and add a "same" label
        newspaper_wiki_first_para_positives[entity].append([article_1, wiki_para_dict, "same"])
        newspaper_wiki_first_para_positives[entity].append([article_2, wiki_para_dict, "same"])
 
print("Number of entities in newspaper dataset  with wiki para (pos) ", len(newspaper_wiki_first_para_positives.keys())) 



for entity in newspaper_hard_negatives:
    entity_titles=news_entity_to_title_dict[entity]
    for article_pair in newspaper_hard_negatives[entity]:

        wiki_articles_with_entity_title=[(wiki_firstpara[title],title) for title in entity_titles if title in wiki_firstpara]
        ##KEep longest first para
        if len(wiki_articles_with_entity_title)==0:
            continue
        longest_first_para=max(wiki_articles_with_entity_title, key=lambda x: len(x[0]))
        longest_first_para_text=longest_first_para[0]
        wiki_longestpara_title=longest_first_para[1]
        ##make a dict - 'wiki_title', 'text' as keys
        wiki_para_dict={"wiki_title":wiki_longestpara_title, "text":longest_first_para_text}
        ##Check which article has classification "Negative"
        for article in article_pair[:2]:
            if article["classification"]=="Negative":
                neg_article=article
            else:
                continue
        ##Pair these two up and add a "different" label
        newspaper_wiki_first_para_negatives[entity].append([neg_article, wiki_para_dict, "different"])
                
print("Number of entities in newspaper dataset with wiki para (neg) ", len(newspaper_wiki_first_para_negatives.keys()))

"""Now, make positives and negs for wiki contexts"""

##load contexts
with open(wiki_dataset_path_contexts, "rb") as f:
    wiki_contexts=pickle.load(f)

print("Number of entities in wiki dataset ", len(wiki_contexts.keys()))
print((wiki_contexts["Shirley Temple"]))
##First, clean the contexts
cleaned_wiki_contexts=defaultdict(list)
for entity in wiki_contexts:
    for context in wiki_contexts[entity]:
        cleaned_context=context_clean(context, entity)
        if cleaned_context:
            cleaned_wiki_contexts[entity].append(cleaned_context)

print("Number of entities in cleaned wiki dataset ", len(cleaned_wiki_contexts.keys()))
print(len(cleaned_wiki_contexts["Shirley Temple"]))
print(cleaned_wiki_contexts["Shirley Temple"][0])
    
##Deduplicate the contexts - if context, mention_start, mention_end are the same, keep only one for the entity
for entity in cleaned_wiki_contexts:
    full_string=[context["context"]+"_"+str(context["mention_start"])+"_"+str(context["mention_end"])+"_"+ context["mention_text"] for context in cleaned_wiki_contexts[entity]]
    ##keep only unique - keep any one if multiple
    cleaned_wiki_contexts[entity]=list({get_hash(string):context for string, context in zip(full_string, cleaned_wiki_contexts[entity])}.values())
    ###N0W COUNT THOSE CONTEXTS WHICH HAVE MULTIPLE MENTIONS OF THE SAME ENTITY- DUPLICATE ONLY BY CONTEXT
    context_duplicates=[context for context in cleaned_wiki_contexts[entity] if cleaned_wiki_contexts[entity].count(context)>1]
    assert len(context_duplicates)==0

##We'll now pair up contexts with newspaper positives and negatives. For each positive, sample 5 contexts
newspaper_wiki_context_positives=defaultdict(list)
newspaper_wiki_context_negatives=defaultdict(list)

# random.seed(42)

for entity in newspaper_positives:
    entity_titles=news_entity_to_title_dict[entity]
    ##combine all contexts of all titles of the entity
    entity_contexts=[]
    for title in entity_titles:
        if title in cleaned_wiki_contexts:
            entity_contexts.extend(cleaned_wiki_contexts[title])
    if len(entity_contexts)==0:
        continue
    for article_pair in newspaper_positives[entity]:
        ##Sample 10 contexts
        sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
        for context in sampled_contexts:
            article_1=article_pair[0]
            article_2=article_pair[1]
            newspaper_wiki_context_positives[entity].append([article_1, context, "same"])
            newspaper_wiki_context_positives[entity].append([article_2, context, "same"])


print("Number of entities in newspaper dataset with wiki context (pos) ", len(newspaper_wiki_context_positives.keys()))


##Now negatives
for entity in newspaper_hard_negatives:
    entity_titles=news_entity_to_title_dict[entity]
    ##combine all contexts of all titles of the entity
    entity_contexts=[]
    for title in entity_titles:
        if title in cleaned_wiki_contexts:
            entity_contexts.extend(cleaned_wiki_contexts[title])
    if len(entity_contexts)==0:
        continue
    for article_pair in newspaper_hard_negatives[entity]:
        ##Sample 10 contexts
        sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
        for context in sampled_contexts:
            for article in article_pair[:2]:
                if article["classification"]=="Negative":
                    neg_article=article
                else:
                    continue
            newspaper_wiki_context_negatives[entity].append([neg_article, context, "different"])    

print("Number of entities in newspaper dataset with wiki context (neg) ", len(newspaper_wiki_context_negatives.keys()))

""" ###Now, positives and negatives from hard negatives """
##First, we need to load the disambiguation dict
disamb_dict=json.load(open(disamb_dict_path, "r"))
all_contexts=pickle.load(open(all_contexts_disamb_path, "rb"))


# ##Dedup the contexts
# for entity in all_contexts:
#     full_string=[context["context"]+str(context["mention_start"])+str(context["mention_end"])+ context["mention_text"] for context in all_contexts[entity]]
#     ##keep only unique - keep any one if multiple
#     all_contexts[entity]=list({get_hash(string):context for string, context in zip(full_string, all_contexts[entity])}.values())
#     ###N0W COUNT THOSE CONTEXTS WHICH HAVE MULTIPLE MENTIONS OF THE SAME ENTITY- DUPLICATE ONLY BY CONTEXT
#     context_duplicates=[context for context in all_contexts[entity] if all_contexts[entity].count(context)>1]
#     assert len(context_duplicates)==0
###Make disamb sets - keys and values fall in the same set - make a list of sets
disamb_sets=[]
for key in disamb_dict:
    disamb_set=set(disamb_dict[key])
    disamb_set.add(key)
    disamb_sets.append(disamb_set)

disamb_entities_all=set()
for disamb_set in disamb_sets:
    disamb_entities_all.update(disamb_set)

##This is a dict of form entity1: entity2....
newspaper_contexts_wiki_hard_negatives=defaultdict(list)

for entity in newspaper_positives:
    
    entity_titles=news_entity_to_title_dict[entity]
    ##For each title, see if it is in disamb dict. The title is a positive, the rest are negatives
    if disamb_entities_all.intersection(set(entity_titles)) == set():
        continue
    for title in entity_titles:
        if title not in disamb_entities_all:
            continue
        for disamb_set in disamb_sets:
            if title in disamb_set:
                positive=title
                negatives=disamb_set.difference(set([positive]))
                break
        ##Get the contexts for the negatives
        for article_pair in newspaper_positives[entity]:
            article_1=article_pair[0]
            article_2=article_pair[1]
            for neg_title in negatives:
                if neg_title in all_contexts:
                    neg_contexts=all_contexts[neg_title]
                    sampled_contexts=random.sample(neg_contexts, min(5, len(neg_contexts)))
                    for context in sampled_contexts:
                        if article_1["classification"]=="Positive":
                            newspaper_contexts_wiki_hard_negatives[entity].append([article_1, context, "different"])
                        if article_2["classification"]=="Positive":
                            newspaper_contexts_wiki_hard_negatives[entity].append([article_2, context, "different"])
                else:
                    continue
            
        
"""Time to create hard wiki negatives for first paragraphs now"""
newspaper_wiki_first_para_wiki_negatives=defaultdict(list)
for entity in newspaper_positives:
        
        entity_titles=news_entity_to_title_dict[entity]
        ##For each title, see if it is in disamb dict. The title is a positive, the rest are negatives
        if disamb_entities_all.intersection(set(entity_titles)) == set():
            continue
        for title in entity_titles:
            if title not in disamb_entities_all:
                continue
            for disamb_set in disamb_sets:
                if title in disamb_set:
                    positive=title
                    negatives=disamb_set.difference(set([positive]))
                    break
            ##Get the contexts for the negatives
            for article_pair in newspaper_positives[entity]:
                article_1=article_pair[0]
                article_2=article_pair[1]
                for neg_title in negatives:
                    if neg_title in wiki_firstpara:
                        neg_firstpara=wiki_firstpara[neg_title]
                        neg_firstpara_dict={"wiki_title":neg_title, "text":neg_firstpara}
                        if article_1["classification"]=="Positive":
                            newspaper_wiki_first_para_wiki_negatives[entity].append([article_1, neg_firstpara_dict, "different"])
                        if article_2["classification"]=="Positive":
                            newspaper_wiki_first_para_wiki_negatives[entity].append([article_2, neg_firstpara_dict, "different"])
                    else:
                        continue

def find_total_pairs_dict(entity_dict):
    total_pairs=0
    for entity in entity_dict:
        total_pairs+=len(entity_dict[entity])
    return total_pairs

""" Only the wiki ones are now left. """


"""Wiki contexts to wiki contexts (pos) - across contexts - wiki_dataset_path_contexts"""

wiki_positives_contexts_subset_across_contexts=defaultdict(list)
##just use the cleaned contexts
for entity in cleaned_wiki_contexts:
    entity_contexts=cleaned_wiki_contexts[entity]
    sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
    context_combinations=list(combinations(sampled_contexts, 2))
    for context_pair in context_combinations:
        wiki_positives_contexts_subset_across_contexts[entity].append([context_pair[0], context_pair[1], "same"])


    
    
print(find_total_pairs_dict(wiki_positives_contexts_subset_across_contexts))

"""Wiki contexts to wiki first para (pos) - across contexts - wiki_dataset_path_contexts"""

wiki_positives_context_to_first_para_subset_across_contexts=defaultdict(list)
for entity in cleaned_wiki_contexts:
    entity_contexts=cleaned_wiki_contexts[entity]
    sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
    for context in sampled_contexts:
        if entity in wiki_firstpara:
            first_para=wiki_firstpara[entity]
            first_para_dict={"wiki_title":entity, "text":first_para}
            wiki_positives_context_to_first_para_subset_across_contexts[entity].append([context, first_para_dict, "same"])
        else:
            continue

wiki_in_context_positives=defaultdict(list)
for entity in cleaned_wiki_contexts:
    entity_contexts=cleaned_wiki_contexts[entity]
    ##Check if any two contexts have the same 'context' key value but different mention_start, mention_end, mention_text
    context_combinations=list(combinations(entity_contexts, 2))
    for context_pair in context_combinations:
        if context_pair[0]["context"]==context_pair[1]["context"]:
            if context_pair[0]["mention_start"]!=context_pair[1]["mention_start"]:
                wiki_in_context_positives[entity].append([context_pair[0], context_pair[1], "same"])
                print(context_pair[0])
                print(context_pair[1])
                print("Entity ", entity)
                raise ValueError("Same context, different mention start")
            else:
                continue
        else:
            continue
        
wiki_in_context_negatives=defaultdict(list) ##These are not enough, need to add more. We can use all_contexts for this
###We need to find contexts that have more than 1 entity. We have ruled out the presence of same entity within context. They aren't likely multiple links within the same para 
context_hash_dict=defaultdict(list)

for entity in cleaned_wiki_contexts:
    for context in cleaned_wiki_contexts[entity]:
        context_hash=get_hash(context["context"])
        context["entity"]=entity
        context_hash_dict[context_hash].append(context)

# subset_context_entity_count=len(cleaned_wiki_contexts)
# all_contexts_random_subset=random.sample(list(all_contexts.keys()), 40*subset_context_entity_count)
# all_contexts_random_subset_dict={entity:all_contexts[entity] for entity in all_contexts_random_subset}
for entity in all_contexts:
    for context in all_contexts[entity]:
        context_hash=get_hash(context["context"])
        if context_hash in context_hash_dict:
            context["entity"]=entity
            context_hash_dict[context_hash].append(context)
        
contexts_with_multiple_entities=[context_hash_dict[context_hash] for context_hash in context_hash_dict if len(context_hash_dict[context_hash])>1]
contexts_with_multiple_entities_total=sum([len(context_list) for context_list in contexts_with_multiple_entities])
##Check that same context['context'] doesn't have same mention_start, mention_end, mention_text
for context_list in contexts_with_multiple_entities:
    context_combinations=list(combinations(context_list, 2))
    for context_pair in context_combinations:
        if context_pair[0]['context']==context_pair[1]['context']:
            if context_pair[0]["mention_start"]==context_pair[1]["mention_start"]:
                continue
            else:
                wiki_in_context_negatives[context_pair[0]["entity"]].append([context_pair[0], context_pair[1], "different"])
        else:
            continue

print("Number of entities in wiki_in_context_negatives with multiple entities in context ", len(wiki_in_context_negatives.keys()))
print("Total number of pairs in wiki_in_context_negatives with multiple entities in context ", find_total_pairs_dict(wiki_in_context_negatives))


    


###Now, wiki to wiki disamb HN - hard - TOPREP - Like before, from disamb set
wiki_contexts_disamb_hard_negatives=defaultdict(list)
for entity in cleaned_wiki_contexts:
    if entity not in disamb_entities_all:
        continue
    entity_contexts=cleaned_wiki_contexts[entity]
    sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
    for context in sampled_contexts:
        for disamb_set in disamb_sets:
            if entity in disamb_set:
                negatives=disamb_set.difference(set([entity]))
                break
        for neg_entity in negatives:
            if neg_entity in all_contexts:
                neg_entity_contexts=all_contexts[neg_entity]
                sampled_neg_contexts=random.sample(neg_entity_contexts, min(5, len(neg_entity_contexts)))
                for neg_context in sampled_neg_contexts:
                    wiki_contexts_disamb_hard_negatives[entity].append([context, neg_context, "different"])
            else:
                continue
            
###Load all first paras
with open(all_first_para_path, "r") as f:
    all_first_paras=json.load(f)
    

###Now for first para
wiki_firstpara_disamb_hard_negatives=defaultdict(list)
for entity in wiki_firstpara:
    if entity not in disamb_entities_all:
        continue
    first_para=wiki_firstpara[entity]
    first_para_dict={"wiki_title":entity, "text":first_para}
    for disamb_set in disamb_sets:
        if entity in disamb_set:
            negatives=disamb_set.difference(set([entity]))
            break
    for neg_entity in negatives:
        if neg_entity in all_first_paras:
            neg_firstpara=all_first_paras[neg_entity]
            neg_firstpara_dict={"wiki_title":neg_entity, "text":neg_firstpara}
            wiki_firstpara_disamb_hard_negatives[entity].append([first_para_dict, neg_firstpara_dict, "different"])
        else:
            continue



print(find_total_pairs_dict(wiki_in_context_negatives))

print(find_total_pairs_dict(wiki_contexts_disamb_hard_negatives))
print(find_total_pairs_dict(wiki_firstpara_disamb_hard_negatives))

# print("Number of contexts ", len(contexts_with_multiple_entities))

###Now easy wiki context negatives - random sample across contexts of entities
wiki_contexts_easy_negatives=defaultdict(list)
for entity in cleaned_wiki_contexts:
    entity_contexts=cleaned_wiki_contexts[entity]
    sampled_contexts=random.sample(entity_contexts, min(5, len(entity_contexts)))
    for context in sampled_contexts:
        for entity_neg in cleaned_wiki_contexts:
            if entity_neg==entity:
                continue
            neg_entity_contexts=cleaned_wiki_contexts[entity_neg]
            sampled_neg_contexts=random.sample(neg_entity_contexts, min(5, len(neg_entity_contexts)))
            for neg_context in sampled_neg_contexts:
                wiki_contexts_easy_negatives[entity].append([context, neg_context, "different"])

###Now for first paras
wiki_firstpara_easy_negatives=defaultdict(list)
for entity in wiki_firstpara:
    first_para=wiki_firstpara[entity]
    first_para_dict={"wiki_title":entity, "text":first_para}
    for entity_neg in wiki_firstpara:
        if entity_neg==entity:
            continue
        neg_firstpara=wiki_firstpara[entity_neg]
        neg_firstpara_dict={"wiki_title":entity_neg, "text":neg_firstpara}
        wiki_firstpara_easy_negatives[entity].append([first_para_dict, neg_firstpara_dict, "different"])


###Save some previews



total_pairs=[find_total_pairs_dict(entity_dict) for entity_dict in [newspaper_wiki_context_positives, newspaper_wiki_context_negatives,
                                                                    newspaper_wiki_first_para_positives, newspaper_wiki_first_para_negatives,
                                                                    newspaper_positives, newspaper_hard_negatives,
                                                                    newspaper_contexts_wiki_hard_negatives, newspaper_wiki_first_para_wiki_negatives,
                                                                    wiki_positives_contexts_subset_across_contexts, wiki_positives_context_to_first_para_subset_across_contexts,
                                                                     wiki_in_context_negatives,
                                                                    wiki_contexts_disamb_hard_negatives, wiki_firstpara_disamb_hard_negatives,
                                                                    wiki_contexts_easy_negatives, wiki_firstpara_easy_negatives]]
print("Total pairs ", total_pairs)

###Combine all dicts and save
all_dicts={"newspaper_wiki_context_positives":newspaper_wiki_context_positives, "newspaper_wiki_context_negatives":newspaper_wiki_context_negatives,
            "newspaper_wiki_first_para_positives":newspaper_wiki_first_para_positives, "newspaper_wiki_first_para_negatives":newspaper_wiki_first_para_negatives,
            "newspaper_positives":newspaper_positives, "newspaper_hard_negatives":newspaper_hard_negatives,
            "newspaper_contexts_wiki_hard_negatives":newspaper_contexts_wiki_hard_negatives, "newspaper_wiki_first_para_wiki_negatives":newspaper_wiki_first_para_wiki_negatives,
            "wiki_positives_contexts_subset_across_contexts":wiki_positives_contexts_subset_across_contexts, "wiki_positives_context_to_first_para_subset_across_contexts":wiki_positives_context_to_first_para_subset_across_contexts,
             "wiki_in_context_negatives":wiki_in_context_negatives,
            "wiki_contexts_disamb_hard_negatives":wiki_contexts_disamb_hard_negatives, "wiki_firstpara_disamb_hard_negatives":wiki_firstpara_disamb_hard_negatives,
            "wiki_contexts_easy_negatives":wiki_contexts_easy_negatives, "wiki_firstpara_easy_negatives":wiki_firstpara_easy_negatives}

pickle.dump(all_dicts, open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_dicts.pickle", "wb"))

##Keep one prototype entity-pair of each dict in the dict and save it as a json
all_dicts_prototype=defaultdict(dict)
for key in all_dicts:
    random_sample_entity=random.sample(list(all_dicts[key].keys()), 1)[0]
    all_dicts_prototype[key][random_sample_entity]=all_dicts[key][random_sample_entity]
    ##Keep only 2 random pairs of the entity
    random_sample_entity_pairs=random.sample(all_dicts[key][random_sample_entity], 1)
    all_dicts_prototype[key][random_sample_entity]=random_sample_entity_pairs
    
json.dump(all_dicts_prototype, open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_dicts_prototype.json", "w"))

##For the first paragraph pairs, the wiki first paragraphs need to a prefix - a sentence comprising of "instance of" and aliases from wikidata


    
            






