#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import List, Tuple

import spacy
from bert_text_classifier import identifiers, util_text

MAX_TEXT_LEN = 100000


def normalize_emails_and_urls(text, show_details=False):

    email_spans = identifiers.extract_emails(text)  # type: List[Tuple[Tuple[int, int], str]]
    url_cand_spans = identifiers.extract_urls(text)  # type: List[Tuple[Tuple[int, int], str]]

    url_spans = []

    # Remove url spans that overlap with email spans
    for url_cand_span in url_cand_spans:
        (start_offset, end_offset), span_text = url_cand_span
        if not util_text.check_overlap_with_spans(start_offset, end_offset, email_spans):
            url_spans.append(url_cand_span)

    # Distinguish spans by inserting ID type to be used during normalization later
    email_id_spans = [(offset, text, identifiers.IDENTIFIER_EMAIL)
                      for offset, text in identifiers.extract_emails(text)]
    onion_url_id_spans = []
    normal_url_id_spans = []

    for offset, span_text in url_spans:
        if span_text.lower().endswith('.onion'):
            onion_url_id_spans.append((offset, span_text, identifiers.IDENTIFIER_ONION_URL))
        else:
            normal_url_id_spans.append((offset, span_text, identifiers.IDENTIFIER_NORMAL_URL))

    spans = email_id_spans + onion_url_id_spans + normal_url_id_spans
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='IDENTIFIERS')

    return text, spans


def normalize_filenames(text, show_details=False):
    spans = [(offset, span_text, identifiers.IDENTIFIER_FILENAME)
             for offset, span_text in identifiers.extract_filenames(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='FILENAME')

    return text, spans


def normalize_ip_addresses(text, show_details=False):
    spans = [(offset, span_text, identifiers.IDENTIFIER_IP_ADDRESS)
             for offset, span_text in identifiers.extract_ip_address(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='IP_ADDRESS')

    return text, spans


def normalize_crypto_addresses(text, show_details=False):
    spans = []

    for span_offset, span_text, crypto_type in identifiers.extract_crypto_addresses(text):
        if crypto_type == 'btc':
            crypto_id = identifiers.IDENTIFIER_BTC_ADDRESS
        elif crypto_type == 'eth':
            crypto_id = identifiers.IDENTIFIER_ETH_ADDRESS
        elif crypto_type == 'ltc':
            crypto_id = identifiers.IDENTIFIER_LTC_ADDRESS
        else:
            crypto_id = identifiers.IDENTIFIER_OTHER_ADDRESS

        spans.append((span_offset, span_text, crypto_id))

    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='CRYPTO_ADDRESS')

    return text, spans


def normalize_quantities(text, show_details=False):

    # Extractor, Identifier
    # NOTE: The order is very important!!
    quantity_funcs = [
        (identifiers.IDENTIFIER_CRYPTO_MONEY, identifiers.extract_crypto_money),
        (identifiers.IDENTIFIER_GENERAL_MONEY, identifiers.extract_general_money),
        (identifiers.IDENTIFIER_LENGTH, identifiers.extract_lengths),
        (identifiers.IDENTIFIER_VOLUME, identifiers.extract_volume),
        (identifiers.IDENTIFIER_WEIGHT, identifiers.extract_weights),
        (identifiers.IDENTIFIER_PERCENTAGE, identifiers.extract_percentage),
        (identifiers.IDENTIFIER_FILESIZE, identifiers.extract_filesize),
        (identifiers.IDENTIFIER_VERSION, identifiers.extract_version),
    ]

    quantity_spans = []

    for quantity_id, quantity_extractor in quantity_funcs:
        quantity_spans += [(span_offset, span_text, quantity_id)
                           for span_offset, span_text in quantity_extractor(text)]

    text = normalize_spans_in_text(text, quantity_spans, show_details=show_details, show_msg='QUANTITIES')

    return text, quantity_spans


def normalize_time(text, show_details=False):

    spans = [(span_offset, span_text, identifiers.IDENTIFIER_TIME)
             for span_offset, span_text in identifiers.extract_time(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='TIME')

    return text, spans


def normalize_numbers(text, show_details=False):

    spans = [(span_offset, span_text, identifiers.IDENTIFIER_NUMBER)
             for span_offset, span_text in identifiers.extract_decimal_numbers(text)]

    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='NUMBERS')

    return text, spans


def normalize_spans_in_text(text, spans: List[Tuple[Tuple[int, int], str, str]],
                            show_details=False, show_msg=''):

    # Check span overlap
    spans = util_text.remove_overlapping_spans_from_back(spans)
    sorted_spans = sorted(spans, key=lambda i: i[0])

    if show_details and sorted_spans:
        print(f'    {show_msg} ({len(sorted_spans)}):')
        for (start_offset, end_offset), span_text, id_type in sorted_spans:
            context = text[max(0, start_offset-20):min(end_offset+20, len(text))].strip().replace('\n', ' ')
            print(f'       {start_offset:4} | {id_type} | {span_text} | {context}')

    for (start_offset, end_offset), span_text, id_type in reversed(sorted_spans):
        text = text[:start_offset] + id_type + text[end_offset:]

    return text


def remove_long_strings(text, limit, preserve_newlines=False, show_details=False):

    if preserve_newlines:
        tokens = text.split(' ')
    else:
        tokens = text.split()

    if show_details:
        print('    Long strings:')
        for t in tokens:
            if len(t) >= limit:
                print(f'       {t}')

    text = ' '.join(t for t in tokens if len(t) < limit)
    return text


def iter_spacy_tokens(s, spacy_nlp_model, no_stopwords=False, no_common_words=False):
    doc = spacy_nlp_model(s, )
    for token in doc:
        if (not token.is_digit
                and not token.is_punct
                and not token.is_quote
                and not token.is_bracket
                and (not no_common_words
                     #or token.pos_ == 'PROPN'
                     or token.lemma_.lower() not in util_text.ENGLISH_COMMON_WORDS)
                and (not no_stopwords
                     or token.text.lower() not in util_text.ENGLISH_STOPWORDS)
                and (token.text != "'s" and token.text != "’s")
                #and token.tag_ in ('NN', 'NNP')
                #and len(token.text) < 20  # ONION URL can be 56+ characters long ...
        ):
            yield token


def process_with_spacy(s, spacy_nlp_model, no_stopwords=False, no_common_words=False, lemma=True):

    if lemma:
        token_text = lambda token: (token.lemma_ if not token.text.isupper() else token.text)
    else:
        token_text = lambda token: token.text

    text = ' '.join(
        token_text(token)
        for token in iter_spacy_tokens(s, spacy_nlp_model,
                                       no_stopwords=no_stopwords,
                                       no_common_words=no_common_words)
    )

    # =======================================
    #    Correct spaCy's mis-tokenization
    # =======================================
    id_num = identifiers.IDENTIFIER_NUMBER

    i = text.find(f'{id_num}]ID_')
    if i == 0 or (i >= 1 and text[i-1] != '['):
        text = text.replace(f'{id_num}]ID_', f'{id_num} ID_')

    i = text.find(f'{id_num})ID_')
    if i == 0 or (i >= 1 and text[i-1] != '('):
        text = text.replace(f'{id_num})ID_', f'{id_num} ID_')

    # Remove redundant whitespace produced by the ' '.join() command
    # (Note: spaCy treats '\n' as a separate token by default)
    return text.replace('\n ', '\n')


SPACY_NLP_MODEL = None


def preprocess(text, title='', spacy_nlp=None, minimal_normalize=False, remove_identifiers=False,
               spacy_preproc=True, show_details=False, lowercase=True,
               truncate=True, remove_long_tokens=True,
               get_extra=False):

    # Truncate text
    if truncate:
        text = text[:MAX_TEXT_LEN]
    #collected_spans = []

    # Normalize emails & urls
    text, id_spans = normalize_emails_and_urls(text, show_details=show_details)

    if not minimal_normalize:
        # Normalize filenames
        text, filename_spans = normalize_filenames(text, show_details=show_details)

    # Normalize IP addresses
    text, ip_address_spans = normalize_ip_addresses(text, show_details=show_details)

    # Normalize cryptocurrency addresses (BTC, ETH, LTC)
    text, crypto_address_spans = normalize_crypto_addresses(text, show_details=show_details)

    if not minimal_normalize:
        # Normalize money/length/volume/weight/percentage/filesize/version (unit&number)
        text, quantity_spans = normalize_quantities(text, show_details=show_details)

        # Normalize temporal expressions ("2020-10-02 09:44:10", "21/09/2019")
        text, time_spans = normalize_time(text, show_details=show_details)

        # Normalize all the remaining numbers
        text, number_spans = normalize_numbers(text, show_details=show_details)

    # spaCy processing
    if spacy_preproc:
        if not spacy_nlp:
            global SPACY_NLP_MODEL
            if not SPACY_NLP_MODEL:
                SPACY_NLP_MODEL = spacy.load('en_core_web_sm', exclude=["parser", "ner"])

            spacy_nlp = SPACY_NLP_MODEL
        text = process_with_spacy(text, spacy_nlp)

    # Lowercase (NOTE: Identifiers "ID_XXX" will also be lowered to "id_xxx")
    if lowercase:
        text = text.lower()

    # Remove all the normalized IDs
    if remove_identifiers:
        for identifier in identifiers.IDENTIFIERS_ALL:
            if lowercase:
                identifier = identifier.lower()
            text = text.replace(identifier, '')
    # Restore all the normalized IDs back into original uppercase form
    elif lowercase:
        for identifier in identifiers.IDENTIFIERS_ALL:
            text = text.replace(identifier.lower(), identifier.upper())

    # Remove all the whitespace-split tokens that are too long BEFORE feeding into spaCy
    if remove_long_tokens:
        text = remove_long_strings(text, limit=50, preserve_newlines=True)

    # For debugging ...
    extra = {
    }

    if get_extra:
        return text, extra
    else:
        return text

