from prompt import Prompter
import torch
import json
import argparse
import os
import logging
from tqdm import tqdm

from pathlib import Path
from utils import set_logger, LanguageDetector

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tag", type=str)
    # parser.add_argument("--save_root", type=str, default="QA_dataset")

    args = parser.parse_args()
    return args


def main(args):

    logger = set_logger()
    prompter = Prompter()
    lang_detector = LanguageDetector()

    # input_json_path = Path(f"./input/Tumblr-1714521600-{args.tag}.json")
    input_json_path = Path(f"./input/NEW_{args.tag}.json")
    save_json_path = Path(f"./intermediate_result/{args.tag}.json")

    with open(input_json_path, "r") as f:
        input = json.load(f)

    input_lst = sorted(
        list(input.items()), key=lambda x: x[1]["timestamp"], reverse=True
    )

    output = {}

    for i, (key, post) in enumerate(tqdm(input_lst)):
        if post["type"] != "text":
            continue

        filtered_post = {
            "timestamp": post["timestamp"],
            "post_url": post["post_url"],
            "body": post["body"],
            "expressions": [],
            "opt_length": 50 < len(post["body"]) < 3000,
            "english": False,
            "need_amt": False
        }

        # check if the text is the optimal length
        if not filtered_post["opt_length"]:
            output[key] = filtered_post 
            continue

        # check if the text is written in English
        text = post["body"].replace("\n", " ")
        lang = lang_detector.predict_lang(text)
        filtered_post["english"] = lang[0][0] == "__label__en"

        if not filtered_post["english"]:
            output[key] = filtered_post
            continue

        try:
            str_list = prompter.prompt(filtered_post["body"])
            filtered_post["expressions"] = eval(str_list)
            logger.info(f"Generated mention lists from gpt-4 (i: {i}, key: {key})")
        except:
            logger.warning(f"Failed to get accurate list from gpt-4 (i: {i}, key: {key})")
        
        if len(filtered_post["expressions"]) > 1:
            filtered_post["need_amt"] = True

        output[key] = filtered_post

        if i % 10 == 0:
            with open(save_json_path, 'w') as f:
                json.dump(output, f, indent=3)
            logger.info(f"Saved json (current i: {i}, current key: {key})")

    with open(save_json_path, 'w') as f:
        json.dump(output, f, indent=3)

if __name__ == "__main__":
    args = get_args()
    main(args)
