import pickle
import json
import random
import numpy as np
from collections import defaultdict
import os

data_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_dicts_with_wikidata_type_aliases.pickle"

all_data=pickle.load(open(data_path,"rb"))


###Count pairs in all keys
def find_total_pairs_dict(entity_dict):
    total_pairs=0
    for entity in entity_dict:
        total_pairs+=len(entity_dict[entity])
    return total_pairs

# for key in all_data:
#     print(key,find_total_pairs_dict(all_data[key]))
    

key_prop_dict={"newspaper_wiki_context_positives":0.1,
"newspaper_wiki_context_negatives":0.1,
"newspaper_wiki_first_para_positives":0.3,
"newspaper_wiki_first_para_negatives":0.3,
"newspaper_positives":1,
"newspaper_hard_negatives":0.2,
"newspaper_contexts_wiki_hard_negatives":0.01,
"newspaper_wiki_first_para_wiki_negatives":1,
"wiki_positives_contexts_subset_across_contexts":1,
"wiki_positives_context_to_first_para_subset_across_contexts":1,
"wiki_in_context_negatives":1,
"wiki_contexts_disamb_hard_negatives":0.25,
"wiki_firstpara_disamb_hard_negatives":1,
"wiki_contexts_easy_negatives":0.0025,
"wiki_firstpara_easy_negatives":0.05}


###Save sample
sample_data={}

random.seed(42)

for key in all_data:
    sample_data[key]={}
    total_pairs=find_total_pairs_dict(all_data[key])
    sample_size=int(total_pairs*key_prop_dict[key])
    print(key,sample_size)
    ###Stratify by entity
    entities=list(all_data[key].keys())
    ##Arrange entities by number of pairs in descending order
    entities.sort(key=lambda x:len(all_data[key][x]),reverse=True)
    ###Now, keep filling up the sample - stratified by entity
    sample_data[key]={}
    entity_count=len(entities)
    for entity in entities:
        pairs=all_data[key][entity]
        random.shuffle(pairs)
        if sample_size>0:
            sample_data[key][entity]=pairs[:int(sample_size/entity_count)]
            sample_size-=int(sample_size/entity_count)
            entity_count=entity_count-1
        if sample_size<=0:
            ###Sample 1-10 pairs from each entity
            sample_data[key][entity]=pairs[:random.randint(1,10)]
            
    # print(key,len(sample_data[key]))
    
##For each key , pair and then each item within it, replace "context" with "text"
for key in sample_data:
    for entity in sample_data[key]:
        for i in range(len(sample_data[key][entity])):
            pair=sample_data[key][entity][i]
            for j in range(2):
                item=pair[j]
                if "context" not in item and "text" in item:
                    continue
                item["text"]=item["context"]
                del item["context"]
                pair[j]=item
            sample_data[key][entity][i]=pair


 
##Count pairs
for key in sample_data:
    print(key,find_total_pairs_dict(sample_data[key]))
    
###Flatten dict - remove keys and make it a list, however, merge entities!
sample_data_entity_dict={}
for key in sample_data:
    for entity in sample_data[key]:
        if entity not in sample_data_entity_dict:
            sample_data_entity_dict[entity]=[]
        sample_data_entity_dict[entity]+=sample_data[key][entity]
        

    
###Check if each pair has the required fields
print("Checking if each pair has the required fields")
for entity in sample_data_entity_dict.keys():
    for i in range(len(sample_data_entity_dict[entity])):
        pair=sample_data_entity_dict[entity][i]
        first_item=pair[0]
        second_item=pair[1]
        assert "mention_text" in first_item
        assert "mention_text" in second_item
        assert "mention_start" in first_item
        assert "mention_start" in second_item
        assert "mention_end" in first_item
        assert "mention_end" in second_item
        assert "text" in first_item
        assert "text" in second_item



# ###Save as json
with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sample_fine_tuning_data.json","w") as f:
    json.dump(sample_data_entity_dict,f)


##Open data
with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sample_fine_tuning_data.json","r") as f:
    sample_data_entity_dict=json.load(f)    

   



