#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
9/4/2025
Author: Katarina

Samples max 2000 messages from each channel and preprocesses them:
 - remove URLs
 - remove user mentions
 - remove duplicate sentences
after this, remove duplicates and sample 2000 messages of at least 3 sentences (if possible)
"""

import os, re, json
import random as rd
import pandas as pd

from nltk.tokenize import sent_tokenize
from urlextract import URLExtract

url_extractor = URLExtract()

def preprocess_message(text):
    urls = url_extractor.find_urls(text)
    for url in urls:
        url = re.escape(url)
        text = re.sub(url, '', text)
    while re.search('@[\w]{3,}', text) is not None:
        mention = re.search('@[\w]{3,}', text)
        mention = re.escape(mention.group())
        text = re.sub(mention, '', text)
    return(text)

def remove_duplicate_sentences(messages):
    sentences = []
    unduplicated_messages = []
    for message in messages:
        sentences += sent_tokenize(message)
    duplicate_sentences = [x for x in sentences if sentences.count(x) > 1]
    for message in messages:
        sentences = sent_tokenize(message)
        sentences = [x for x in sentences if not x in duplicate_sentences and len(x) > 1]
        if len(sentences) > 0:
            unduplicated_messages.append(' '.join(sentences))
    return(unduplicated_messages)

def split_in_paragraphs(message):
    # in: a string
    # out: a list of strings (1 or more elements)
    splitted_message = []
    current_paragraph = ""
    for paragraph in re.split('\n+', message):
        current_paragraph = f'{current_paragraph}{paragraph} '
        if len(sent_tokenize(paragraph)) > 1:
            splitted_message.append(current_paragraph)
            current_paragraph = ""
    return(splitted_message)

def sample_messages(messages, max_n_messages, min_sentences = 3):
    if len(messages) <= max_n_messages:
        return(messages)
    # get the messages with at least 3 sentences
    long_messages = [x for x in messages if len(sent_tokenize(x)) >= 3]
    if len(long_messages) <= max_n_messages:
        messages = list(set(messages) - set(long_messages))
        n_fill = max_n_messages - len(long_messages)
        return(long_messages + rd.sample(messages, n_fill))
    else:
        return(rd.sample(long_messages, max_n_messages))

def main():
    in_directory = 'data/1_raw_messages'
    out_directory = 'data/2_preprocessed'
    stats_df = {'username' : [],
                'n_messages' : [],
                'n_paragraphs' : []}
    for i, file in enumerate(os.listdir(in_directory)):
        if not '.json' in file:
            continue
        with open(f'{in_directory}/{file}', 'r') as f:
            messages = json.load(f)
        messages = [preprocess_message(x) for x in messages]
        messages = remove_duplicate_sentences(messages)
        paragraphs = []
        for message in messages:
            paragraphs += split_in_paragraphs(message)
        paragraphs = sample_messages(paragraphs, 2500)
        messages = sample_messages(messages, 2000)

        # add some data to our stats_df
        stats_df['username'].append(re.sub('\.json', '', file))
        stats_df['n_messages'].append(len(messages))
        stats_df['n_paragraphs'].append(len(paragraphs))

        # dump the paragraphs and the messages
        with open(f'{out_directory}/full_messages/{file}', 'w') as f:
            json.dump(messages, f)

        with open(f'{out_directory}/paragraphs/{file}', 'w') as f:
            json.dump(paragraphs, f)

        if i+1 % 10 == 0 and i > 0:
            print(f'{i+1} out of {len(os.listdir(in_directory))} files processed!')


    stats_df = pd.DataFrame(stats_df)
    stats_df.to_csv('preprocessing_stats.csv', index = False)


if __name__ == '__main__':
    main()
