#!/usr/bin/python3
# Get the distribution of POS in the dataset
# and the ratio of content / function words
# -> Only use sites labeled as english

import os
import re
import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import spacy

nlp = spacy.load("en_core_web_lg", exclude=["ner"])

os.chdir('..')

PATH_BASE = os.getcwd()
PATH_ANALYSIS = os.path.join(PATH_BASE, "corpus_analysis")
PATH_CATEGORY = os.path.join(PATH_ANALYSIS, "categories")
PATH_LEX = os.path.join(PATH_ANALYSIS, "lex_stats")
PATH_FIGURE = os.path.join(PATH_ANALYSIS, "lex_figs")

PATH_CSV_POS = os.path.join(PATH_LEX, "pos_dist")
PATH_CSV_CF = os.path.join(PATH_LEX, "content_function")

try:
    os.mkdir(PATH_LEX)
except OSError:
    print("Directory already exists")

try:
    os.mkdir(PATH_CSV_POS)
except OSError:
    print("Directory already exists")

try:
    os.mkdir(PATH_CSV_CF)
except OSError:
    print("Directory already exists")

# list of masked tokens
identifier_tokens = ['ID_EMAIL', 'ID_ONION_URL', 'ID_NORMAL_URL', 'ID_IP_ADDRESS',
                     'ID_BTC_ADDRESS', 'ID_ETH_ADDRESS', 'ID_LTC_ADDRESS',
                     'ID_CRYPTO_MONEY', 'ID_GENERAL_MONEY', 'ID_LENGTH', 'ID_WEIGHT',
                     'ID_VOLUME', 'ID_PERCENTAGE', 'ID_VERSION', 'ID_FILENAME',
                     'ID_FILESIZE', 'ID_TIME', 'ID_BRAND_NAME', 'ID_NUMBER']

# list of universal POS tags
pos_tags = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN',
            'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']

# function / content words, as used in the ACL paper
content_tags = ['ADJ', 'ADV', 'NOUN', 'PROPN', 'VERB', 'X', 'NUM']
function_tags = [tag for tag in pos_tags if tag not in content_tags]

fw = open(os.path.join(PATH_LEX, 'pos_dist_all.txt'), 'w')
fs = open(os.path.join(PATH_LEX, 'pos_dist_concise.txt'), 'w')

categories = os.listdir(PATH_CATEGORY)

token_stats = dict()
cont_cnt_means = {cat: 0 for cat in categories}
func_cnt_means = {cat: 0 for cat in categories}
cont_func_ratios = {cat : 0 for cat in categories}
pos_by_category = {cat: dict() for cat in categories}

for category in categories:

    PATH_CUR_CATEGORY = os.path.join(PATH_CATEGORY, category)

    print(f'Currently processing {category} category...')

    fw.write(f'\n-*-*-*-*-{category}-*-*-*-*-\n\n')

    current_category_files = os.listdir(PATH_CUR_CATEGORY)
    # num_files = len(current_category_files)

    pos_ratio_total = list()
    cont_cnt_total = list()
    func_cnt_total = list()
    cf_ratio = list()
    
    for files in current_category_files:

        pos_cnt = {pos : 0 for pos in pos_tags}
        
        filepath = os.path.join(PATH_CUR_CATEGORY, files)

        cont_cnt = 0
        func_cnt = 0

        with open(filepath, 'r') as f:

            fw.write(f'{files}:\n')

            content = f.read()
        
            if content is not None:
                
                # content_parsed = [x for x in re.split('([^a-zA-Z0-9_])', content) if x and x != ' ']
                content_parsed = [x for x in re.split('[^a-zA-Z0-9_]', content) if x]
                content_nomask = [x for x in content_parsed if x not in identifier_tokens]
                content_doc = ' '.join(content_nomask)
                doc = nlp(content_doc)

                for token in doc:
                    
                    if any(id_token in token.text for id_token in identifier_tokens):
                        continue

                    token_pos = token.pos_
                    pos_cnt[token_pos] += 1

                    if token_pos in content_tags:
                        cont_cnt += 1
                    else:
                        func_cnt += 1

            s = sum(pos_cnt.values())
            if s == 0:
                pos_ratio = {k: 0 for k, v in pos_cnt.items()}
            else:
                pos_ratio = {k: v/s for k, v in pos_cnt.items()}

            for k, v in pos_cnt.items():

                try:
                    ratio = v / s
                except ZeroDivisionError:
                    ratio = 0

                fw.write(f'{k:<8}: {v} ({ratio})\n')

            pos_ratio_total.append(pos_ratio)

            if (func_cnt != 0): 
                cont_func_ratio = cont_cnt / func_cnt
                cf_ratio.append(cont_func_ratio)
            else:
                cont_func_ratio = 'undef'

            cont_cnt_total.append(cont_cnt)
            func_cnt_total.append(func_cnt)

            fw.write(f'Number of content words : {cont_cnt}\n')
            fw.write(f'Number of function words: {func_cnt}\n')
            fw.write(f'Content / Func ratio    : {cont_func_ratio}\n')
            fw.write('\n')

    cont_cnt_means[category] = sum(cont_cnt_total) / len(cont_cnt_total)
    func_cnt_means[category] = sum(func_cnt_total) / len(func_cnt_total)
    cont_func_ratios[category] = sum(cf_ratio) / len(cf_ratio)

    mean_pos_by_cat = {pos : sum([item[pos] for item in pos_ratio_total]) \
                             / len(pos_ratio_total) for pos in pos_tags}

    pos_by_category[category] = mean_pos_by_cat

    # token_cnt_norm = {k : v / num_files for k, v in token_cnt.items()}
    # token_stats[category] = token_cnt_norm

# Export to CSV (POS distribution)
for category in categories:

    csv_file_name = os.path.join(PATH_CSV_POS, category + "_pos.csv")

    df = pd.DataFrame.from_dict(pos_by_category[category], orient='index')
    df.to_csv(csv_file_name)

# Export to CSV (content words and function words)
csv_file_name = os.path.join(PATH_CSV_CF, category + "_cf.csv")

df = pd.DataFrame.from_dict(cont_func_ratios, orient='index')
df.to_csv(csv_file_name)


fs.write(f'\n-*-*-*-*-CONTENT WORD COUNTS-*-*-*-*-\n\n')
for key, value in cont_cnt_means.items():    
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

fs.write(f'\n-*-*-*-*-FUNCTION WORD COUNTS-*-*-*-*-\n\n')
for key, value in func_cnt_means.items():    
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

fs.write(f'\n-*-*-*-*-CONTENT-FUNCTION WORD RATIO-*-*-*-*-\n\n')
for key, value in cont_func_ratios.items():    
    fs.write(f'{key:<8}: {value:.3f}\n')
    fs.write('\n')

fs.write(f'\n-*-*-*-*-POS RATIO-*-*-*-*-\n\n')
for category, pos_dict in pos_by_category.items():

    fs.write(f'{category}:\n')

    for pos, ratio in pos_dict.items():
        fs.write(f'{pos:<8}: {ratio:.3f}\n')

    fs.write('\n')

fw.close()
fs.close()

# pp = pprint.PrettyPrinter(indent=4)
# pp.pprint(token_stats)


    # fs.write('Mean   : [{:<10.2f}, {:<8.2f}, {:.3f}]\n'.format(wordcnt_arr.mean(), \
    #                                                            uniqcnt_arr.mean(), \
    #                                                            ttr_arr.mean()))
    # fs.write('Median : [{:<10}, {:<8}, {:.3f}]\n'.format(np.median(wordcnt_arr), \
    #                                                      np.median(uniqcnt_arr), \
    #                                                      np.median(ttr_arr)))
    # fs.write('Min    : [{:<10}, {:<8}, {:.3f}]\n'.format(np.amin(wordcnt_arr), \
    #                                                      np.amin(uniqcnt_arr), \
    #                                                      np.amin(ttr_arr)))
    # fs.write('Max    : [{:<10}, {:<8}, {:.3f}]\n'.format(np.amax(wordcnt_arr), \
    #                                                      np.amax(uniqcnt_arr), \
    #                                                      np.amax(ttr_arr)))


# boxplot for statistics

# try:
#     os.mkdir(PATH_FIGURE)
# except OSError:
#     print('Directory already created')

# fig_wordcnt, wc_x = plt.subplots()
# wc_x.set(
#     title='Word count for each CoDA category',
#     ylabel='Word count'
# )
# wc_x.boxplot(word_counts, showfliers=False)
# wc_x.set_xticklabels(categories, rotation=60, fontsize=6)
# fig_wordcnt.set_size_inches(42,8)
# fig_wordcnt.tight_layout()

# fig_uniqcnt, uq_x = plt.subplots()
# uq_x.set(
#     title='Unique word count for each CoDA category',
#     ylabel='Unique word count'
# )
# uq_x.boxplot(uniq_counts, showfliers=False)
# uq_x.set_xticklabels(categories, rotation=60, fontsize=6)
# fig_uniqcnt.set_size_inches(42,8)
# fig_uniqcnt.tight_layout()

# fig_ttr, ttr_x = plt.subplots()
# ttr_x.set(
#     title='TTR for each CoDA category',
#     ylabel='Type-Token Ratio'
# )
# ttr_x.boxplot(ttr_ratio)
# ttr_x.set_xticklabels(categories, rotation=60, fontsize=6)
# fig_ttr.set_size_inches(42,8)
# fig_ttr.tight_layout()

# fig_wordcnt.savefig(os.path.join(PATH_FIGURE, 'wordcnt_bydomain.png'), dpi=160)
# fig_uniqcnt.savefig(os.path.join(PATH_FIGURE, 'uniqcnt_bydomain.png'), dpi=160)
# fig_ttr.savefig(os.path.join(PATH_FIGURE, 'ttr_bydomain.png'), dpi=160)

