import os
import sys
import tensorflow as tf

from utils.config import get_output_folder, get_gpt3_cache_path
from utils.gpt3 import create_cache, count_tokens, compute_max_prompt_length, prompt_packer_batch, truncate, cached_completion

cache = create_cache(get_gpt3_cache_path())

def create_paraphrases_dataset(dataset_path):
    print("Processing:", dataset_path)
    
    # Get the output_folder and the name of the split
    output_folder = os.path.dirname(dataset_path)
    basename = os.path.basename(dataset_path)

    # Define paraphrase output file names   
    paraphrase_dataset_path = os.path.join(output_folder, basename + "_paraphrase_neutral")

    # Define dataset mapper
    def python_map_fn(user_id, body):
        user_id = user_id.numpy()
        print("    User:", user_id)
        texts = [t.decode('utf8').replace('\n', ' ') for t in body.numpy().tolist()]
        max_tokens = min(max(map(count_tokens, texts)) + 20, 500)
        prompts = prompt_packer_batch(
            [''] * len(texts),
            [[]] * len(texts),
            ["Passage: %s\n\nParaphrase the passage in a simple neutral style.\n\nRewrite:"] * len(texts),
            [(truncate(text, 2000),) for text in texts],
            separator = '\n\n',
            max_prompt_length=compute_max_prompt_length(max_tokens)
        )
        paraphrases = cached_completion(
            engine='text-curie-001',
            prompts=prompts,
            temperature=0.7,
            max_tokens=max_tokens,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            cache=cache,
            stop=None,
        )
        return tf.constant(paraphrases)
    
    def map_fn(record):
        record['body'] = tf.py_function(func=python_map_fn, inp=[record['user_id'], record['body']], Tout=tf.string)
        return record

    # Map each record
    dataset = tf.data.experimental.load(dataset_path)
    dataset = dataset.map(map_fn)

    # Save to disk
    print("Saving to disk...")
    tf.data.experimental.save(dataset, paraphrase_dataset_path)
    print("Done.")

if __name__ == "__main__":
    # Define which datasets to process
    output_folder = get_output_folder(sys.argv[1])
    split = "test"
    dataset_paths = [
        f'{output_folder}/reddit_{split}_query/random/source_author_posts',
        f'{output_folder}/reddit_{split}_query/random/target_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/source_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/target_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/source_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/target_author_posts',
    ]

    # Create paraphrase datasets
    for dataset_path in dataset_paths:
        create_paraphrases_dataset(dataset_path)