import os
import sys
import pickle
import tensorflow as tf

from utils.config import get_output_folder, get_gpt3_cache_path
from utils.gpt3 import create_cache, compute_max_prompt_length, prompt_packer, cached_completion

cache = create_cache(get_gpt3_cache_path())

def create_style_descriptor_dict(dataset_path):
    print("Processing:", dataset_path)
    
    # Get the output_folder and the name of the split
    output_folder = os.path.dirname(dataset_path)
    basename = os.path.basename(dataset_path)

    # Define style descriptor output file names   
    descriptors_fn = os.path.join(output_folder, basename + "_descriptors_gpt3.pkl")

    # Instantiate descriptors dictionary
    descriptors = {}

    # Loop over each record
    dataset = tf.data.experimental.load(dataset_path)
    for record in dataset:
        target_user_id = record['user_id'].numpy()
        texts = [t.decode('utf8').replace('\n', ' ') for t in record['body'].numpy().tolist()]
        max_tokens = 20
        prompt = prompt_packer(
            'Passage: %s',
            texts,
            "List some adjectives, comma-separated, that describe the writing style of the author of these passages:",
            (),
            separator = '\n\n',
            max_prompt_length = compute_max_prompt_length(max_tokens)
        )
        descriptors[target_user_id] = cached_completion(
            engine='text-curie-001',
            prompts=prompt,
            temperature=0.7,
            max_tokens=max_tokens,
            top_p=1.0,
            frequency_penalty=1.0,
            presence_penalty=0.0,
            cache=cache,
            stop=None,
        ).rstrip('.').lower()
        print("    User:", target_user_id)
        print("    Described as:", descriptors[target_user_id])

    # Save to disk
    print("Saving to disk...")
    with open(descriptors_fn, 'wb') as fp:
        pickle.dump(descriptors, fp)
    print("Done.")

if __name__ == "__main__":
    # Define which datasets to process
    output_folder = get_output_folder(sys.argv[1])
    split = "test"
    dataset_paths = [
        f'{output_folder}/reddit_{split}_query/random/target_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/target_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/target_author_posts',
    ]

    # Create style descriptor dicts
    for dataset_path in dataset_paths:
        create_style_descriptor_dict(dataset_path)