import os
import sys
import pickle
import tensorflow as tf
import utils.gpt3
import utils.llm

from utils.config import get_output_folder, get_gpt3_cache_path
from utils.gpt3 import compute_max_prompt_length, prompt_packer
from utils.llm import load_engine, completion

def create_style_descriptor_dict(dataset_path, engine, temperature, top_p):
    print("Processing:", dataset_path)
    
    # Get the output_folder and the name of the split
    output_folder = os.path.dirname(dataset_path)
    basename = os.path.basename(dataset_path)

    # Define style descriptor output file names   
    descriptors_fn = os.path.join(output_folder, basename + f"_descriptors_{engine}.pkl")

    # Instantiate descriptors dictionary
    descriptors = {}

    # Loop over each record
    dataset = tf.data.experimental.load(dataset_path)
    for record in dataset:
        target_user_id = record['user_id'].numpy()
        texts = [t.decode('utf8').replace('\n', ' ') for t in record['body'].numpy().tolist()]
        max_tokens = 20
        prompt = prompt_packer(
            'Passage: %s',
            texts,
            "List some adjectives, comma-separated, that describe the writing style of the author of these passages:",
            (),
            separator = '\n\n',
            max_prompt_length = compute_max_prompt_length(max_tokens)
        )
        descriptors[target_user_id] = completion(
            prompts=prompt,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty=1.0,
            stop=None,
        ).rstrip('.').rstrip(',').lower()
        print("    User:", target_user_id)
        print("    Described as:", descriptors[target_user_id])

    # Save to disk
    print("Saving to disk...")
    with open(descriptors_fn, 'wb') as fp:
        pickle.dump(descriptors, fp)
    print("Done.")

if __name__ == "__main__":
    # Define which datasets to process
    output_folder = get_output_folder(sys.argv[1])
    engine = sys.argv[2]
    temperature = float(sys.argv[3])
    top_p = float(sys.argv[4])
    load_engine(engine)
    utils.gpt3.tokenizer = utils.llm.tokenizer
    split = "test"
    dataset_paths = [
        f'{output_folder}/reddit_{split}_query/random/target_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/target_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/target_author_posts',
    ]

    # Create style descriptor dicts
    for dataset_path in dataset_paths:
        create_style_descriptor_dict(dataset_path, engine.split("/")[-1], temperature, top_p)