from __future__ import annotations

import numpy as np
import pandas as pd

import util as U
from slicing.split import split
from slicing.variable import Variable as V


class Slice:
  """ A slice of the data, representing a subset of data points.

  == Attributes ==
    df: DataFrame containing the rows corresponding to this slice.
    id: Fixed values of the fixed variables in this slice (all variables in self.split_by).
    split_by: List of variables that don't differ between points in this slice.
    title: Name for slice.
    y: Real values of sp-BLEU.
       dim: n if n df rows.

  == Methods ==
    get_title: Returns title.
    x: Returns values of specified xvars in this slice.
  """
  df: pd.DataFrame
  id: pd.Series
  split_by: list[V]
  title: str
  y: np.ndarray[U.FloatT]

  def __init__(self, df: pd.DataFrame, id: pd.Series, split_by: list[V]) -> None:
    """ Initializes slice. """
    self.df, self.id, self.split_by = df, id, split_by
    self.title = self.get_title()
    self.y = self.df.loc[:, "sp-BLEU"].to_numpy()

  def get_title(self) -> str:
    """ Returns title for slice.
    
    == Return Values ==
    title: Values in id seperated by "-".
    """
    if len(self.split_by) == 0:
      return "all", "all" 
    vals = []
    for var in self.split_by:
      if "size" in var.title:
        vals.append(str(self.id[var.title]) + "k")
      else:
        vals.append(str(self.id[var.title]))
    return '-'.join(vals)
  
  def x(self, xvars: list[V]) -> np.ndarray[U.FloatT]:
    """ Returns values of xvars for the points in this slice."""
    return self.df.loc[:, [var.title for var in xvars]].astype(float).to_numpy()
  
  def __len__(self):
    return len(self.df)

  def __repr__(self) -> str:
    return self.title

class SliceGroup:
  GROUPS = {}
  """ Group of all slices when slicing by the variables in self.split_by.

  == Attributes ==
    ids: DataFrame containing the ids of the slices.
    slices: List of slices.
    split_by: List of variables partitioned over.

  == Static Methods ==
    get_slices: Takes lists of variable types and returns a corresponding
                instance of SliceGroup.
  """
  ids: pd.DataFrame
  slices: list[Slice]
  split_by: list[V]

  def __init__(self, split_by: list[V]) -> None:
    """ Initializes SliceGroup."""
    self.split_by = split_by
    self.ids, slices = split(self.split_by)
    self.slices = [Slice(slices[i], self.ids.iloc[i], self.split_by)
                   for i in range(len(slices))]

  @staticmethod
  def get_instance(split_by: list[V]) -> SliceGroup:
    """ If the same slice group has not been yet initialized, initializes it and saves it in GROUPS. 
    Otherwise, returns the previously initialized slice group.
    """
    flags = tuple([var in split_by for var in V.main()])
    if flags not in SliceGroup.GROUPS:
      SliceGroup.GROUPS[flags] = SliceGroup(split_by)
    return SliceGroup.GROUPS[flags]
    
  def __len__(self):
    return len(self.slices)

  def __repr__(self):
    return '+'.join(map(V.__repr__, self.split_by))
  
  def repr_ids(self):
    return [slice.__repr__() for slice in self.slices]