import os
import json
from datetime import datetime
import zstandard
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers
import logging.handlers
from tqdm.auto import tqdm
import sys
from torch.utils.data import Dataset


class ListDataset(Dataset):

    def __init__(self, original_list):
        self.original_list = original_list

    def __len__(self):
        return len(self.original_list)

    def __getitem__(self, i):
        return self.original_list[i]


log = logging.getLogger("bot")
log.setLevel(logging.INFO)
log_formatter = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s')
log_str_handler = logging.StreamHandler()
log_str_handler.setFormatter(log_formatter)
log.addHandler(log_str_handler)
if not os.path.exists("logs"):
    os.makedirs("logs")
log_file_handler = logging.handlers.RotatingFileHandler(os.path.join("logs", "bot.log"), maxBytes=1024 * 1024 * 16,
                                                        backupCount=5)
log_file_handler.setFormatter(log_formatter)
log.addHandler(log_file_handler)


def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
    chunk = reader.read(chunk_size)
    bytes_read += chunk_size
    if previous_chunk is not None:
        chunk = previous_chunk + chunk
    try:
        return chunk.decode()
    except UnicodeDecodeError:
        if bytes_read > max_window_size:
            raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
        log.info(f"Decoding error with {bytes_read:,} bytes, reading another chunk")
        return read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)


def read_lines_zst(file_name):
    with open(file_name, 'rb') as file_handle:
        buffer = ''
        reader = zstandard.ZstdDecompressor(max_window_size=2 ** 31).stream_reader(file_handle)
        while True:
            chunk = read_and_decode(reader, 2 ** 27, (2 ** 29) * 2)

            if not chunk:
                break
            lines = (buffer + chunk).split("\n")

            for line in lines[:-1]:
                yield line.strip(), file_handle.tell()

            buffer = lines[-1]

        reader.close()


def write_line_zst(handle, line):
    handle.write(line.encode('utf-8'))
    handle.write("\n".encode('utf-8'))


def get_basic_attribute(obj, attribute):
    if attribute not in obj or obj[attribute] is None or obj[attribute] == "":
        return None
    else:
        return obj[attribute]


def evaluate_whether_question_is_found(texts):
    # if text is None or text == "" or text == " " or text == "[deleted]" or text == "[removed]":
    #     return False, 0, False
    # preprocessed = tokenizer(text, truncation=True, max_length=512)
    # prediction = pipe(preprocessed)
    # if prediction[0]["label"] == "question":
    #     return True, prediction[0]["score"], True
    # elif prediction[0]["label"] == "non_question":
    #     return False, prediction[0]["score"], True
    # return False, 0, False
    dataset = ListDataset(texts)
    predictions = []
    for out in tqdm(pipe(dataset, truncation=True, max_length=512, batch_size=512), total=len(dataset)):
        predictions.append(out)
    # predictions = [el for el in tqdm(pipe(dataset, truncation=True, max_length=512, batch_size=16), total=len(dataset), file=sys.stdout)]
    results = []
    for prediction in predictions:
        if prediction["label"] == "question":
            results.append((True, prediction["score"]))
        elif prediction["label"] == "non_question":
            results.append((False, prediction["score"]))
    return results
    # text_list = text.split()
    # n = 500
    # text_list = [' '.join(text_list[i:i + n]) for i in range(0, len(text_list), n)]
    # for t in text_list:
    #     prediction = pipe(t)
    #     if prediction[0]["label"] == "question":
    #         return True, prediction[0]["score"], True
    #     elif prediction[0]["label"] == "non_question":
    #         return False, prediction[0]["score"], True
    # return False, 0, False


def iterate_over_file(path, submission=True):
    file_size = os.stat(path).st_size
    file_lines = 0
    created = None
    bad_lines = 0
    output_path_question = os.path.join("data/output_data", f"{subreddit}_question.zst")
    handle_question = zstandard.ZstdCompressor().stream_writer(open(output_path_question, 'wb'))
    objs = []
    titles = []
    self_texts = []
    bodies = []
    for line, file_bytes_processed in read_lines_zst(path):
        file_lines += 1
        if file_lines % 100000 == 0:
            print(
                f"{created} Line: {file_lines:,} Bad Lines: {bad_lines:,} Bytes Processed: {file_bytes_processed:,} : {(file_bytes_processed / file_size) * 100:.0f}%", flush=True)
        try:
            obj = json.loads(line)
            created = datetime.utcfromtimestamp(int(obj["created_utc"])).strftime("%Y/%m/%d")
            objs.append(obj)
            if submission:
                titles.append(get_basic_attribute(obj, "title") if get_basic_attribute(obj, "title") is not None else "")
                self_texts.append(get_basic_attribute(obj, "selftext") if get_basic_attribute(obj, "selftext") is not None else "")
            else:
                bodies.append(get_basic_attribute(obj, "body") if get_basic_attribute(obj, "body") is not None else "")
            if file_lines != 0 and file_lines % 32768 == 0:
                # print("Collected information, now evaluating whether question is found", flush=True)
                if submission:
                    title_results = evaluate_whether_question_is_found(titles)
                    self_text_results = evaluate_whether_question_is_found(self_texts)
                else:
                    body_results = evaluate_whether_question_is_found(bodies)
                # print("Finished evaluating, now writing to file", flush=True)
                for idx, obj in enumerate(objs):
                    if submission:
                        current_title_result = title_results[idx]
                        obj["title_is_question"] = current_title_result[0]
                        obj["title_score"] = current_title_result[1]
                        current_self_text_result = self_text_results[idx]
                        obj["self_text_is_question"] = current_self_text_result[0]
                        obj["self_text_score"] = current_self_text_result[1]
                    else:
                        current_body_results = body_results[idx]
                        obj["body_is_question"] = current_body_results[0]
                        obj["body_score"] = current_body_results[1]
                    new_line = json.dumps(obj)
                    write_line_zst(handle_question, new_line)
                objs = []
                titles = []
                self_texts = []
                bodies = []
        except Exception as e:
            bad_lines += 1
            print(f"Bad line: {e}")
    handle_question.close()
    return 0


if __name__ == "__main__":
    # subreddit = "personalfinance"
    subreddits = ["personalfinance", "financialindependence", "FinancialPlanning", "investing", "wallstreetbets",
                  "Wallstreetbetsnew", "stocks", "StockMarket", "pennystocks", "options", "RealEstate", "Economics",
                  "realestateinvesting", "AskEconomics", "explainlikeimfive"]
    # submission_file = f"{subreddit}_submissions.zst"
    # comment_file = f"{subreddit}_comments.zst"
    # submission_path = os.path.join(PathManager.get_data_path(), submission_file)
    # tokenizer = AutoTokenizer.from_pretrained("huaen/question_detection")
    # model = AutoModelForSequenceClassification.from_pretrained("huaen/question_detection")

    pipe = transformers.pipeline("text-classification", model="huaen/question_detection", device=0)

    # iterate_over_file(submission_path, submission=True)
    os.makedirs("data/output_data", exist_ok=True)
    submission_or_comment = [True]
    for subreddit in subreddits:
        for is_submission in submission_or_comment:
            if is_submission:
                submission_string = "submissions"
            else:
                submission_string = "comments"
            file_name = f"{subreddit}_{submission_string}.zst"
            path = os.path.join("data", file_name)
            if is_submission:
                iterate_over_file(path, submission=is_submission)
            else:
                iterate_over_file(path, submission=is_submission)
            print(f"Finished {subreddit}!", flush=True)
            break
        break
