from PushshiftDumps.scripts.filter_file import write_line_zst, read_and_decode, read_lines_zst
import os
from path_manager import PathManager
from datetime import datetime
import zstandard
import json
from analysis import get_basic_attribute
import random
from nltk.tokenize.casual import _replace_html_entities, reduce_lengthening
import unicodedata
import strip_markdown
import re
from tqdm import tqdm

REDDIT_USER = r"(?:\/?u\/\w+)"
REDDIT_USER_RE = re.compile(REDDIT_USER, flags=re.UNICODE)
HASH_RE = re.compile(r'#(?=\w+)', re.UNICODE)
#my url version, nltk's doesn't work for separate regexp
URL_RE = re.compile(r"""((https?:\/\/|www)|\w+\.(\w{2-3}))([\w\!#$&-;=\?\-\[\]~]|%[0-9a-fA-F]{2})+""", re.UNICODE)

def preprocess_text(text):
    if text is None:
        text = ""
    text = _replace_html_entities(text)
    text = re.sub(REDDIT_USER_RE, ' ', text)
    text = re.sub(HASH_RE, '', text)
    text = re.sub(URL_RE, ' ', text)
    text = strip_emoji(text)
    text = remove_edits_and_updates_from_text(text)
    text = strip_markdown.strip_markdown(text)
    return text


def strip_emoji(text):
    '''Take out emoji. Returns doc string.
    ::param text:: tweet
    ::type doc:: str
    '''
    text = ''.join(c for c in text if unicodedata.category(c) != 'So') # almost works perfectly
    return text


def remove_edits_and_updates_from_text(text):
    filter_elements = ["Edit: ", "edit: ", "update: ", "Update: "]
    for el in filter_elements:
        if el in text:
            text_split = text.split(el)
            if len(text_split[0]) != 0:
                text = text_split[0]
    return text


def collect_merged_information(obj, comment, current_level):
    text = get_basic_attribute(obj, "title")
    if text is None:
        text = get_basic_attribute(obj, "body")
    context = get_basic_attribute(obj, "selftext")
    answer = get_basic_attribute(comment, "body")
    final_obj = {
        "text": text if text is not None else "",
        "context": context if context is not None else "",
        "answer": answer if answer is not None else "",
        # "level": current_level
    }
    return final_obj


def collect_merged_information_two_answers(obj, comment1, comment2, current_level):
    text = get_basic_attribute(obj, "title")
    if text is None:
        text = get_basic_attribute(obj, "body")
    context = get_basic_attribute(obj, "selftext")
    answer_1 = get_basic_attribute(comment1, "body")
    answer_2 = get_basic_attribute(comment2, "body")
    text = preprocess_text(text)
    context = preprocess_text(context)
    answer_1 = preprocess_text(answer_1)
    answer_2 = preprocess_text(answer_2)
    final_obj = {}
    final_obj["text"] = text
    final_obj["context"] = context
    final_obj["answer_1"] = answer_1
    final_obj["answer_2"] = answer_2
    final_obj["subreddit"] = subreddit

    for key, value in obj.items():
        if key not in ["title", "selftext", "comments"]:
            final_obj[key] = value
    for key, value in comment1.items():
        if key not in ["body", "parent_id"]:
            final_obj[key + "_answer1"] = value
    for key, value in comment2.items():
        if key not in ["body", "parent_id"]:
            final_obj[key + "_answer2"] = value

    return final_obj


def create_question_answer_dataset(file_path, max_level=None):
    amount_of_lines = len(list(read_lines_zst(file_path)))
    file_size = os.stat(file_path).st_size
    file_lines = 0
    created = None
    bad_lines = 0
    lines_created = 0
    not_enough_comments = 0
    score_difference_small = 0
    level_string = f"_max_level_{max_level}" if max_level is not None else ""
    # For Merging
    output_path = os.path.join(PathManager.get_question_answers_path(), "updated_filters_no_filter_comments_good_bad", current_directory)
    os.makedirs(output_path, exist_ok=True)
    print("created directory: " + output_path)
    output_path = os.path.join(output_path, f"{subreddit}_qa{level_string}.zst")
    handle = zstandard.ZstdCompressor().stream_writer(open(output_path, 'wb'))
    for line, file_bytes_processed in tqdm(read_lines_zst(file_path), total=amount_of_lines):
        file_lines += 1
        if file_lines % 100000 == 0:
            print(
                f"{created} Line: {file_lines:,} Bad Lines: {bad_lines:,} Bytes Processed: {file_bytes_processed:,} : {(file_bytes_processed / file_size) * 100:.0f}%")
        try:
            obj = json.loads(line)
            created = datetime.utcfromtimestamp(int(obj["created_utc"])).strftime("%Y/%m/%d")
            submission_id = obj["id"]
            obj_map = {submission_id: {"level": 0, "object": obj, "parent": None}}
            if "comments" not in obj:
                continue
            handled_idxs = []
            obj_map_size = len(obj_map)
            # Construct an object map mapping object ids to objects, their level and parent id
            while True:
                if len(handled_idxs) >= len(obj["comments"]):
                    break
                for idx, comment in enumerate(obj["comments"]):
                    if idx in handled_idxs:
                        continue
                    comment_parent = comment["parent_id"]
                    if comment_parent in obj_map:
                        parent_level = obj_map[comment_parent]["level"]
                        if parent_level > max_level:
                            continue
                        comment_id = comment["name"]
                        if comment_id is None:
                            comment_id = "t1_" + comment["id"]
                        obj_map[comment_id] = {
                            "level": parent_level + 1,
                            "object": comment,
                            "parent": comment_parent
                        }
                        handled_idxs.append(idx)
                if len(obj_map) == obj_map_size and obj_map_size != len(obj["comments"]) + 1:
                    # print(f"{len(obj['comments']) - len(obj_map) + 1} unconnected comment(s) found")
                    break
                obj_map_size = len(obj_map)
            # Iterate over object map in order to generate question answer pairs and write them to file
            if which_dataset == "QA":
                current_object_list = list(obj_map.values())
                current_object_list = [el for el in current_object_list if el["parent"] == submission_id]
                if len(current_object_list) == 0:
                    continue
                sorted_object_list = sorted(current_object_list, key=lambda x: x["object"]["score"], reverse=True)
                best_answer = sorted_object_list[0]
                parent = obj_map[submission_id]
                new_obj = collect_merged_information(parent["object"], best_answer["object"], parent["level"])
                new_line = json.dumps(new_obj)
                write_line_zst(handle, new_line)
                lines_created += 1
                # for key, element in obj_map.items():
                #     parent_id = element["parent"]
                #     if parent_id is None:
                #         continue
                #     parent = obj_map[parent_id]
                #     new_obj = collect_merged_information(parent["object"], element["object"], parent["level"])
                #     new_line = json.dumps(new_obj)
                #     write_line_zst(handle, new_line)
                #     lines_created += 1
            else:
                cur_obj_list = list(obj_map.values())
                cur_obj_list = [obj for obj in cur_obj_list if obj["parent"] == submission_id]
                parent = obj_map[submission_id]
                cur_obj_list = [obj for obj in cur_obj_list if obj["object"]["author"] != parent["object"]["author"]] # Filter out authors answering themselves directly
                if len(cur_obj_list) < 2:
                    not_enough_comments += 1
                    continue
                cur_obj_list.sort(key=lambda el: el["object"]["score"], reverse=True)
                if which_dataset == "QA_good_good":
                    new_obj = collect_merged_information_two_answers(parent["object"], cur_obj_list[0]["object"], cur_obj_list[1]["object"], parent["level"])
                elif which_dataset == "QA_good_bad":
                    cur_scores_list = [el["object"]["score"] for el in cur_obj_list]
                    best_score = cur_scores_list[0]
                    qualified_bad_answer_indices = [idx for idx, value in enumerate(cur_scores_list) if (best_score - value >= 10) and (value <= 3) ]
                    if len(qualified_bad_answer_indices) == 0:
                        score_difference_small += 1
                        continue
                    selected_index = random.choice(qualified_bad_answer_indices)
                    new_obj = collect_merged_information_two_answers(parent["object"], cur_obj_list[0]["object"],
                                                                     cur_obj_list[selected_index]["object"], parent["level"])
                else:
                    print("Wrong dataset name!!!")
                    continue
                new_line = json.dumps(new_obj)
                write_line_zst(handle, new_line)
                lines_created += 1
            # print(f"finished handling submission: {submission_id}")

        except (KeyError, json.JSONDecodeError) as err:
            print("Error:" + err)
    handle.close()
    print(f"Lines created: {lines_created} out of {file_lines} submissions of subreddit {subreddit}")
    print(f"{not_enough_comments} entries with less than two comments available")
    print(f"{score_difference_small} entries available score difference smaller 10")
    return lines_created



if __name__ == "__main__":
    subreddits = ["personalfinance", "financialindependence", "FinancialPlanning", "investing", "wallstreetbets",
                  "Wallstreetbetsnew", "stocks", "StockMarket", "pennystocks", "options", "RealEstate", "Economics",
                  "realestateinvesting", "AskEconomics", "explainlikeimfive"]
    # 2186362 + 462523 + 86692 + 587147 + 5716104 + 84048 + 874304 + 57523 + 206569 + 195266 + 503554 + 31544 + 153164 + 42385 + 1289669
    # 1145969 + 133911 + 51246 + 277270 + 3700981 + 64503 + 443077 + 37974 + 135832 + 96995 + 263075 + 11671 + 84188 + 27423 + 555800
    which_dataset = "QA_good_bad"
    datasets = ["QA", "QA_good_good", "QA_good_bad"]
    # for current_directory in os.listdir(PathManager.get_filtered_all_comments_path()):
    #     if "." in current_directory:
    #         continue
    #     for subreddit in subreddits:
    #         current_path = os.path.join(PathManager.get_filtered_all_comments_path(), current_directory, f"{subreddit}_filtered_combined.zst")
    #         create_question_answer_dataset(current_path, max_level=0)
    current_directory = "no_gilding_and_awards_percentile90"
    total_lines = 0
    for subreddit in subreddits:
        current_path = os.path.join(PathManager.get_filtered_all_comments_path(), current_directory, f"{subreddit}_filtered_combined.zst")
        total_lines += create_question_answer_dataset(current_path, max_level=0)
    print(f"total dataset size: {total_lines}")
    # create_question_answer_dataset(path)
    # TODO: Track stats for QA pairs created

