from datasets import load_dataset, DatasetDict, Dataset

original_enron_dataset_path = "preference-agents/EnronEmails-42K"
processed_enron_dataset_path = "preference-agents/preference-enron"

# Load the original Enron dataset
enron_dataset = load_dataset(original_enron_dataset_path)

# Columns to keep and their new order
columns_to_keep = ["from", "to", "date", "subject", "email_context", "content"]
final_column_order = ["from", "to", "date", "subject", "previous_context", "content"]


def filter_columns(dataset, columns_to_retain):
    """
    Filters the dataset to retain only the specified columns.

    Args:
        dataset (DatasetDict): The dataset to filter.
        columns_to_retain (list): List of columns to keep in the dataset.

    Returns:
        DatasetDict: The filtered dataset with only the specified columns.
    """
    # Remove columns that are not in columns_to_retain
    dataset = dataset.remove_columns(
        [col for col in dataset.column_names if col not in columns_to_retain]
    )
    return dataset


def concatenate_email_lists(example_batch):
    """
    Concatenates email lists into comma-separated strings for 'from' and 'to' columns.

    Args:
        example_batch (dict): A batch of examples from the dataset.

    Returns:
        dict: The modified batch with concatenated email lists.
    """
    example_batch["from"] = [
        ", ".join(email_list) for email_list in example_batch["from"]
    ]
    example_batch["to"] = [", ".join(email_list) for email_list in example_batch["to"]]
    return example_batch


def rename_and_reorder_columns(dataset):
    """
    Renames 'email_context' to 'previous_context' and reorders columns as specified.

    Args:
        dataset (DatasetDict): The dataset to process.

    Returns:
        Dataset: The dataset with renamed and reordered columns.
    """
    # Rename 'email_context' to 'previous_context'
    dataset = dataset.rename_column("email_context", "previous_context")

    # Create a new dataset with the columns in the desired order
    new_dataset = Dataset.from_dict({col: dataset[col] for col in final_column_order})

    return new_dataset


# Apply the filtering function to the dataset
filtered_enron_dataset = DatasetDict(
    {split: filter_columns(ds, columns_to_keep) for split, ds in enron_dataset.items()}
)

# Apply the concatenation transformation
transformed_enron_dataset = filtered_enron_dataset.map(
    concatenate_email_lists, batched=True
)

# Rename and reorder columns, creating a new dataset
final_processed_dataset = DatasetDict(
    {
        split: rename_and_reorder_columns(ds)
        for split, ds in transformed_enron_dataset.items()
    }
)

# Save the processed dataset to the new path
final_processed_dataset.push_to_hub(processed_enron_dataset_path)

print(
    "Filtered, transformed, and reordered dataset has been uploaded to the Hugging Face Hub."
)
