#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import List, Tuple, Set

import spacy

from normalizer_utils import check_overlap_with_spans, remove_overlapping_spans_from_back
import normalizer_rules as identifiers


MAX_TEXT_LEN = 100000


def normalize_emails_and_urls(text, show_details=False):

    email_spans = identifiers.extract_emails(text)  # type: List[Tuple[Tuple[int, int], str]]
    url_cand_spans = identifiers.extract_urls(text)  # type: List[Tuple[Tuple[int, int], str]]
    #num_spans = util.extract_decimal_numbers(text)

    url_spans = []

    # Remove url spans that overlap with email spans
    for url_cand_span in url_cand_spans:
        (start_offset, end_offset), span_text = url_cand_span
        if not check_overlap_with_spans(start_offset, end_offset, email_spans):
            url_spans.append(url_cand_span)

    # Distinguish spans by inserting ID type to be used during normalization later
    email_id_spans = [(offset, text, identifiers.IDENTIFIER_EMAIL)
                      for offset, text in identifiers.extract_emails(text)]
    onion_url_id_spans = []
    normal_url_id_spans = []

    for offset, span_text in url_spans:
        if span_text.lower().endswith('.onion'):
            onion_url_id_spans.append((offset, span_text, identifiers.IDENTIFIER_ONION_URL))
        else:
            normal_url_id_spans.append((offset, span_text, identifiers.IDENTIFIER_NORMAL_URL))

    filename_id_spans = [(offsets, span_text, identifiers.IDENTIFIER_FILENAME)
                         for offsets, span_text in identifiers.extract_filenames(text)]

    spans = email_id_spans + onion_url_id_spans + normal_url_id_spans + filename_id_spans
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='IDENTIFIERS')

    return text, spans


def normalize_filenames(text, show_details=False):
    spans = [(offset, span_text, identifiers.IDENTIFIER_FILENAME)
             for offset, span_text in identifiers.extract_filenames(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='FILENAME')

    return text, spans


def normalize_ip_addresses(text, show_details=False):
    spans = [(offset, span_text, identifiers.IDENTIFIER_IP_ADDRESS)
             for offset, span_text in identifiers.extract_ip_address(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='IP_ADDRESS')

    return text, spans


def normalize_crypto_addresses(text, show_details=False):
    spans = []

    for span_offset, span_text, crypto_type in identifiers.extract_crypto_addresses(text):
        if crypto_type == 'btc':
            crypto_id = identifiers.IDENTIFIER_BTC_ADDRESS
        elif crypto_type == 'eth':
            crypto_id = identifiers.IDENTIFIER_ETH_ADDRESS
        elif crypto_type == 'ltc':
            crypto_id = identifiers.IDENTIFIER_LTC_ADDRESS
        else:
            crypto_id = identifiers.IDENTIFIER_OTHER_ADDRESS

        spans.append((span_offset, span_text, crypto_id))

    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='CRYPTO_ADDRESS')

    return text, spans


def normalize_quantities(text, show_details=False):

    # Extractor, Identifier
    # NOTE: The order is very important!!
    quantity_funcs = [
        (identifiers.IDENTIFIER_CRYPTO_MONEY, identifiers.extract_crypto_money),
        (identifiers.IDENTIFIER_GENERAL_MONEY, identifiers.extract_general_money),
        (identifiers.IDENTIFIER_LENGTH, identifiers.extract_lengths),
        (identifiers.IDENTIFIER_VOLUME, identifiers.extract_volume),
        (identifiers.IDENTIFIER_WEIGHT, identifiers.extract_weights),
        (identifiers.IDENTIFIER_PERCENTAGE, identifiers.extract_percentage),
        (identifiers.IDENTIFIER_FILESIZE, identifiers.extract_filesize),
        (identifiers.IDENTIFIER_VERSION, identifiers.extract_version),
    ]

    quantity_spans = []

    for quantity_id, quantity_extractor in quantity_funcs:
        quantity_spans += [(span_offset, span_text, quantity_id)
                         for span_offset, span_text in quantity_extractor(text)]

    text = normalize_spans_in_text(text, quantity_spans, show_details=show_details, show_msg='QUANTITIES')

    return text, quantity_spans


def normalize_time(text, show_details=False):

    spans = [(span_offset, span_text, identifiers.IDENTIFIER_TIME)
             for span_offset, span_text in identifiers.extract_time(text)]
    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='TIME')

    return text, spans


def normalize_numbers(text, show_details=False):

    spans = [(span_offset, span_text, identifiers.IDENTIFIER_NUMBER)
             for span_offset, span_text in identifiers.extract_decimal_numbers(text)]

    text = normalize_spans_in_text(text, spans, show_details=show_details, show_msg='NUMBERS')

    return text, spans


def normalize_spans_in_text(text, spans: List[Tuple[Tuple[int, int], str, str]],
                            show_details=False, show_msg=''):

    # Check span overlap
    spans = remove_overlapping_spans_from_back(spans)
    sorted_spans = sorted(spans, key=lambda i: i[0])

    if show_details and sorted_spans:
        print(f'    {show_msg} ({len(sorted_spans)}):')
        for (start_offset, end_offset), span_text, id_type in sorted_spans:
            context = text[max(0, start_offset-20):min(end_offset+20, len(text))].strip().replace('\n', ' ')
            print(f'       {start_offset:4} | {id_type} | {span_text} | {context}')

    # 뒤에서부터 부분 문자열을 하나씩 교체
    for (start_offset, end_offset), span_text, id_type in reversed(sorted_spans):
        text = text[:start_offset] + id_type + text[end_offset:]

    return text


def remove_long_strings(text, limit, preserve_newlines=False, show_details=False):

    if preserve_newlines:
        tokens = text.split(' ')
    else:
        tokens = text.split()

    if show_details:
        print('    Long strings:')
        for t in tokens:
            if len(t) >= limit:
                print(f'       {t}')

    text = ' '.join(t for t in tokens if len(t) < limit)
    return text


def tokenize_brand_names(text, title, show_details=False):
    # Tokenize tokens that appear in the title

    id_type = identifiers.IDENTIFIER_BRAND_NAME
    span_matches = list(identifiers.REGEX_BRAND_NAMES.finditer(text))

    collections = []
    spans = []

    # Examine matches obtained from each potential text area
    for span_match in reversed(span_matches):

        span_text = span_match.group(1)
        span_offsets = span_match.span(1)

        #print('@@@@@', span_text)

        if (span_text.lower() in identifiers.ATOMIC_BRAND_NAMES_LOWERED
                or span_text not in title
                or span_text.startswith('Mc')):
            continue

        #print('!!@@---@', span_text, span_offsets)
        part_span_texts = []
        prev = 0
        curr = 0

        for curr in range(len(span_text)):

            #if curr > 0:
            #    print(' ---!', curr, span_text[curr-1], span_text[curr])

            # 소문자->대문자로 전환되는 부분에서 잘라서 저장
            if curr > 0 \
                    and span_text[curr-1].islower() \
                    and span_text[curr].isupper():
                part_span_text = span_text[prev:curr]
                part_span_texts.append(part_span_text)
                prev = curr

        part_span_texts.append(span_text[prev:curr+1])

        span = span_offsets, span_text, identifiers.IDENTIFIER_BRAND_NAME
        spans.insert(0, span)

        if show_details:
            start_offset, end_offset = span_match.span(0)
            context = text[max(0, start_offset-20):min(end_offset+20, len(text))].strip().replace('\n', ' ')
            collections.insert(0, (span_text, start_offset, context))

        text = text[:span_offsets[0]] + ' '.join(part_span_texts) + text[span_offsets[1]:]

    if collections:
        print(f'    BRAND NAMES ({len(collections)}):')
        for span_text, start_offset, context in collections:
            print(f'       {start_offset:4} | {id_type} | {span_text} | {context}')

    return text, spans


def process_with_spacy(s, spacy_nlp_model):
    #doc = spacy_nlp_model(s.lower(), )
    doc = spacy_nlp_model(s, )
    text = ' '.join(
        #token.text
        (token.lemma_ if not token.text.isupper() else token.text)
        for token in doc
        if True
            #and token.is_alpha
            and not token.is_digit
            and not token.is_punct
            and not token.is_quote
            and not token.is_bracket  # added
            #and token.tag_ in ('NN', 'NNP')
            #and len(token.text) < 20  # ONION URL can be 56+ characters long ...
    )
    # Remove redundant whitespace produced by the ' '.join() command
    # (Note: spaCy treats '\n' as a separate token by default)
    return text.replace('\n ', '\n')


SPACY_NLP_MODEL = None


def preprocess(text, spacy_nlp=None, title='', remove_identifiers=False, show_details=False, get_extra=False):

    if not spacy_nlp:
        global SPACY_NLP_MODEL
        if not SPACY_NLP_MODEL:
            SPACY_NLP_MODEL = spacy.load('en_core_web_sm', exclude=["parser", "ner"])

        spacy_nlp = SPACY_NLP_MODEL

    #show_details = True

    # Truncate text
    text = text[:MAX_TEXT_LEN]

    # Normalize emails & urls
    # ***** To be added: pgp key ===> for now handled by long string removal *****
    text, id_spans = normalize_emails_and_urls(text, show_details=show_details)

    # Normalize filenames
    text, filename_spans = normalize_filenames(text, show_details=show_details)

    # Normalize IP addresses ("192.0.0.1")
    text, ip_address_spans = normalize_ip_addresses(text, show_details=show_details)

    # Normalize cryptocurrency addresses (BTC, ETH, LTC)
    text, crypto_address_spans = normalize_crypto_addresses(text, show_details=show_details)

    # Normalize money/length/volume/weight/percentage/filesize/version (unit&number)
    text, quantity_spans = normalize_quantities(text, show_details=show_details)

    # Normalize temporal expressions ("2020-10-02 09:44:10", "21/09/2019")
    text, time_spans = normalize_time(text, show_details=show_details)

    # Tokenize brand names ("TorLink" => "Tor" & "Link", "CannabisUK" => "Cannabis" & "UK")
    text, brand_name_spans = tokenize_brand_names(text, title, show_details=show_details)

    # Normalize all the remaining numbers
    text, number_spans = normalize_numbers(text, show_details=show_details)

    # spaCy processing
    text = process_with_spacy(text, spacy_nlp)

    # Lowercase (NOTE: Identifiers "ID_XXX" will also be lowered to "id_xxx")
    text = text.lower()

    if remove_identifiers:
        # Restore all the IDs (normalized) back into original uppercase form
        for identifier in identifiers.IDENTIFIERS_ALL:
            #print(f'@@@Replacing "{identifier.lower()}" => "{identifier.upper()}"')
            #print(text)
            text = text.replace(identifier.lower(), '')
    else:
        # Restore all the IDs (normalized) back into original uppercase form
        for identifier in identifiers.IDENTIFIERS_ALL:
            #print(f'@@@Replacing "{identifier.lower()}" => "{identifier.upper()}"')
            #print(text)
            text = text.replace(identifier.lower(), identifier.upper())

    # Remove all the whitespace-split tokens that are too long BEFORE feeding into spaCy
    text = remove_long_strings(text, limit=50, preserve_newlines=True)

    # For debugging ...
    extra = {
        #'brand_names': [span_text for offsets, span_text, id_type in brand_name_spans]
    }

    if get_extra:
        return text, extra
    else:
        return text


def compute_jaccard_similarity_between_two_token_sets(token_set1: Set[str], token_set2: Set[str]) -> float:

    if type(token_set1) != set or type(token_set2) != set:
        raise ValueError('Input must be of type <class \'set\'>')

    union = token_set1.union(token_set2)
    intersection = token_set1.intersection(token_set2)
    if len(union) == 0:
        jaccard_sim = 1.0
    else:
        jaccard_sim = len(intersection) / len(union)
    return jaccard_sim


def demo():

    texts = [
        "To participate, you just need to send from 0.01 BTC to 20 BTC to the contribution address and"
        " we will immediately send you back 0.2 BTC to 40 BTC to the address you sent it from. (x2 back)\n\n"
        "SPECIAL OFFER:\n"
        "If you send 5+ BTC, you will be airdropped 10 BTC back +35% bonus\n\n"
        "Payment Address\n"
        "You can send BTC to the following address:\n"
        "164auQnEcxJQs5ea1WVAtpYFfaKDbDek6T"
        ,
        "iPhone 11 Pro Max 64 GB - $749\n"
        "iPhone 11 Pro Max 256 GB - $899\n"
        "iPhone 11 Pro Max 512 GB - $999"
        ,
        "The wallets have a balance between 10 ₿ and 0.01 ₿, depending on how much I want to get rid of.\n"
        "The price is always 50% of the balance.\n"
        "This wallet has a value of 1.6 BTC and received its balance on 03/20/2020."
        ,
        "Self: /index.php\n"
        "MyURL: http://mgioamqnhbbxkos4.onion:80//index.php\n"
        "Server Address: [127.0.0.1:80]\n"
        "Server Name: 'mgioamqnhbbxkos4.onion'\n"
        "Remote Address: [127.0.0.1] Port 41538"
        ,
        "Specifications:\n"
        "Caliber: 7,62x51mm NATO"
        "Operation: Gas operated rotating bolt\n"
        "Magazine Capacity: 5 - 10 - 20 rounds\n"
        "Length: 1029 mm\n"
        "Barrel Length: 457 mm\n"
        "Weight: 5,440 kg\n"
        "Price on market 13000$"
        ,
        "Jambler.io Partner BTC Mixer Bitcoin\n"
        "Official TOR Mirror:\n"
        "overtsgjd4xmgu25uegho7p3ez47solhiri5xpylcgm2tlofbafrzwid.onion"
    ]

    spacy_nlp_model = spacy.load('en_core_web_sm', exclude=["parser", "ner"])

    for text in texts:

        print('------- Original Text -------')
        print(text)

        text = preprocess(text, spacy_nlp=spacy_nlp_model)

        print('------- Normalized Text --------')
        print(text)
        print()


def demo2():
    text1 = (
        "iPhone 11 Pro Max 64 GB - $749\n"
        "iPhone 11 Pro Max 256 GB - $899\n"
    )
    text2 = (
        "iPhone 12 Pro Max 256 GB - $1099\n"
        "iPhone 12 Pro Max 512 GB - $1299\n"
    )

    orig_text1_token_set = set(text1.split())
    orig_text2_token_set = set(text2.split())

    orig_jaccard_sim = compute_jaccard_similarity_between_two_token_sets(orig_text1_token_set, orig_text2_token_set)

    print('------- Text 1 (original) --------')
    print(text1)
    print('------- Text 2 (original) --------')
    print(text2)
    print(f'Jaccard similarity: {orig_jaccard_sim:.2f}')
    print('\n')

    norm_text1 = preprocess(text1)
    #norm_text1 = preprocess(text1, remove_identifiers=True)
    norm_text2 = preprocess(text2)
    #norm_text2 = preprocess(text2, remove_identifiers=True)

    norm_text1_token_set = set(norm_text1.split())
    norm_text2_token_set = set(norm_text2.split())

    norm_jaccard_sim = compute_jaccard_similarity_between_two_token_sets(norm_text1_token_set, norm_text2_token_set)

    print('------- Text 1 (normalized) --------')
    print(norm_text1)
    print('------- Text 2 (normalized) --------')
    print(norm_text2)
    print(f'Jaccard similarity: {norm_jaccard_sim:.2f}')
    print()


if __name__ == '__main__':
    demo2()
