import itertools
import typing as T

import numpy as np
import pandas as pd

from slicing.variable import Variable as V
import util as U


# == Splitting Functions ==
def split(split_by: list[V], df: pd.DataFrame=U.RECORDS) -> T.Tuple[pd.DataFrame, list[pd.DataFrame]]:
  """ Returns list of ids and dataframes corresponding to split."""
  ids, slices = [], []
  prd = list(itertools.product(*[var.values(df) for var in split_by]))
  for comb in prd:
    id = []
    slice = df
    for i, var in enumerate(split_by):
      id.append(comb[i])
      if pd.isna(comb[i]):
        slice = slice[pd.isna(slice[var.title])]
      else:
        slice = slice[slice[var.title] == comb[i]]

    ids.append(id)
    slices.append(slice)
  cols, dtypes = [var.title for var in split_by], dict([(var.title, var.dtype) for var in split_by])
  ids = pd.DataFrame(np.array(ids), columns=cols).astype(dtypes)
  return ids, slices