#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import re
import pandas as pd
from typing import List
from dataclasses import dataclass

DUTA_TO_CODA_MAP = {
    'Porno': 'Porn',
    'Counterfeit-Credit-Cards': 'Financial',
    'Counterfeit-Money': 'Financial',
    'Counterfeit-Personal-Identification': 'Financial',
    'Cryptocurrency': 'Crypto',
    'Drugs': 'Drugs',
    'Cryptolocker': 'Hacking',
    'Hacking': 'Hacking',
    'Casino': 'Gambling',
    'Violence': 'Arms',  # Note that the Violence category in DUTA covers both hitmen and weapons.
    'Others': 'Others',
}

@dataclass
class DocItem:
    doc_idx: int
    base_filename: str = ''
    text_file_path: str = ''
    raw_text: str = ''
    accepted_category_labels: List[str] = ''


def compress_whitespace_in_text(text, preserve_newline=False) -> str:
    refined_text = text.strip()

    if preserve_newline:
        refined_text = re.sub(r'[ \xa0\t\f\v]+', ' ', refined_text)
        refined_text = re.sub(r'[\n\r][ \xa0\t\f\v]', '\n', refined_text)
        refined_text = re.sub(r'(?:[\n\r][ \xa0\t\f\v]*){3,}', '\n\n', refined_text)
    else:
        refined_text = re.sub(r'\s+', ' ', refined_text)

    return refined_text


def get_text_and_cat(doc_item, text_type):

    if text_type == 'id_removed':
        title = doc_item.title_id_removed

    elif text_type == 'all_id_masked_preprocessed':
        title = doc_item.title_id_masked_preprocessed

    elif text_type == 'all_id_masked' \
            or text_type == 'min_id_masked':
        title = doc_item.title_id_masked

    elif text_type == '' or text_type == 'txt':
        title = doc_item.title

    else:
        raise Exception('Invalid text type: ' + text_type)

    text = doc_item.raw_text
    file_path = os.path.basename(doc_item.text_file_path)

    if title.strip():
        title_and_text = title + '\n' + text
    else:
        title_and_text = text

    cat = doc_item.accepted_category_labels[0]

    return title_and_text, cat, file_path


def iter_corpus_files(root_dir_path, text_type):

    if text_type.startswith('all_id_masked_preprocessed'):
        text_dir_path = root_dir_path + '/txt_all_id_masked_preprocessed'
    elif text_type.startswith('all_id_removed'):
        text_dir_path = root_dir_path + '/txt_all_id_removed'
    elif text_type.startswith('all_id_masked'):
        text_dir_path = root_dir_path + '/txt_all_id_masked'
    elif text_type.startswith('min_id_normalized'):
        text_dir_path = root_dir_path + '/txt_min_id_normalized'
    else:
        raise ValueError('Unsupported text type: ' + text_type)

    text_filenames = sorted(os.listdir(text_dir_path))

    for text_filename in text_filenames:
        anno_idx, category = text_filename.rsplit('.', 1)[0].split('-')

        base_filename = text_filename.replace('.txt', '')
        text_file_path = os.path.join(text_dir_path, text_filename)

        with open(text_file_path, 'r') as f_text:
            text = f_text.read()

        doc_item = DocItem(
            doc_idx=anno_idx,
            base_filename=base_filename,
            text_file_path=text_file_path,
            accepted_category_labels=[category],
            raw_text=text,
        )

        yield doc_item


def load_coda_darkweb_texts(data_path, text_type):
    file_items = list(iter_corpus_files(data_path, text_type=text_type))

    dataset = [get_text_and_cat(doc_item, text_type)
               for doc_item in file_items]

    input_texts, category_values, domains, file_paths = zip(*dataset)
    input_texts = (compress_whitespace_in_text(t) for t in input_texts)

    # Remove all identifiers (token starting with 'ID_')
    input_texts = [' '.join(token for token in input_text.split()
                            if not token.startswith('ID_'))
                   for input_text in input_texts]

    df, input_text_col_name, output_label_col_name, class_names \
        = get_dataframe_items(input_texts, category_values, domains, file_paths)

    return df, input_text_col_name, output_label_col_name, class_names


def load_darkweb_extra_benchmark_data(data_path):
    for filename in sorted(os.listdir(data_path), key=lambda x: int(x.split('-', 1)[0])):
        # filename e.g.) "0001-Drugs-Opiate_Connect.txt"
        idx, category, forum_name = filename.split('-')
        file_path = os.path.join(data_path, filename)

        with open(file_path, 'r') as f:
            text = f.read()
            yield text, category, forum_name


def get_dataframe_items(input_texts, category_values, domains, file_paths):
    assert len(input_texts) == len(category_values) == len(domains) == len(file_paths)

    class_names = sorted(list(set(category_values)))
    map_name_to_label = dict((name, idx) for idx, name in enumerate(class_names))

    category_labels = [map_name_to_label[name] for name in category_values]

    input_text_col_name = 'input_text'
    output_label_col_name = 'category'

    df = pd.DataFrame(data={
        input_text_col_name: input_texts,
        output_label_col_name: category_labels,
        'domain': domains,
        'file_path': file_paths,
    })

    return df, input_text_col_name, output_label_col_name, class_names


if __name__ == '__main__':
    pass
