import code
import json
import tensorflow as tf
import tensorflow_datasets as tfds
import sys

from collections import Counter
from utils.config import get_output_folder

MAX_SOURCE_AUTHORS = json.loads(sys.argv[1])
MAX_TARGET_AUTHORS = json.loads(sys.argv[2])

def save_to_disk(max, dataset, dataset_path):
    if max:
        dataset = dataset.take(max)
    tf.data.experimental.save(dataset, dataset_path)
    

if __name__ == "__main__":
    # Load a reddit authorship dataset's split
    output_folder = get_output_folder(sys.argv[3])
    split = "test"
    query_split = f'{split}_query'
    target_split = f'{split}_target'

    def dataset():
        dataset = tfds.load("reddit_user_id", split=query_split, shuffle_files=False)
        dataset = dataset.shuffle(len(dataset), seed=42)
        return dataset
    dataset_rows = list(dataset().take(len(dataset())))

    def target_dataset():
        target_dataset = tfds.load("reddit_user_id", split=target_split, shuffle_files=False)
        target_dataset = target_dataset.shuffle(len(target_dataset), seed=42)
        return target_dataset
    target_dataset_rows = list(target_dataset().take(len(target_dataset())))

    # Helper funcs
    def get_user_ids_tensor(dataset):
        return tf.concat(list(dataset.map(lambda x: x['user_id'])), 0)

    def get_user_ids(dataset):
        return sorted(get_user_ids_tensor(dataset).numpy().tolist())

    def fetch_user_ids_filter(dataset):
        user_ids = get_user_ids_tensor(dataset)
        def _filter(record):
            out, _ = tf.raw_ops.ListDiff(x = tf.expand_dims(record['user_id'], 0), y = user_ids)
            return len(out) == 0
        return _filter

    # Save the entire target split
    save_to_disk(None, target_dataset(), f'{output_folder}/reddit_{split}_target/full')

    # Choose source and target authors
    dataset_len = min(500, len(dataset()))
    random_source_authors = lambda: dataset().take(dataset_len // 2)
    random_target_authors = lambda: dataset().skip(dataset_len // 2).take(dataset_len - dataset_len // 2)
    save_to_disk(MAX_SOURCE_AUTHORS, random_source_authors(), f'{output_folder}/reddit_{split}_query/random/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, random_target_authors(), f'{output_folder}/reddit_{split}_query/random/target_author_posts')
    save_to_disk(MAX_SOURCE_AUTHORS, target_dataset().filter(fetch_user_ids_filter(random_source_authors())), f'{output_folder}/reddit_{split}_target/random/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, target_dataset().filter(fetch_user_ids_filter(random_target_authors())), f'{output_folder}/reddit_{split}_target/random/target_author_posts')

    # Choose source and target authors with posts all within the most popular subreddit
    subreddit_counts = [Counter(dataset_rows[i]['subreddit'].numpy().tolist()) for i in range(len(dataset()))]
    exclusive_subreddits = list(map(lambda c: list(c.keys())[0], filter(lambda c: len(c) == 1, subreddit_counts)))
    most_common_subreddits = list(map(lambda p: p[0], Counter(exclusive_subreddits).most_common(5)))
    most_common_subreddits.remove(b'AskReddit') # AskReddit is too heterogenous for our purposes, we'll use the next most common
    most_common_subreddit = most_common_subreddits[0]
    def most_common_subreddit_author_filter(record):
        return tf.reduce_all((record['subreddit'] == most_common_subreddit))
    most_common_subreddit_authors_len = len(list(dataset().filter(most_common_subreddit_author_filter)))
    most_common_subreddit_source_authors = lambda: dataset().filter(most_common_subreddit_author_filter).take(most_common_subreddit_authors_len // 2)
    most_common_subreddit_target_authors = lambda: dataset().filter(most_common_subreddit_author_filter).skip(most_common_subreddit_authors_len // 2).take(most_common_subreddit_authors_len - most_common_subreddit_authors_len // 2)
    save_to_disk(MAX_SOURCE_AUTHORS, most_common_subreddit_source_authors(), f'{output_folder}/reddit_{split}_query/most_common_subreddit/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, most_common_subreddit_target_authors(), f'{output_folder}/reddit_{split}_query/most_common_subreddit/target_author_posts')
    save_to_disk(MAX_SOURCE_AUTHORS, target_dataset().filter(fetch_user_ids_filter(most_common_subreddit_source_authors())), f'{output_folder}/reddit_{split}_target/most_common_subreddit/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, target_dataset().filter(fetch_user_ids_filter(most_common_subreddit_target_authors())), f'{output_folder}/reddit_{split}_target/most_common_subreddit/target_author_posts')

    # Choose source and target authors that post in a diverse set (>= 13) of subreddits
    def diverse_author_filter(record):
        unique_subreddits, _ = tf.unique(record['subreddit'])
        return len(unique_subreddits) >= 13
    diverse_authors_len = len(list(dataset().filter(diverse_author_filter)))
    diverse_source_authors = lambda: dataset().filter(diverse_author_filter).take(diverse_authors_len // 2)
    diverse_target_authors = lambda: dataset().filter(diverse_author_filter).skip(diverse_authors_len // 2).take(diverse_authors_len - diverse_authors_len // 2)
    save_to_disk(MAX_SOURCE_AUTHORS, diverse_source_authors(), f'{output_folder}/reddit_{split}_query/diverse/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, diverse_target_authors(), f'{output_folder}/reddit_{split}_query/diverse/target_author_posts')
    save_to_disk(MAX_SOURCE_AUTHORS, target_dataset().filter(fetch_user_ids_filter(diverse_source_authors())), f'{output_folder}/reddit_{split}_target/diverse/source_author_posts')
    save_to_disk(MAX_TARGET_AUTHORS, target_dataset().filter(fetch_user_ids_filter(diverse_target_authors())), f'{output_folder}/reddit_{split}_target/diverse/target_author_posts')

    # Interact
    code.interact(local=locals())

