#!/usr/bin/python3
# Get the distribution of POS in the dataset
# and the ratio of content / function words
# -> Only use sites labeled as english

import os
import re
import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import spacy

nlp = spacy.load("en_core_web_lg", exclude=["ner"])

PATH_BASE = os.getcwd()
PATH_CATEGORY = os.path.join(PATH_BASE, "DUTA_10K_Masked_Dataset")
PATH_LEX = os.path.join(PATH_BASE, "lex_stats_DUTA")
PATH_FIGURE = os.path.join(PATH_BASE, "lexical_figs")

PATH_CSV_POS = os.path.join(PATH_LEX, "pos_dist")
PATH_CSV_CF = os.path.join(PATH_LEX, "content_function")

try:
    os.mkdir(PATH_LEX)
except OSError:
    print("Directory already exists")

try:
    os.mkdir(PATH_CSV_POS)
except OSError:
    print("Directory already exists")

try:
    os.mkdir(PATH_CSV_CF)
except OSError:
    print("Directory already exists")

encodings = ['utf-8', 'windows-1250']

# list of masked tokens
identifier_tokens = ['ID_EMAIL', 'ID_ONION_URL', 'ID_NORMAL_URL', 'ID_IP_ADDRESS',
                     'ID_BTC_ADDRESS', 'ID_ETH_ADDRESS', 'ID_LTC_ADDRESS',
                     'ID_CRYPTO_MONEY', 'ID_GENERAL_MONEY', 'ID_LENGTH', 'ID_WEIGHT',
                     'ID_VOLUME', 'ID_PERCENTAGE', 'ID_VERSION', 'ID_FILENAME',
                     'ID_FILESIZE', 'ID_TIME', 'ID_BRAND_NAME', 'ID_NUMBER']

# list of universal POS tags
pos_tags = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN',
            'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']

# list of files detected as english (not needed as masked data is already only in english)
# with open('english_duta.txt', 'r') as f_eng:
#     english_duta = [line.rstrip('\n') for line in f_eng]

# english_duta = [re.search('/(.+?)/', uri).group(1) + '.txt' for uri in english_duta]

# function / content words, as used in the ACL paper
content_tags = ['ADJ', 'ADV', 'NOUN', 'PROPN', 'VERB', 'X', 'NUM']
function_tags = [tag for tag in pos_tags if tag not in content_tags]

fw = open(os.path.join(PATH_LEX, 'pos_dist_all.txt'), 'w')
fs = open(os.path.join(PATH_LEX, 'pos_dist_concise.txt'), 'w')

categories = os.listdir(PATH_CATEGORY)

token_stats = dict()
cont_cnt_means = {cat: 0 for cat in categories}
func_cnt_means = {cat: 0 for cat in categories}
cont_func_ratios = {cat : 0 for cat in categories}

#########################
pos_by_category = {cat: dict() for cat in categories}
#########################

for category in categories:

    PATH_CUR_CATEGORY = os.path.join(PATH_CATEGORY, category)

    print(f'Currently processing {category} category...')

    fw.write(f'\n-*-*-*-*-{category}-*-*-*-*-\n\n')

    current_category_files = os.listdir(PATH_CUR_CATEGORY)

    # current_cat_en = list(set(current_category_files) & set(english_duta))
    # print(current_cat_en)

    pos_ratio_total = list()
    cont_cnt_total = list()
    func_cnt_total = list()
    cf_ratio = list()

    for files in current_category_files:

        pos_cnt = {pos : 0 for pos in pos_tags}

        filepath = os.path.join(PATH_CUR_CATEGORY, files)

        cont_cnt = 0
        func_cnt = 0

        for e in encodings:
                try:
                    with open(filepath, 'r', encoding=e) as f:
                        content = f.read()

                        # text is too long for spacy
                        if (len(content) > 999999):
                            content = None
                except UnicodeDecodeError:
                    pass
                    #print(f'Got unicode error with {e}, trying a different encoding')
                else:
                    # encoding works
                    break

        if content is not None:

            fw.write(f'{files}:\n')

            # content_parsed = [x for x in re.split('([^a-zA-Z0-9_])', content) if x and x != ' ']
            content_parsed = [x for x in re.split('[^a-zA-Z0-9_]', content) if x]
            content_nomask = [x for x in content_parsed if x not in identifier_tokens]
            content_doc = ' '.join(content_nomask)
            doc = nlp(content_doc)

            for token in doc:

                if any(id_token in token.text for id_token in identifier_tokens):
                    continue

                token_pos = token.pos_
                pos_cnt[token_pos] += 1

                if token_pos in content_tags:
                    cont_cnt += 1
                else:
                    func_cnt += 1

        ####################################
        s = sum(pos_cnt.values())
        if s == 0:
            pos_ratio = {k: 0 for k, v in pos_cnt.items()}
        else:
            pos_ratio = {k: v/s for k, v in pos_cnt.items()}

        for k, v in pos_cnt.items():
            
            try:
                ratio = v / s
            except ZeroDivisionError:
                ratio = 0

            fw.write(f'{k:<8}: {v} ({ratio})\n')
        
        pos_ratio_total.append(pos_ratio)
        ####################################

        if (func_cnt != 0):
            cont_func_ratio = cont_cnt / func_cnt
            cf_ratio.append(cont_func_ratio)
        else:
            cont_func_ratio = 'undef'

        cont_cnt_total.append(cont_cnt)
        func_cnt_total.append(func_cnt)

        fw.write(f'Number of content words : {cont_cnt}\n')
        fw.write(f'Number of function words: {func_cnt}\n')
        fw.write(f'Content / Func ratio    : {cont_func_ratio}\n')
        fw.write('\n')

    cont_cnt_means[category] = sum(cont_cnt_total) / len(cont_cnt_total)
    func_cnt_means[category] = sum(func_cnt_total) / len(func_cnt_total)
    cont_func_ratios[category] = sum(cf_ratio) / len(cf_ratio)

    ###################################
    mean_pos_by_cat = {pos : sum([item[pos] for item in pos_ratio_total]) \
                             / len(pos_ratio_total) for pos in pos_tags}

    pos_by_category[category] = mean_pos_by_cat
    #print(mean_pos_by_cat)
    ###################################

    # token_cnt_norm = {k : v / num_files for k, v in token_cnt.items()}
    # token_stats[category] = token_cnt_norm

# Export to CSV (POS distribution)
for category in categories:

    csv_file_name = os.path.join(PATH_CSV_POS, category + "_pos.csv")

    df = pd.DataFrame.from_dict(pos_by_category[category], orient='index')
    df.to_csv(csv_file_name)

# Export to CSV (cf ratio)
csv_file_name = os.path.join(PATH_CSV_CF, "cf_ratio.csv")

df = pd.DataFrame.from_dict(cont_func_ratios, orient='index')
df.to_csv(csv_file_name)


fs.write(f'\n-*-*-*-*-CONTENT WORD COUNTS-*-*-*-*-\n\n')
for key, value in cont_cnt_means.items():
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

fs.write(f'\n-*-*-*-*-FUNCTION WORD COUNTS-*-*-*-*-\n\n')
for key, value in func_cnt_means.items():
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

fs.write(f'\n-*-*-*-*-CONTENT-FUNCTION WORD RATIO-*-*-*-*-\n\n')
for key, value in cont_func_ratios.items():
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

######################################
fs.write(f'\n-*-*-*-*-POS RATIO-*-*-*-*-\n\n')
for category, pos_dict in pos_by_category.items():

    fs.write(f'{category}:\n')

    for pos, ratio in pos_dict.items():
        fs.write(f'{pos:<8}: {ratio:.3f}\n')
    
    fs.write('\n')
######################################

fw.close()
fs.close()