import json
import requests
import pickle
from tqdm import tqdm




def create_template(title, instance_of, aliases):
    """
    Creates a formatted string that describes an entity with its title, types, and aliases in a more natural English format.
    
    Parameters:
        title (str): The title of the entity.
        instance_of (list): A list of types or categories of the entity.
        aliases (list): A list of other names or aliases for the entity.
    
    Returns:
        str: A formatted sentence describing the entity.
    """
    # Format the instance of list
    if instance_of:
        if len(instance_of) > 1:
            instance_of_formatted = ", ".join(instance_of[:-1]) + ", and " + instance_of[-1]
        else:
            instance_of_formatted = instance_of[0]
    else:
        instance_of_formatted = "unknown type"

    # Format the aliases list
    if aliases:
        if len(aliases) > 1:
            alias_list = ", ".join(aliases[:-1]) + ", and " + aliases[-1]
        else:
            alias_list = aliases[0]
    else:
        alias_list = []

    if len(alias_list)>0 and len(instance_of_formatted)>0:
        return f"{title} is of type {instance_of_formatted}. Also known as {alias_list}."
    elif len(alias_list)==0 and len(instance_of_formatted)>0:
        return f"{title} is of type {instance_of_formatted}."
    elif len(alias_list)>0 and len(instance_of_formatted)==0:
        return f"{title} is also known as {alias_list}."
    else:
        return f"{title}."

def add_mention_info(template,title):
    men_start=0
    men_end=len(title)
    
    assert template[men_start:men_end]==title
    
    return {"mention_start":0, "mention_end":men_end,"mention_text":title}
    

def batch_process(items, batch_size):
    """Yield successive n-sized batches from items."""
    for i in range(0, len(items), batch_size):
        yield items[i:i + batch_size]

def get_wikidata_ids(titles):
    """Retrieve Wikidata IDs for a list of English Wikipedia titles."""
    wikidata_ids = {}
    for title in tqdm(titles):
        url = 'https://en.wikipedia.org/w/api.php'
        params = {
            'action': 'query',
            'format': 'json',
            'titles': title,
            'prop': 'pageprops',
            'redirects': 1,
        }
        response = requests.get(url, params=params).json()
        pages = response['query']['pages']
        for page_id, page_info in pages.items():
            if 'pageprops' in page_info:
                wikidata_id = page_info['pageprops'].get('wikibase_item')
                wikidata_ids[title] = wikidata_id
            else:
                wikidata_ids[title] = None
    return wikidata_ids





def get_wikidata_details(wikidata_id):
    """Retrieve all 'instance of' labels and all aliases for a given Wikidata ID."""
    url = 'https://www.wikidata.org/w/api.php'
    params = {
        'action': 'wbgetentities',
        'format': 'json',
        'ids': wikidata_id,
        'props': 'claims|aliases|labels',
        'languages': 'en',  # Fetch labels and aliases in English
    }
    response = requests.get(url, params=params).json()
    
    if 'entities' not in response:
        print(f"Error: {response}")
        return {
        "wikidata_id": None,
        "instance_of_labels": [],
        "aliases": []
    }

    entity = response['entities'][wikidata_id]
    instances_of = []
    if 'P31' in entity['claims']:  # Check for the 'instance of' property
        for claim in entity['claims']['P31']:
            instance_qid = claim['mainsnak']['datavalue']['value']['id']
            # Fetch label for the instance QID
            instances_of.append(instance_qid)

    aliases = [alias['value'] for alias in entity.get('aliases', {}).get('en', [])]  # Collect all aliases in English

    return {
        "wikidata_id": wikidata_id,
        "instance_of_labels": instances_of,
        "aliases": aliases
    }



def get_wikipedia_redirects(titles):
    """
    Fetches all redirects for a given list of Wikipedia titles.
    
    Args:
    titles (list of str): List of Wikipedia titles.
    
    Returns:
    dict: Dictionary with titles as keys and list of redirects as values.
    """
    redirects = {}
    session = requests.Session()
    URL = "https://en.wikipedia.org/w/api.php"
    
    for title in titles:
        params = {
            "action": "query",
            "format": "json",
            "titles": title,
            "redirects": 1,
            "prop": "redirects",
            "rdlimit": "max"
        }
        response = session.get(URL, params=params).json()
        page = next(iter(response['query']['pages'].values()))
        
        if 'redirects' in page:
            redirects[title] = [redirect['title'] for redirect in page['redirects']]
        else:
            redirects[title] = []
    
    return redirects

all_data_dict=pickle.load(open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_dicts.pickle","rb"))

###keep only those keys that have first para newspaper_wiki_first_para_positives,newspaper_wiki_first_para_negatives,newspaper_wiki_first_para_wiki_negatives,
# wiki_positives_context_to_first_para_subset_across_contexts, wiki_firstpara_disamb_hard_negatives,wiki_firstpara_easy_negatives

relevant_keys=["newspaper_wiki_first_para_positives","newspaper_wiki_first_para_negatives","newspaper_wiki_first_para_wiki_negatives","wiki_positives_context_to_first_para_subset_across_contexts","wiki_firstpara_disamb_hard_negatives","wiki_firstpara_easy_negatives"]


new_dict={}
for key in tqdm(all_data_dict.keys()):
    if key in relevant_keys:
        new_dict[key]=all_data_dict[key]
        
# ##Collect all pairs to extract entity names
# unique_titles=set()
# for key in tqdm(new_dict.keys()):
#     wiki_title_in_key=False
#     for entity in new_dict[key]:
#         for pair in new_dict[key][entity]:
#             for item in pair :
#                 if "wiki_title" in item:
#                     wiki_title_in_key=True
#                     unique_titles.add(item["wiki_title"])
#     assert wiki_title_in_key==True
#     print(key,wiki_title_in_key)
    
# # for key in new_dict.keys():
# #     for entity in new_dict[key]:
# #         print(entity)
# #         for pair in new_dict[key][entity]:
# #             for item in pair :
# #                 if "wiki_title" in item:
# #                     if item["wiki_title"]=="Malcolm Wilson (governor)":
# print(len(unique_titles))
# assert "Malcolm Wilson (governor)" in unique_titles

# ###Save title list as json
# with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/first_para_newssubset_wiki_titles.json","w") as f:
#     json.dump(list(unique_titles),f)
    
# # redirect_mapping = get_wikipedia_redirects(list(unique_titles)[:10])
# # print(redirect_mapping)

# # exit()

# # exit()

# ##########GEt data from wikidata
# # Example usage with a large list of titles. Return {title:wikidata_output}
# # unique_titles=list(unique_titles)
# title_to_id = get_wikidata_ids(list(unique_titles))

# ##Save as json
# with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/first_para_newssubset_wiki_ids.json","w") as f:
#     json.dump(title_to_id,f)

# title_to_info={}
# for title in tqdm(title_to_id.keys()):
#     print(title)
#     title_to_info[title]=get_wikidata_details(title_to_id[title])



# ##Replace instance_of_labels that contains a list of qids with a list of labels
# unique_qids=set()
# for title in title_to_info.keys():
#     for qid in title_to_info[title]["instance_of_labels"]:
#         unique_qids.add(qid)

# qid_to_label={}
# for qid in tqdm(unique_qids):
#     url = 'https://www.wikidata.org/w/api.php'
#     params = {
#         'action': 'wbgetentities',
#         'format': 'json',
#         'ids': qid,
#         'props': 'labels',
#         'languages': 'en',  # Fetch labels and aliases in English
#     }
#     response = requests.get(url, params=params).json()
#     qid_to_label[qid]=response["entities"][qid]["labels"]["en"]["value"]

# for title in title_to_info.keys():
#     if "instance_of_labels" not in title_to_info[title]:
#         title_to_info[title]["instance_of_labels"]=[]
#     title_to_info[title]["instance_of_labels"]=[qid_to_label[qid] for qid in title_to_info[title]["instance_of_labels"]  ]
    
# ###Save as json
# with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/first_para_newssubset_wiki_aliases.json","w") as f:
#     json.dump(title_to_info,f)

######
##Open it now

##Open unique titles
with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/first_para_newssubset_wiki_titles.json","r") as f:
    unique_titles=json.load(f)

with open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/first_para_newssubset_wiki_aliases.json","r") as f:
    title_to_info=json.load(f)

print(len(title_to_info))

##Print titles not in unique_titles
for title in title_to_info.keys():
    if title not in unique_titles:
        print(title)
        
# assert "Malcolm Wilson (governor)" in title_to_info.keys()
# assert "Malcolm Wilson (governor)" in unique_titles
relevant_keys=["newspaper_wiki_first_para_positives","newspaper_wiki_first_para_negatives","newspaper_wiki_first_para_wiki_negatives","wiki_positives_context_to_first_para_subset_across_contexts","wiki_firstpara_disamb_hard_negatives","wiki_firstpara_easy_negatives"]

###Add template and mention info to the original data
for key in new_dict.keys():
    for entity in new_dict[key]:
        for pair in new_dict[key][entity]:
            for item in pair :
                if "wiki_title" in item:
                    title=item["wiki_title"]
                    item["wikidata_info"]=title_to_info[title]
                    item["template"]=create_template(title,item["wikidata_info"]["instance_of_labels"],item["wikidata_info"]["aliases"])
                    mention_info=add_mention_info(item["template"],title)
                    ##add    the keys from mention_info to item
                    for key_item in mention_info.keys():
                        item[key_item]=mention_info[key_item]
                    ###Add template to text
                    assert type(item["text"])==str
                    assert type(item["template"])==str
                    
             
                    original_text_len=len(item["text"])
                    item["text"]=item["template"] + " " + item["text"]

                    
                    assert item["text"][:item["mention_end"]]==title
                    assert len(item["text"])==original_text_len+len(item["template"])+1

###Now replace all_dicts keys that are in new_dict with new_dict values
for key in new_dict.keys():
    # The line `all_data_dict[key]=new_dict[key]` is replacing the values of specific keys in the
    # `all_data_dict` dictionary with the corresponding values from the `new_dict` dictionary.
    all_data_dict[key]=new_dict[key]
    
##Save a copy
pickle.dump(all_data_dict,open("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/all_dicts_with_wikidata_type_aliases.pickle","wb"))
