from datasets import load_dataset, concatenate_datasets
import polars as pl
import re
import numpy as np
import random


def get_imdb_ds(min_rating: int = None, max_rating: int = None):
    """
    min_rating: int
    max_rating: int, inclusive

    return ds format:
        authorId: str
        text: str
        textId: str
        other...
    """
    ds = load_dataset("tasksource/imdb62")
    ds = ds["train"]  # This dataset only has train split
    # to polars
    ds = (
        pl.from_arrow(ds.data.table)
        .rename({"userId": "authorId", "content": "text", "reviewId": "textId"})
        .with_columns(
            authorId=pl.col("authorId").cast(pl.String()),
            textId=pl.col("textId").cast(pl.String()),
        )
        .filter(pl.col("text").is_not_null())
    )
    if min_rating is not None:
        ds = ds.filter(pl.col("rating") >= min_rating)
    if max_rating is not None:
        ds = ds.filter(pl.col("rating") <= max_rating)
    ds = ds.filter(pl.len().over("authorId") >= 2)
    return ds


def get_blog_ds(gender=None, min_age=None, max_age=None):
    """
    gender: str, "male" or "female"
    min_age: int
    max_age: int, inclusive

    return ds format:
        authorId: str
        text: str
        textId: str
        gender: str
        age: str
        topic: str
    """
    ds = load_dataset("tasksource/blog_authorship_corpus")
    ds = ds["train"]  # This dataset only has train split
    
    ds = ds.remove_columns(["sign"])
    # to polars

    ds = (
        pl.from_arrow(ds.data.table)
        .rename({"id": "authorId"})
        .with_columns(authorId=pl.col("authorId").cast(pl.String()))
        .with_columns(
            textId=pl.concat_str(
                pl.col("authorId"), pl.lit("_"), pl.col("date").cast(pl.String())
            )
        )
        .with_columns(text=pl.col("text").str.strip_chars())
        .with_columns(
            word_length=pl.col("text").str.replace("\n", " ").str.split(" ").list.len()
        )
        .filter((pl.col("word_length") < 600) & (pl.col("word_length") > 50))
        .filter(~pl.col("text").str.contains("[\u4e00-\u9fff]"))  # Filter out Chinese
    )
    if gender is not None:
        assert gender in ["male", "female"], f"Unknown gender: {gender}"
        ds = ds.filter(pl.col("gender") == gender)
    if min_age is not None:
        ds = ds.filter(pl.col("age") >= min_age)
    if max_age is not None:
        ds = ds.filter(pl.col("age") <= max_age)
    ds = ds.filter(pl.len().over("authorId") >= 2)

    return ds


def _distribute(count, num_pos):
    r = np.full(num_pos, count // num_pos)
    r[: count % num_pos] += 1
    return r


def make_n_author_m_shot_ds(ds, seed, num_authors, num_shots, num_test_rows):
    """
    Return:
    train_ds: subset containing num_shots for each of num_authors
    test_ds: only num_test_rows in total, evenly distributed among authors
    """
    full_authors = sorted(list(set(ds["authorId"])))
    rng = random.Random(seed)
    seed_for_selected_authors = rng.randint(0, 2**32)
    selected_authors = np.random.RandomState(seed_for_selected_authors).choice(
        full_authors, num_authors, replace=False
    )
    filtered_ds = ds.filter(ds["authorId"].is_in(selected_authors))

    num_test_per_author = {
        author: n
        for author, n in zip(selected_authors, _distribute(num_test_rows, num_authors))
    }
    support = (
        filtered_ds.group_by("authorId")
        .agg(pl.len())
        .rename({"len": "support"})
        .with_columns(
            required=num_shots
            + pl.col("authorId").replace(num_test_per_author, return_dtype=pl.UInt32)
        )
    )
    #  if any support < required
    for d in support.filter(pl.col("support") < pl.col("required")).iter_rows(
        named=True
    ):
        raise ValueError(
            f"Not enough rows for author {d['authorId']}. Has {d['support']}, needs {num_shots} for training and {num_test_per_author[d['authorId']]} for testing"
        )
    seed_for_extracted_ds = rng.randint(0, 2**32)
    seeds = {
        author: s
        for author, s in zip(
            selected_authors,
            np.random.RandomState(seed_for_extracted_ds).randint(0, 2**32, num_authors),
        )
    }
    extracted_ds = (
        filtered_ds.group_by("authorId", maintain_order=True)
        .map_groups(
            lambda df: df.sample(
                num_shots + num_test_per_author[df["authorId"][0]],
                with_replacement=False,
                shuffle=True,
                seed=seeds[df["authorId"][0]],
            )
        )
        .sort("authorId", maintain_order=True)
    )
    selected_authors_idices = {author: i for i, author in enumerate(selected_authors)}
    extracted_ds = (
        extracted_ds.with_columns(
            authorIdx=pl.col("authorId").replace(
                selected_authors_idices, return_dtype=pl.UInt32
            )
        )
        .sort("authorIdx", maintain_order=True)
        .drop("authorIdx")
    )
    train_ds = extracted_ds.filter(pl.int_range(pl.len()).over("authorId") < num_shots)
    test_ds = extracted_ds.filter(pl.int_range(pl.len()).over("authorId") >= num_shots)
    assert train_ds["authorId"].to_list() == selected_authors.repeat(num_shots).tolist()
    # assert (
    #     len(set(train_ds["textId"].to_list()) & set(test_ds["textId"].to_list())) == 0
    # )
    return train_ds, test_ds


if __name__ == "__main__":
    ds = get_imdb_ds()
    train_ds, test_ds = make_n_author_m_shot_ds(ds, 0)
    print(train_ds, test_ds)
    ds = get_blog_ds()
    print(ds.group_by("authorId").agg(pl.len()))
