from glob import glob
import json
from tqdm import tqdm
from random import sample, shuffle, seed
import re


def open_and_clean_data(path_pattern):

    print("\n Opening and cleaning data ...")

    all_positions = []
    for file in glob(path_pattern):
        print(file)
        with open(file) as f:
            all_positions.extend(json.load(f))

    # Aggregate by individual
    all_individuals = {}
    poorly_formed_entries = 0

    for pos in all_positions:
        ind_id = pos['item'].split("/")[-1]
        try:
            end = int(pos['end'][0:4])
            start = int(pos['start'][0:4])
        except:
            poorly_formed_entries += 1
            continue

        if ind_id not in list(all_individuals.keys()):
            all_individuals[ind_id] = {
                "name text": pos['itemLabel'],
                "link count": int(pos['linkcount']),
                "all starts": [start],
                "all ends": [end],
                "aliases": {pos['itemLabel']}
            }
            if "alias" in list(pos.keys()):
                all_individuals[ind_id]["aliases"].add(pos['alias'])
        else:
            assert all_individuals[ind_id]["name text"] == pos['itemLabel']
            all_individuals[ind_id]["all starts"].append(start)
            all_individuals[ind_id]["all ends"].append(end)
            if "alias" in list(pos.keys()):
                all_individuals[ind_id]["aliases"].add(pos['alias'])

    print(f"{len(all_individuals)} unique individuals")
    print(f'{poorly_formed_entries} entries removed because data was not in the right format')

    # Create range of active years
    for ind in list(all_individuals.keys()):
        all_individuals[ind]['first start'] = min(all_individuals[ind]['all starts'])
        all_individuals[ind]['last end'] = max(all_individuals[ind]['all ends'])
        all_individuals[ind]['aliases'] = list(all_individuals[ind]['aliases'])
        all_individuals[ind]['aliases_lower'] = [alias.lower() for alias in all_individuals[ind]['aliases']]

    return all_individuals


def stratify_and_sample(all_individuals):

    print("\n Sampling names to search ...")

    # Split into strata
    first_strata_ids = [ind for ind in list(all_individuals.keys()) if all_individuals[ind]['link count'] >= 30]
    second_strata_ids = [ind for ind in list(all_individuals.keys()) if (30 > all_individuals[ind]['link count'] >= 5)]
    third_strata_ids = [ind for ind in list(all_individuals.keys()) if all_individuals[ind]['link count'] < 5]

    print(f'{len(first_strata_ids)} in first strata')
    print(f'{len(second_strata_ids)} in second strata')
    print(f'{len(third_strata_ids)} in third strata')

    # Sample
    seed(42)
#    sample_ids = first_strata_ids + sample(second_strata_ids, 500) + sample(third_strata_ids, 500)
    sample_ids = first_strata_ids + second_strata_ids + third_strata_ids

    # Pull all searches
    sample_names = []
    for ind in sample_ids:
        alias_list = all_individuals[ind]['aliases_lower']
        for alias in alias_list:
            sample_names.append(alias)

    print(f'Sample size: {len(sample_ids)}')
    print(f'{len(sample_names)} terms to search')

    return sample_names, sample_ids


def quick_search(sample_names, path_pattern):

    print("\n Searching for all names ...")

    results = []
    for path in tqdm(glob(path_pattern)):
        with open(path) as f:
            file = json.load(f)
            for key in list(file.keys()):
                for art in file[key]:
                    art['article'] = clean_ocr_text(art['article'])
                    art['lower'] = art['article'].lower()
                    if any(search_term in art['lower'] for search_term in sample_names):
                        results.append(art)

    print(f'{len(results)} results found')

    with open('/mnt/data02/entity/large_search_results.json', 'w') as f:
        json.dump(results, f, indent=4)

    return results


def clean_ocr_text(text, basic=True):
    """
    Given
    - string of text,
    - whether (True/False) to do only basic newline cleaning, and
    - the list of characters to remove (if basic=False),
    returns a tuple containing
    (1) the text after applying the desired cleaning operations, and
    (2) a list of integers indicating, for each character in original text,
        how many positions to the left that character is offset to arrive at cleaned text.
    When basic is False, also replaces 'é', 'ï', 'ﬁ', and 'ﬂ'.
    In all cases, hyphen-newline ("-\n") sequences are removed, lone newlines are
    converted to spaces, and sequences of consecutive newlines are kept unchanged
    in order to indicate paragraph boundaries.
    """
    remove_list = ["#","/","*","@","~","¢","©","®","°"]

    # Code to deal with unwanted symbols
    cleaned_text = text.replace("-\n", "")

    if not basic:
        cleaned_text = cleaned_text.replace("é", "e").replace("ï", "i").replace("ﬁ", "fi").replace("ﬂ", "fl")
        cleaned_text = cleaned_text.translate({ord(x): '' for x in remove_list})

    # Code to deal with newline and double newline

    cleaned_text = re.sub(r'(\n\s*)+\n+', '\n\n', cleaned_text)

    z = 0
    while z < (len(cleaned_text)-1):  # Check from the first to before last index
        if cleaned_text[z] == "\n" and cleaned_text[z+1] == "\n":
            z += 2
        elif cleaned_text[z] == "\n" and cleaned_text[z+1] != "\n":
            temp = list(cleaned_text)
            temp[z] = " "
            cleaned_text = "".join(temp)
            z += 1
        else:
            z += 1

    if cleaned_text[len(cleaned_text)-1] == "\n" and cleaned_text[len(cleaned_text)-2] != "\n":  # Check if the last index is a new line
        temp = list(cleaned_text)
        temp[len(cleaned_text)-1] = " "
        cleaned_text = "".join(temp)

    return cleaned_text


def sort_search_results_by_name(search_results, sample_ids, all_individuals):

    print("\n Separating results by name ...")

    no_results_found = 0
    total_results_found = 0
    sorted_results = {}

    for ind_id in tqdm(sample_ids):

        aliases_lower = all_individuals[ind_id]['aliases_lower']

        # Pull relevant search results
        name_results = []
        for result in search_results:
            if any(alias in result['lower'] for alias in aliases_lower):
                name_results.append(result)

        total_results_found += len(name_results)

        if len(name_results) == 0:
            no_results_found += 1

        else:
            sorted_results[ind_id] = name_results

    with open('/mnt/data02/entity/large_results_by_name.json', 'w') as f:
        json.dump(sorted_results, f, indent=4)

    print(f'{no_results_found} names with no results found')
    print(f'{total_results_found} results found in total')

    return sorted_results


def prep_for_ls(sorted_results, all_individuals, outfile):

    too_small = 0
    ls_outputs = []

    # Remove weird substring matches
    print("Removing weird substring matches ...")
    removed = 0
    for ind_id in tqdm(list(sorted_results.keys())):

        alias_list = ['(\W|\s|^)' + alias + '(\W|\s|$)' for alias in all_individuals[ind_id]['aliases_lower']]
        aliases_str = '|'.join(alias_list)

        updated_results = []
        for result in sorted_results[ind_id]:
            if re.search(r"(?=(" + aliases_str + r"))", result['lower']):
                updated_results.append(result)

        if len(updated_results) == 0:
            del sorted_results[ind_id]
            removed += 1
        else:
            sorted_results[ind_id] = updated_results

    print(f'{removed} individuals removed')

    print("Formatting for label studio ...")
    for ind_id in tqdm(list(sorted_results.keys())):

        name = all_individuals[ind_id]['name text']
        first_start = all_individuals[ind_id]['first start']
        last_end = all_individuals[ind_id]['last end']
        aliases = all_individuals[ind_id]['aliases']

        # Split into active and non-active years
        active_results = []
        maybe_active_results = []
        non_active_results = []

        for result in sorted_results[ind_id]:
            text_year = int(result['id'].split("-")[-3])

            if text_year >= first_start:
                active_results.append(result)
            elif text_year >= first_start - 10:
                maybe_active_results.append(result)
            else:
                non_active_results.append(result)

        if len(active_results) < 2 or len(non_active_results) < 2:
            too_small += 1
            continue

        shuffle(active_results)
        shuffle(maybe_active_results)
        shuffle(non_active_results)

        if len(active_results) > 8:
            active_results = active_results[:8]
        if len(maybe_active_results) > 8:
            maybe_active_results = maybe_active_results[:8]
        if len(non_active_results) > 16:
            non_active_results = non_active_results[:16]

        chosen_results = active_results + maybe_active_results + non_active_results
        shuffle(chosen_results)

        # Find spans to highlight
        for result in chosen_results:

            result['article'] = result['article'].replace("\n", " ")
            result['lower'] = result['lower'].replace("\n", " ")

            result['spans'] = []
            for alias in all_individuals[ind_id]['aliases_lower']:
                result['spans'].extend([[s.start(), s.end()] for s in re.finditer(alias, result['lower'])])

            result['chosen_span'] = sample(result['spans'], 1)[0]

        # Pad to 32 examples
        while len(chosen_results) < 32:
            chosen_results.append({
                'id': "Empty",
                'headline': 'Empty',
                'byline': 'Empty',
                'article': 'Empty',
                'chosen_span': [0, 0]
            })

        # Create ls output
        if len(active_results) > 0:
            art_id = active_results[0]['id']
        else:
            art_id = non_active_results[0]['id']

        ls_example = {
            "id": art_id,
            "data": {
                "entity_search": name,
                "wiki_id": ind_id,
                "aliases": ", ".join(aliases),
                "active_years": "Active years: " + str(first_start) + "-" + str(last_end),
            },
            "predictions": [{"result": []}],   # for inputting spans to highlight
            }

        for i, chosen in enumerate(chosen_results):
            if chosen['id'] == "Empty":
                year = "None"
            else:
                year = chosen['id'].split("-")[-3]

            ls_example['data'][f"art_id_{i}"] = chosen['id']
            ls_example['data'][f"headline{i}"] = chosen['headline']
            ls_example['data'][f"year{i}"] = "Year of article: " + year
            ls_example['data'][f"byline{i}"] = chosen['byline']
            ls_example['data'][f"text{i}"] = chosen['article']

            ls_example['predictions'][0]['result'].append(
                {
                    "value":
                        {
                            "start": chosen['chosen_span'][0],
                            "end": chosen['chosen_span'][1],
                            "text": chosen['article'][chosen['chosen_span'][0]:chosen['chosen_span'][1]],
                            "labels": ["entity"]
                        },
                    "from_name": "label",
                    "to_name": f"text{i}",
                    "type": "labels"
                }
            )

        ls_outputs.append(ls_example)

    print(f'{too_small} names with <2 results during active years or <2 results during non-active years')
    print(f'{len(ls_outputs)} chosen for labelling')

    shuffle(ls_outputs)

    with open(outfile, 'w') as o:
        json.dump(ls_outputs, o, indent=4)


if __name__ == '__main__':

    all_individuals = open_and_clean_data(path_pattern='/mnt/data02/entity/wikidata_search_results/politicians_alias/**')

    # sample_names, sample_ids = stratify_and_sample(all_individuals)

    # search_results = quick_search(
    #     sample_names,
    #     path_pattern='/mnt/data02/retrieval/preprocess/random_sample/**/ocr_*'
    #     # path_pattern='/mnt/data02/retrieval/preprocess/rule_based_outputs/**/**/ocr_*'
    # )

    # with open('/mnt/data02/entity/large_search_results.json') as f:
    #     search_results = json.load(f)
    #
    # sorted_results = sort_search_results_by_name(
    #     search_results,
    #     sample_ids,
    #     all_individuals,
    # )

    # with open('/mnt/data02/entity/large_results_by_name.json') as f:
    #     sorted_results = json.load(f)
    #
    # prep_for_ls(
    #     sorted_results,
    #     all_individuals,
    #     outfile='/mnt/data02/entity/politicians_to_label_full_regex.json'
    # )

    with open('/mnt/data02/entity/politicians_to_label_full_regex.json') as f:
        all_data = json.load(f)
    with open('/mnt/data02/entity/politicians_to_label_full.json') as f:
        congruence_data = json.load(f)

    print(len(congruence_data))
    congruence_people = [samp["data"]["entity_search"] for samp in congruence_data]

    output = [dat for dat in all_data if dat["data"]["entity_search"] not in congruence_people]

    output_a = output[:len(output)//2]
    output_b = output[len(output)//2:]

    with open('/mnt/data02/entity/politicians_a.json', 'w') as o:
        json.dump(output_a, o, indent=4)
    with open('/mnt/data02/entity/politicians_b.json', 'w') as o:
        json.dump(output_b, o, indent=4)

