import json
import os
import pickle
import sys
import tensorflow as tf

from utils.config import get_output_folder, get_gpt3_cache_path, make_safe_filename
from utils.gpt3 import create_cache, count_tokens, compute_max_prompt_length, prompt_packer_batch, truncate, cached_completion

cache = create_cache(get_gpt3_cache_path())

def style_transfer(dataset_path, para_mode, descriptors_mode):
    print("Processing:", dataset_path)

    # Define style_transfer output file names
    style_transfer_method = f"gpt3_{para_mode}"
    if descriptors_mode:
        style_transfer_method += "_style_descriptors"
    style_transfer_dataset_path = os.path.join(dataset_path, f"source_author_posts_style_transferred/{style_transfer_method}")

    # Load source and target datasets
    source_dataset_path = os.path.join(dataset_path, f'source_author_posts_paraphrase_{para_mode}')
    target_dataset_path = os.path.join(dataset_path, 'target_author_posts')
    target_paraphrase_dataset_path = os.path.join(dataset_path, f'target_author_posts_paraphrase_{para_mode}')

    # For each target user
    target_dataset = tf.data.experimental.load(target_dataset_path)
    target_paraphrase_dataset = tf.data.experimental.load(target_paraphrase_dataset_path)
    for target_record, target_paraphrase_record in zip(target_dataset, target_paraphrase_dataset):
        assert target_record['user_id'].numpy() == target_paraphrase_record['user_id'].numpy()
        target_user_id = target_record['user_id'].numpy()
        print("    Style Transfer Target:", target_user_id)
        target_texts = [t.decode('utf8').replace('\n', ' ') for t in target_record['body'].numpy().tolist()]
        target_paraphrase_texts = [t.decode('utf8').replace('\n', ' ') for t in target_paraphrase_record['body'].numpy().tolist()]

        # Load style descriptors
        descriptors = {}
        if descriptors_mode:
            descriptors_fn = os.path.join(dataset_path, 'target_author_posts_descriptors_gpt3.pkl')
            with open(descriptors_fn, 'rb') as fp:
                descriptors = pickle.load(fp)

        # Define dataset mapper
        def python_map_fn(user_id, body):
            user_id = user_id.numpy()
            print("        Style Transfer Source:", user_id)
            texts = [t.decode('utf8').replace('\n', ' ') for t in body.numpy().tolist()]
            max_tokens = min(max(map(count_tokens, texts)) + 20, 500)
            if descriptors_mode:
                target_descriptors = descriptors[target_user_id]
                target_descriptors_batch = [target_descriptors] * len(texts)
                prompts = prompt_packer_batch(
                    ['Here is some text: {%s} Here is a rewrite of the text that is more %s: {%s}'] * len(texts),
                    [list(zip(target_paraphrase_texts, target_descriptors_batch, target_texts))] * len(texts),
                    ["Here is some text: {%s} Here is a rewrite of the text that is more %s: {"] * len(texts),
                    [(truncate(text, 2000), target_descriptors) for text in texts],
                    separator = ' ',
                    max_prompt_length=compute_max_prompt_length(max_tokens)
                )
            else:
                prompts = prompt_packer_batch(
                    ['Here is some text: {%s} Here is a rewrite of the text: {%s}'] * len(texts),
                    [list(zip(target_paraphrase_texts, target_texts))] * len(texts),
                    ["Here is some text: {%s} Here is a rewrite of the text: {"] * len(texts),
                    [(truncate(text, 2000),) for text in texts],
                    separator = ' ',
                    max_prompt_length=compute_max_prompt_length(max_tokens)
                )
            styled_texts = cached_completion(
                engine='text-curie-001',
                prompts=prompts,
                temperature=0.7,
                max_tokens=max_tokens,
                top_p=1.0,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                cache=cache,
                stop=["}"],
            )
            return tf.constant(styled_texts)

        def map_fn(record):
            record['body'] = tf.py_function(func=python_map_fn, inp=[record['user_id'], record['body']], Tout=tf.string)
            record['user_id'] = target_user_id
            return record

        # Map each record
        source_dataset = tf.data.experimental.load(source_dataset_path)
        source_dataset = source_dataset.map(map_fn)

        # Save to disk
        print("Saving to disk...")
        tf.data.experimental.save(source_dataset, os.path.join(style_transfer_dataset_path, make_safe_filename(target_user_id.decode('utf8'))))
        print("Done.")

if __name__ == "__main__":
    # Define which datasets to process
    output_folder = get_output_folder(sys.argv[1])
    para_mode = sys.argv[2]
    descriptors_mode = json.loads(sys.argv[3])
    split = "test"
    dataset_paths = [
        f'{output_folder}/reddit_{split}_query/random',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit',
        f'{output_folder}/reddit_{split}_query/diverse',
    ]

    # Create style transferred datasets
    for dataset_path in dataset_paths:
        style_transfer(dataset_path, para_mode, descriptors_mode)