import pickle
import json
import itertools 
from tqdm import tqdm
import random 
import sys
import os
import re

import cugraph as cnx
import cudf as gd

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)

from nlp_utils.modified_sbert import cluster_fns


def entity_clean(text):

    chars_to_remove = r'"#$%&\()*+/:;<=>@[\\]^_`{|}~.?,!\''

    clean_text = text.translate(str.maketrans('', '', chars_to_remove))

    clean_text = text.encode('ascii', 'ignore').decode()

    if len(clean_text) <= 0.5 * len(text):
        clean_text = "0"

    return clean_text


def len_entities(text):

    """
    Counts number of characters in a text that are made up of entities
    """

    split = text.split("[[")

    ent_list = []

    if text.startswith("[["):
        seg_list = split
    else:
        seg_list = split[1:]

    for seg in seg_list:
        seg_split = seg.split("]]")

        ent_list.append(seg_split[0])

    ent = "".join(ent_list)

    return (len(ent)+(4*len(ent_list)))


def context_clean(text, ent):

    bad_keywords = ("[[File:", "[[Image:", "[[file:", "[[image:", "<", ":File:", "[[WP:", "#REDIRECT", "[[Wikipedia:")

    if text.startswith(bad_keywords):
        return None

    if any(x in text for x in ["User:", "||", "[[]]"]):  # Table, comments 
        return None

    else:

        # Deal with external links
        if "http" in text:

            split = text.split("[http")
            keep = []
            if not text.startswith("[http"):
                keep.append(split[0])
                split = split[1:]
            for spl in split:
                part = spl.split("]", 1)
                left = part[0].split(" ", 1)
                if len(left) >1:
                    keep.extend(left[1:])
                keep.extend(part[1:])
            text = "".join(keep)

            if "http" in text:
                split = text.split(" ")
                keep = [spl for spl in split if "http" not in spl]
                text = " ".join(keep)

        for kw in bad_keywords:
            text = text.split(kw)[0]     

        if ("[[" + ent.lower() + "]]") not in text.lower() and ("[[" + ent.lower() + "|") not in text.lower():
            return None

        elif "=" in text:
            return None

        elif len(text) < 75:
            return None
        elif len(text) > 10000:
            return None

        ent_len = len_entities(text)

        if len(text) - ent_len < 25:
            return None

        elif ent_len/len(text) > 0.85:
            return None

        split = text.split("[[")

        if len(split) < 2:
            return None

        else:

            mention_text = None
            men_seg_start = None 

            if text.startswith("[["):
                context_list = []
                seg_list = split

            else:
                context_list = [split[0]]
                seg_list = split[1:]

            for seg in seg_list:
                seg_split = seg.split("]]")
                if len(seg_split) > 2:
                    seg_split = [x for x in seg_split if x]
                    seg_split = [seg_split[0], "".join(seg_split[1:])]

                if "|" in seg_split[0]:
                    mention = seg_split[0].split("|")[1]
                    entity = seg_split[0].split("|")[0]
                else:
                    mention = seg_split[0]    
                    entity = seg_split[0]
                context_list.append(mention)

                if entity.lower() == ent.lower() and not mention_text:
                    mention_text = mention 
                    men_seg_start = len(context_list) - 1

                if len(seg_split) == 2:
                    context_list.append(seg_split[1])

            if men_seg_start == None:
                print("ERROR")
                print(ent)
                print(text)

            context = "".join(context_list)
            men_start = len("".join(context_list[:men_seg_start]))
            men_end = len("".join(context_list[:men_seg_start+1]))

            assert context[men_start:men_end] == mention_text

            data_dict = {"context": context, "mention_text": mention_text, "mention_start": men_start, "mention_end": men_end}
            
            return data_dict


def clean_data():

    with open('/mnt/data01/wiki_data/disambiguation_dict.json', 'rb') as f:
        disamb_dict = json.load(f)

    with open('/mnt/data01/wiki_data/clean_datasets/full_wiki_para_entity.pkl', 'rb') as f:
        context_data = pickle.load(f)

    print("Cleaning contexts ...")

    new_context_data = {}
    for ent in tqdm(context_data):
        if len(ent) > 1:
            if ent[0].upper() + ent[1:] in new_context_data:
                new_context_data[ent[0].upper() + ent[1:]].extend(context_data[ent])
            else:
                new_context_data[ent[0].upper() + ent[1:]] = context_data[ent]
        else:
            if ent.upper() in new_context_data:
                new_context_data[ent.upper()].extend(context_data[ent])
            else:
                new_context_data[ent.upper()] = context_data[ent]            
    context_data = new_context_data

    print("NEW:", sum([len(val) for val in list(context_data.values())]))
    print(len(context_data))

    dedup_contexts = {}
    clean_context_data = {}
    for ent in tqdm(context_data):
        
        org_contexts = list(set(context_data[ent]))

        dedup_contexts[ent] = org_contexts

        clean_contexts = []

        for context in org_contexts:

            cleaned_context = context_clean(context, ent)
            if cleaned_context:
                clean_contexts.append(cleaned_context) 

        if len(clean_contexts) >= 1:
            clean_context_data[ent] = clean_contexts

    print("BEFORE:", sum([len(val) for val in list(context_data.values())]))
    print(len(context_data))
    print("DEDUP:", sum([len(val) for val in list(dedup_contexts.values())]))
    print(len(dedup_contexts))
    print("AFTER:", sum([len(val) for val in list(clean_context_data.values())]))
    print(len(clean_context_data))

    with open('/mnt/data01/wiki_data/final_datasets_2/full_wiki_para_entity_clean_2.pkl', 'wb') as f:
        pickle.dump(clean_context_data, f)

    # Clean disambiguation dict 
    print("Cleaning disambiguation dicts ...")
    
    new_disamb_dict = {}
    for clu in disamb_dict:
        new_disamb_dict[clu] = [ent[0].upper() + ent[1:] for ent in disamb_dict[clu] if len(ent)>1]
    disamb_dict = new_disamb_dict

    clean_disamb_dict = {}

    for ent_clu in tqdm(disamb_dict.keys()):
        ent_list = [ent for ent in disamb_dict[ent_clu] if len(entity_clean(ent)) >= 2]   # Less than 2 or more than half non-latin chars
        
        # Remove wikipedia features 
        ent_list = [ent for ent in ent_list if "disambiguation" not in ent.lower()]
        ent_list = [ent for ent in ent_list if "wikipedia" not in ent.lower()]
        ent_list = [ent for ent in ent_list if "category" not in ent.lower()]
        ent_list = [ent for ent in ent_list if "image:" not in ent.lower()]
        ent_list = [ent for ent in ent_list if "file:" not in ent.lower()]
        ent_list = [ent for ent in ent_list if "#" not in ent.lower()]   # Links to subpages 

        # Remove if not in context data (ie. no links)
        ent_list = [ent for ent in ent_list if ent in clean_context_data]

        ent_list = list(set(ent_list))

        if len(ent_list) >= 2:
            clean_disamb_dict[ent_clu] = ent_list

    print("BEFORE:", sum([len(val) for val in list(disamb_dict.values())]))
    print(len(disamb_dict))
    print(len([len(val) for val in list(disamb_dict.values()) if len(val) >=2]))

    print("AFTER:", sum([len(val) for val in list(clean_disamb_dict.values())]))
    print(len(clean_disamb_dict))
    print(len([len(val) for val in list(clean_disamb_dict.values()) if len(val) >=2]))

    with open('/mnt/data01/wiki_data/final_datasets_2/disambiguation_dict_clean_2.json', 'w') as f:
        json.dump(clean_disamb_dict, f)

    print("Removing entities from contexts that don't appear in disambiguation dicts")
    ents_in_disamb = [item for sublist in list(clean_disamb_dict.values()) for item in sublist]
    print(len(ents_in_disamb))
    ents_in_disamb_set = set(ents_in_disamb)
    print(len(ents_in_disamb_set))

    final_context_data = {}
    counter = 0 

    for ent in tqdm(clean_context_data):
        if ent in ents_in_disamb_set:
            old_ent_list = clean_context_data[ent]
            new_ent_list = []
            for context in old_ent_list:
                context["id"] = counter
                counter += 1
                new_ent_list.append(context)

            final_context_data[ent] = new_ent_list

    print("BEFORE:", sum([len(val) for val in list(clean_context_data.values())]))
    print("AFTER:", sum([len(val) for val in list(final_context_data.values())]))
    print(len(final_context_data))

    print("Max:", max([len(val) for val in list(final_context_data.values())]))
    
    for l in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]:
        print(l, len([val for val in list(final_context_data.values()) if len(val) > l ]))
    
    for ent in final_context_data:
        if len(final_context_data[ent]) > 50000:
            print(ent, len(final_context_data[ent]))

    print("Trimming contexts ...")
    trimmed_context_data = {}
    for ent in final_context_data:
        if len(final_context_data[ent]) > 512:
            trimmed_context_data[ent] = random.sample(final_context_data[ent], 512)
        else:
            trimmed_context_data[ent] = final_context_data[ent]

    print("TRIMMED:", sum([len(val) for val in list(trimmed_context_data.values())]))
    print(len(trimmed_context_data))

    # with open(f'/mnt/data01/wiki_data/final_datasets_2/full_wiki_para_entity_final_trimmed.pkl', 'wb') as f:
    #     pickle.dump(trimmed_context_data, f)
    with open(f'/mnt/data01/wiki_data/train_splits_2/all_contexts.pkl', 'wb') as f:
        pickle.dump(trimmed_context_data, f)

    return trimmed_context_data, clean_disamb_dict


def positives(clean_disamb_dict):

    print("Splitting data ...")

    already_seen = []

    splits = {'train': [], 'dev': [], 'test': []}
    small_splits = {'train': [], 'dev': [], 'test': []}
    mid_splits = {'train': [], 'dev': [], 'test': []}

    prob_list = range(0, 50)
    small_prob_list = range(0, 500)
    mid_prob_list = range(0, 50)

    pages = list(clean_disamb_dict.keys())
    random.shuffle(pages)
    for page in tqdm(pages):
        new_ents = [ent for ent in clean_disamb_dict[page] if ent not in already_seen]

        choice = random.choice(prob_list)
        mid_choice = random.choice(mid_prob_list)
        small_choice = random.choice(small_prob_list)

        if choice == 8:
            splits['dev'].extend(new_ents)
            if small_choice == 0:
                small_splits['dev'].extend(new_ents)
            if mid_choice == 0:
                mid_splits['dev'].extend(new_ents)
        elif choice == 9:
            splits['test'].extend(new_ents)
            if small_choice == 0:
                small_splits['test'].extend(new_ents)
            if mid_choice == 0:
                mid_splits['test'].extend(new_ents)
        else:
            splits['train'].extend(new_ents)
            if small_choice == 0:
                small_splits['train'].extend(new_ents)
            if mid_choice == 0:
                mid_splits['train'].extend(new_ents)

    print(len(splits['train']))
    print(len(splits['dev']))
    print(len(splits['test']))
    print(len(small_splits['train']))
    print(len(small_splits['dev']))
    print(len(small_splits['test']))
    print(len(mid_splits['train']))
    print(len(mid_splits['dev']))
    print(len(mid_splits['test']))

    # assert len([ent for ent in splits['train'] if ent in splits['dev']]) == 0
    # assert len([ent for ent in splits['train'] if ent in splits['test']]) == 0
    # assert len([ent for ent in splits['test'] if ent in splits['dev']]) == 0

    with open(f'/mnt/data01/wiki_data/train_splits_2/splits_v2.json', 'w') as f:
        json.dump(splits, f)
    with open(f'/mnt/data01/wiki_data/train_splits_2/small_splits_v2.json', 'w') as f:
        json.dump(small_splits, f)
    with open(f'/mnt/data01/wiki_data/train_splits_2/mid_splits_v2.json', 'w') as f:
        json.dump(mid_splits, f)


def negatives(clean_disamb_dict):

    print("Creating negative dict ...")

    print(len(clean_disamb_dict))

    negatives_dict = {}
    for clu in tqdm(clean_disamb_dict):
        ent_list = clean_disamb_dict[clu]
        for ent in ent_list:
            if ent in negatives_dict:
                negatives_dict[ent].extend([e for e in ent_list if e != ent])
            else:
                negatives_dict[ent] = [e for e in ent_list if e != ent]

    for ent in negatives_dict:
        negatives_dict[ent] = list(set(negatives_dict[ent]))

    with open(f'/mnt/data01/wiki_data/train_splits_2/negatives.pkl', 'wb') as f:
        pickle.dump(negatives_dict, f)

    print(len(negatives_dict))


def describe_disambig_dicts(clean_disamb_dict):
    
    print("Max:", max([len(val) for val in list(clean_disamb_dict.values())]))
    for clu in clean_disamb_dict:
        if len(clean_disamb_dict[clu]) > 200:
            print(clu, len(clean_disamb_dict[clu]))
            print(clean_disamb_dict[clu][:5])
            print("**")
    

def clean_fp_data(final_context_data=None, clean_disamb_dict=None):

    with open('/mnt/data01/wiki_data/clean_datasets/full_wiki_fp_all_pages.pkl', 'rb') as f:
        fp_data = pickle.load(f)

    # # Checks
    # print("Checks ...")

    for e in ["Wateree River", "Derry Urban Area", "Dave Culbert"]:
        print(e)
        print(fp_data[e])

    for e in ["Albert Brown (footballer, born 1862)", "Lesley Duncan", "October 6", "Mike Roach"]:
        print(e)
        print(fp_data[e])

    for e in ["Madura United F.C.", "Yoshihiko Wada", "List of volcanoes in Mexico", "Octavio Lepage"]:
        print(e)
        print(fp_data[e])

    # # for dat in ["ace_entities", "aqaint_entities", "msnbc_entities", "wned_wikipedia", "wned_cweb", "conll_entities"]:
    # #     with open(f'/mnt/data01/entity/benchmarks/coref/{dat}.json') as f:
    # #         raw_data = json.load(f)

    # #     for ent in tqdm(raw_data):
    # #         if ent not in fp_data:
    # #             print(ent)
    # #         # assert ent in fp_data

    # # for ent in final_context_data:
    # #     assert ent in fp_data


    # Clean fps 
    bad_keywords = ("File:", "Image:", "file:", "image:", "Wikipedia:", "WP:")

    count = 0
    short_text_count = 0
    short_ent_list = []

    cleaned_fp_data = {}
    for ent in tqdm(fp_data):
        assert len(fp_data[ent]) == 1

        # Clean entity
        if len(ent) > 1:
            cleaned_ent = ent[0].upper() + ent[1:]
        else:
            cleaned_ent = ent[0].upper()

        # Clean text 
        text = fp_data[ent][0]

        text = re.sub('<.*?>','', text)    # Remove text between < >
 
        if text.startswith('\n\n'):
            text = text[2:]
        if text.endswith('\n\n'):
            text = text[:-2]
        if text.startswith('\n'):
            text = text[4:]
        if text.endswith('\n'):
            text = text[:-4]

        split = text.split('\n')  # Remove any files or images
        clean_split = []
        for s in split:
            if not any(b in s for b in bad_keywords):
                clean_split.append(s)
        new_text = "\n".join(clean_split)

        if len(new_text.strip()) < 10:
            text = re.sub("\[\[File:(.*?).\]\]",'', text) 
            text = re.sub("\[\[:File:(.*?).\]\]",'', text) 
            text = re.sub("\[\[Image:(.*?).\]\]",'', text) 
            text = re.sub("\[\[:Image:(.*?).\]\]",'', text) 
            text = re.sub("\[\[image:(.*?).\]\]",'', text) 
            text = re.sub("\[\[file:(.*?).\]\]",'', text) 
            text = re.sub("\[\[Wikidata:(.*?).\]\]",'', text) 
            text = re.sub("\[\[WP:(.*?).\]\]",'', text)

        else:
            text = new_text

        # text in {} (which is often multi-line)
        text = re.sub('\n','', text)
        text = re.sub('{.*?}','', text)
        text = re.sub('{.*?}','', text)
        text = re.sub('{.*?}','', text)
        text = re.sub('<.*?>','', text)

        # Deal with external links
        if "http" in text:

            split = text.split("[http")
            keep = []
            if not text.startswith("[http"):
                keep.append(split[0])
                split = split[1:]
            for spl in split:
                part = spl.split("]", 1)
                left = part[0].split(" ", 1)
                if len(left) >1:
                    keep.extend(left[1:])
                keep.extend(part[1:])
            text = "".join(keep)

            if "http" in text:
                split = text.split(" ")
                keep = [spl for spl in split if "http" not in spl]
                text = " ".join(keep)

        # Replace all internal links with text 
        split = text.split("[[")

        if text.startswith("[["):
            context_list = []
            seg_list = split
        else:
            context_list = [split[0]]
            seg_list = split[1:]

        for seg in seg_list:
            seg_split = seg.split("]]")
            if len(seg_split) > 2:
                seg_split = [x for x in seg_split if x]
                seg_split = [seg_split[0], "".join(seg_split[1:])]

            if "|" in seg_split[0]:
                mention = seg_split[0].split("|")[1]
            else:
                mention = seg_split[0]    
            context_list.append(mention)

            if len(seg_split) == 2:
                context_list.append(seg_split[1])

        text = "".join(context_list)

        text = text.strip()

        if not ("may refer" in text and len(text) < 50):
            cleaned_fp_data[cleaned_ent] = text

        if len(text) <=25:
            short_text_count += 1
            # print(cleaned_ent)
            short_ent_list.append(ent)
    
    print("BEFORE:", len(fp_data))
    print("AFTER:", len(cleaned_fp_data))
    print(f'{short_text_count} short texts')

    with open('short_ents.json', 'w') as f:
        json.dump(short_ent_list, f)

    # Remove from disambiguation dict if not in fps 
    count = 0
    with_fp_disamb_dict = {}
    for k in tqdm(clean_disamb_dict):
        new_list = []
        for ent in clean_disamb_dict[k]:
            if ent in cleaned_fp_data:
                new_list.append(ent)
            else:
                count += 1
                print(ent)
        if len(new_list) > 1:
            with_fp_disamb_dict[k] = new_list

    print(count, "entities missing")
    
    print("BEFORE:", sum([len(val) for val in list(clean_disamb_dict.values())]))
    print(len(clean_disamb_dict))
    print("AFTER:", sum([len(val) for val in list(with_fp_disamb_dict.values())]))
    print(len(with_fp_disamb_dict))

    # Remove entities that aren't in disambiguation pages
    ents_in_disamb = [item for sublist in list(with_fp_disamb_dict.values()) for item in sublist]
    ents_in_disamb_set = set(ents_in_disamb)
    print("Unique entities", len(ents_in_disamb_set))

    with_fp_context_data = {}
    counter = 0 

    for ent in tqdm(final_context_data):
        if ent in ents_in_disamb_set:
            old_ent_list = final_context_data[ent]
            new_ent_list = []
            for context in old_ent_list:
                context["id"] = counter
                counter += 1
                new_ent_list.append(context)

            with_fp_context_data[ent] = new_ent_list

    print("BEFORE:", sum([len(val) for val in list(final_context_data.values())]))
    print("AFTER:", sum([len(val) for val in list(with_fp_context_data.values())]))
    print(len(with_fp_context_data))

    with open(f'/mnt/data01/wiki_data/train_splits_2/all_contexts.pkl', 'wb') as f:
        pickle.dump(with_fp_context_data, f)
    with open(f'/mnt/data01/wiki_data/train_splits_2/disambiguation_dict_final.json', 'w') as f:
        json.dump(with_fp_disamb_dict, f)
    with open(f'/mnt/data01/wiki_data/train_splits_2/cleaned_fp_data.json', 'w') as f:
        json.dump(cleaned_fp_data, f)

    return with_fp_context_data, with_fp_disamb_dict, cleaned_fp_data


def clean_benchmark_data():

    with open('/mnt/data01/wiki_data/clean_datasets/full_wiki_fp_all_pages.pkl', 'rb') as f:
        fp_data = pickle.load(f)
    with open('/mnt/data01/wiki_data/redirect_move_dict.json') as f:
        redirects = json.load(f)
    with open('/mnt/data01/wiki_data/redirect_all_dict.json') as f:
        all_redirects = json.load(f)

    redirects_clean = {}
    for ent in redirects:
        assert len(redirects[ent]) == 1
        new_name = redirects[ent][0]
        redirects_clean[ent[0].upper() + ent[1:]] = new_name[0].upper() + new_name[1:]

    all_redirects_clean = {}
    for ent in all_redirects:
        assert len(all_redirects[ent]) == 1
        new_name = all_redirects[ent][0]
        all_redirects_clean[ent[0].upper() + ent[1:]] = new_name[0].upper() + new_name[1:]


    for dat in ["ace_entities", "aqaint_entities", "msnbc_entities", "wned_wikipedia", "wned_cweb", "conll_entities"]:
        print(dat)
        with open(f'/mnt/data01/entity/benchmarks/coref/{dat}.json') as f:
            raw_benchmark = json.load(f)

        clean_benchmark = {}
        for ent in raw_benchmark:
            clean_benchmark[ent[0].upper() + ent[1:]] = raw_benchmark[ent]

        print(len(clean_benchmark))

        new_benchmark = {}

        org = 0
        now = 0
        for ent in tqdm(clean_benchmark):
            if ent in fp_data:
                new_benchmark[ent] = clean_benchmark[ent]
            else:
                org += 1
                if ent not in redirects_clean and ent not in all_redirects_clean:
                    now += 1
                    print(ent)
        print(org)
        print(now)


if __name__ == '__main__':

    # # # Clean data
    # final_context_data, clean_disamb_dict = clean_data()
    # describe_disambig_dicts(clean_disamb_dict)

    # with_fp_context_data, with_fp_disamb_dict, cleaned_fp_data = clean_fp_data(final_context_data, clean_disamb_dict)

    # clean_benchmark_data()
    #### NEED TO ADD SAVE TO THIS AND CHANGE TO USE CLEANED FIRST PARAGRAPHS

    with open(f'/mnt/data01/wiki_data/train_splits_2/disambiguation_dict_final.json') as f:
        with_fp_disamb_dict = json.load(f)

    # Create splits
    positives(with_fp_disamb_dict)
    negatives(with_fp_disamb_dict)

    # # describe_disambig_dicts(with_fp_disamb_dict)
