import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

PATH_BASE = os.getcwd()
PATH_LEX = os.path.join(PATH_BASE, "lex_stats_DUTA")
PATH_CATEGORY = os.path.join(PATH_BASE, "DUTA_10K_Masked_Dataset")

PATH_CSV = os.path.join(PATH_LEX, "mask_dist")

try:
    os.mkdir(PATH_CSV)
except OSError:
    print("Directory already exists")

encodings = ['utf-8', 'windows-1250']

# list of masked tokens
identifier_tokens = ['ID_EMAIL', 'ID_ONION_URL', 'ID_NORMAL_URL', 'ID_IP_ADDRESS',
                     'ID_BTC_ADDRESS', 'ID_ETH_ADDRESS', 'ID_LTC_ADDRESS',
                     'ID_CRYPTO_MONEY', 'ID_GENERAL_MONEY', 'ID_LENGTH', 'ID_WEIGHT',
                     'ID_VOLUME', 'ID_PERCENTAGE', 'ID_VERSION', 'ID_FILENAME',
                     'ID_FILESIZE', 'ID_TIME', 'ID_BRAND_NAME', 'ID_NUMBER']

fw = open(os.path.join(PATH_LEX, 'mask_dist_all.txt'), 'w')
fs = open(os.path.join(PATH_LEX, 'mask_dist_concise.txt'), 'w')

categories = os.listdir(PATH_CATEGORY)

token_stats = dict()

for category in categories:

    PATH_CUR_CATEGORY = os.path.join(PATH_CATEGORY, category)

    fw.write(f'\n-*-*-*-*-{category}-*-*-*-*-\n\n')

    current_category_files = os.listdir(PATH_CUR_CATEGORY)
    num_files = 0

    token_cnt = {k : 0 for k in identifier_tokens}

    for files in current_category_files:

        filepath = os.path.join(PATH_CUR_CATEGORY, files)

        for e in encodings:
                try:
                    with open(filepath, 'r', encoding=e) as f:
                        content = f.read()
                except UnicodeDecodeError:
                    pass
                    #print(f'Got unicode error with {e}, trying a different encoding')
                else:
                    # encoding works
                    break

        # with open(filepath, 'r') as f:

        #     fw.write(f'{files}:\n')

        #     try:
        #         content = f.read()
        #     except UnicodeDecodeError:
        #         content = None

        if content is not None:

            num_files += 1
            fw.write(f'{files}:\n')

            for identifier in identifier_tokens:

                res = re.findall(identifier, content)
                token_cnt[identifier] += len(res)

                fw.write(f'{identifier:<17}: {len(res)}\n')
        
        fw.write('\n')
    
    token_cnt_norm = {k : v / num_files for k, v in token_cnt.items()}
    token_sum = sum([count for _, count in token_cnt_norm.items()])
    token_cnt_ratio = {k: v / token_sum for k, v in token_cnt_norm.items()}
    token_stats[category] = token_cnt_ratio

# Export to csv
for category in categories:

    csv_file_name = os.path.join(PATH_CSV, category + "_mask.csv")

    df = pd.DataFrame.from_dict(token_stats[category], orient='index')
    df.to_csv(csv_file_name)

    print(f'{csv_file_name} created.')


for key, value in token_stats.items():

    fs.write(f'\n-*-*-*-*-{key}-*-*-*-*-\n\n')

    for k, v in value.items():

        fs.write(f'{k:<17}: {v:.3f}\n')
    
    fs.write('\n')
