import os
import pickle
import sys
import tensorflow as tf

from utils.config import get_output_folder

STRAP_PARAPHRASE_DIR = '/exp/$USER/strap_paraphrase_df'

def create_paraphrases_dataset(dataset_path):
    print("Processing:", dataset_path)
    
    # Get the output_folder and the name of the split
    output_folder = os.path.dirname(dataset_path)
    basename = os.path.basename(dataset_path)

    # Get dataframe name
    dataframe_fn = ''
    if '/random/' in dataset_path:
        dataframe_fn += 'df_query_random_'
    elif '/most_common_subreddit/' in dataset_path:
        dataframe_fn += 'df_query_mcs_'
    elif '/diverse/' in dataset_path:
        dataframe_fn += 'df_query_diverse_'
    if basename == 'source_author_posts':
        dataframe_fn += 'source.pkl'
    elif basename == 'target_author_posts':
        dataframe_fn += 'target.pkl'
    
    # Load dataframe
    with open(os.path.join(STRAP_PARAPHRASE_DIR, dataframe_fn), 'rb') as fp:
        df = pickle.load(fp)

    # Define paraphrase output file names   
    paraphrase_dataset_path = os.path.join(output_folder, basename + "_paraphrase_strap")

    # Define dataset mapper
    def python_map_fn(user_id, body):
        user_id = user_id.numpy()
        print("    User:", user_id)
        texts = [t.decode('utf8').replace('\n', ' ') for t in body.numpy().tolist()]
        paraphrases = df.loc[df['user_id'] == user_id.decode('utf8'), 'neutral'].tolist()
        assert len(texts) == len(paraphrases), "The length of texts and paraphrases must match!"
        return tf.constant(paraphrases)
    
    def map_fn(record):
        record['body'] = tf.py_function(func=python_map_fn, inp=[record['user_id'], record['body']], Tout=tf.string)
        return record

    # Map each record
    dataset = tf.data.experimental.load(dataset_path)
    dataset = dataset.map(map_fn)

    # Save to disk
    print("Saving to disk...")
    tf.data.experimental.save(dataset, paraphrase_dataset_path)
    print("Done.")

if __name__ == "__main__":
    # Define which datasets to process
    output_folder = get_output_folder(sys.argv[1])
    split = "test"
    dataset_paths = [
        f'{output_folder}/reddit_{split}_query/random/source_author_posts',
        f'{output_folder}/reddit_{split}_query/random/target_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/source_author_posts',
        f'{output_folder}/reddit_{split}_query/most_common_subreddit/target_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/source_author_posts',
        f'{output_folder}/reddit_{split}_query/diverse/target_author_posts',
    ]

    # Create paraphrase datasets
    for dataset_path in dataset_paths:
        create_paraphrases_dataset(dataset_path)