# Data processing
import pandas as pd
import numpy as np
# Text preprocessiong
import nltk
nltk.download('stopwords')
nltk.download('omw-1.4')
nltk.download('wordnet')
wn = nltk.WordNetLemmatizer()
# Topic model
from bertopic import BERTopic
# Dimension reduction
from umap import UMAP
import json
import time
from tqdm import tqdm
import os
import zstandard
from datetime import datetime


def get_basic_attribute(obj, attribute):
    if attribute not in obj or obj[attribute] is None or obj[attribute] == "":
        return None
    else:
        return obj[attribute]


def filter_title(obj):
    title = get_basic_attribute(obj, "title")
    if title is None or title == "[removed]" or title == "[deleted]":
        return True
    return False

def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
	chunk = reader.read(chunk_size)
	bytes_read += chunk_size
	if previous_chunk is not None:
		chunk = previous_chunk + chunk
	try:
		return chunk.decode()
	except UnicodeDecodeError:
		if bytes_read > max_window_size:
			raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
		return read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)


def read_lines_zst(file_name):
	with open(file_name, 'rb') as file_handle:
		buffer = ''
		reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle)
		while True:
			chunk = read_and_decode(reader, 2**27, (2**29) * 2)

			if not chunk:
				break
			lines = (buffer + chunk).split("\n")

			for line in lines[:-1]:
				yield line.strip(), file_handle.tell()

			buffer = lines[-1]

		reader.close()


def get_data(file_path):
    file_size = os.stat(file_path).st_size
    file_lines = 0
    created = None
    bad_lines = 0
    results = []
    for line, file_bytes_processed in read_lines_zst(file_path):
        file_lines += 1
        if file_lines % 100000 == 0:
            print(
                f"{created} Line: {file_lines:,} Bad Lines: {bad_lines:,} Bytes Processed: {file_bytes_processed:,} : {(file_bytes_processed / file_size) * 100:.0f}%")
        try:
            obj = json.loads(line)
            created = datetime.utcfromtimestamp(int(obj["created_utc"])).strftime("%Y/%m/%d")
            # TODO: investigate impact of embedded media, num_crossposts
            filter = filter_title(obj)
            if not filter:
                results.append({
                    "title": obj["title"],
                    "id": obj["id"]
                })

        except (KeyError, json.JSONDecodeError) as err:
            print("Error:" + err)
    results_df = pd.DataFrame(results)
    return results_df

stopwords = nltk.corpus.stopwords.words("english")
print(f'There are {len(stopwords)} default stopwords. They are {stopwords}')
finished_length = 0
subreddits = [
    "personalfinance",
    "financialindependence",
    "FinancialPlanning",
    "investing",
    "wallstreetbets",
    "Wallstreetbetsnew",
    "stocks",
    "StockMarket",
    "pennystocks",
    "options",
    "RealEstate",
    "Economics",
    "realestateinvesting",
    "AskEconomics",
    "explainlikeimfive"
]
for subreddit in tqdm(subreddits):
    data = get_data(f"./raw_data/subreddits/{subreddit}_submissions.zst")
    print(f"Subreddit has {data.shape[0]} examples")
    data['text_without_stopwords'] = data['title'].apply(lambda x: ' '.join([w for w in x.split() if w.lower() not in stopwords]))
    # Lemmatization
    data['text_lemmatized'] = data['text_without_stopwords'].apply(lambda x: ' '.join([wn.lemmatize(w) for w in x.split() if w not in stopwords]))
    # Take a look at the data
    print(f"Handling subreddit {subreddit}")
    # Initiate UMAP
    umap_model = UMAP(n_neighbors=15, 
                    n_components=5, 
                    min_dist=0.0, 
                    metric='cosine', 
                    random_state=100)
    # Initiate BERTopic
    topic_model = BERTopic(umap_model=umap_model, language="english", calculate_probabilities=True, verbose=True)
    # Run BERTopic model
    start = time.time()
    topics, probabilities = topic_model.fit_transform(data['text_lemmatized'])
    
    # Save the topic model
    topic_model.save(f"./topic_models_initial_data/{subreddit}")
    finished_length += data.shape[0]
    end = time.time()
    print(f"Iteration took {(end - start) / 60} minutes.")
