import abc
import argparse
import math
import os
import pandas as pd
from enum import Enum, auto
import torch.nn.functional as F
import numpy as np

from typing import Union, List

from transformers import BertTokenizer
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler, BatchSampler

from utils.feature_converter import convert_feature, BertInputData

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  @abc.abstractmethod
  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    pass

  @abc.abstractmethod
  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    pass

  @abc.abstractmethod
  def get_labels(self):
    """Gets the list of labels for this data set."""
    pass

  @classmethod
  def _read_tsv(cls, input_file):
    """Reads a tab separated value file."""
    with open(input_file, "r") as f:
      return pd.read_table(f)

class TextDataSampler(DataProcessor):
  def get_train_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    return ["0", "1"]

  def _create_examples(self, df: pd.DataFrame, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    txt_idx = df.columns.get_loc('text')
    for (i, series) in df.iterrows():
      guid = series["guid"] if "guid" in series else i
      text_a = series["text"]
      label = series[(txt_idx+1):].to_list()
      examples.append(BertInputData(guid=guid, text_a=text_a, text_b=None, label=label))
    return examples

class ModalityDatasource(object):
  def __init__(self) -> None:
    self.audio_data = None
    self.video_data = None

class ModelInput(object):
  def __init__(self, d_dict: dict) -> None:
    self.input_ids: Tensor = d_dict.get('input_ids').long()
    self.segment_ids: Tensor = d_dict.get('segment_ids').long()
    self.attention_mask: Tensor = d_dict.get('attention_mask').long()
    self.audio_data: Tensor = None if d_dict.get('audio_data') is None else d_dict.get('audio_data').float()
    self.video_data: Tensor = None if d_dict.get('video_data') is None else d_dict.get('video_data').float()
    self.labels: Tensor = d_dict.get('labels').float()

  def to(self, device):
    tensor: Tensor
    for (k, tensor) in self.__dict__.items():
      if tensor is None: continue
      self.__dict__[k] = tensor.to(device)
    return self

class ExecMode(Enum):
  TRAIN = auto()
  EVAL = auto()
  TEST = auto()

def get_size(datum: Union[Tensor, List[Tensor]]):
  return datum.size(0) if torch.is_tensor(datum) else len(datum)

class ListOrTensorDataset(Dataset):
  data: pd.DataFrame

  def __init__(self, data: "dict[str, Union[Tensor, List[Tensor]]]") -> None:
    assert all(get_size(list(data.values())[0]) == get_size(datum) for datum in data.values()), "Size mismatch between tensors"
    self.data = pd.DataFrame.from_dict(data)
    self.data.set_index('guid')

  def __getitem__(self, index):
    target: pd.Series = self.data.iloc[index] if type(index) == int else self.data.loc[index]
    return target.to_dict()

  def __len__(self):
    return get_size(self.data)

def empty_tensor(total): return torch.zeros((total, 1, 1))

def convert_datasrc_to_dataset(l_datasrc, datasources: ModalityDatasource, emo_cls=False, sent_bin=False):
  """
  Return value format is:
  TensorDataset: (
    input_id, input_segment_id, input_mask, audio, audio_raw, video, video_raw, sema, input_label_id
  )
  """
  guid_full = [f.guid for f in l_datasrc]
  input_id_full = [f.input_ids for f in l_datasrc]
  input_mask_full = [f.input_mask for f in l_datasrc]
  input_segment_id_full = [f.segment_ids for f in l_datasrc]

  if emo_cls:
    def pick_max_emo(l):
      emo = torch.tensor(l[1:])
      return [l[0], 0 if torch.max(emo).item() <= 0.1 else torch.argmax(emo).item() + 1]
    input_label_id_full = [torch.tensor(pick_max_emo(f.label_id), dtype=torch.float32) for f in l_datasrc]
  elif sent_bin:
    input_label_id_full = [torch.tensor([1,0] if f.label_id[0] >= 0 else [0,1], dtype=torch.float32) for f in l_datasrc]
  else:
    input_label_id_full = [torch.tensor(f.label_id, dtype=torch.float32) for f in l_datasrc]

  ds_dict = { k: [torch.tensor(e, dtype=torch.float32) for e in data] for (k, data) in vars(datasources).items() if data is not None }

  return ListOrTensorDataset({
    'guid': guid_full,
    'input_ids': input_id_full,
    'segment_ids': input_segment_id_full,
    'attention_mask': input_mask_full,
    'labels': input_label_id_full,
    **ds_dict
  })

class BalancedSampler(BatchSampler):
  def __init__(self, dataset: ListOrTensorDataset, batch_size: int) -> None:
    self.batch_size = batch_size
    label_count = len(dataset.data['labels'][0])
    total_count = len(dataset)
    labels = dataset.data['labels'].to_list()
    self.labels_indices_removal = {i: [] for i in range(label_count)}
    self.labels_indices = {i: [] for i in range(label_count)}
    for i, v in enumerate(labels):
      self.labels_indices[torch.argmax(v).item()].append(i)
      self.labels_indices_removal[torch.argmax(v).item()].append(i)

    max_count = max([len(i) for i in self.labels_indices.values()])
    self.index_weights_raw = {i: (len(self.labels_indices[i]) / total_count) * (max_count / len(self.labels_indices[i]) ** (4/5)) for i in range(label_count)}
    # self.index_weights_raw = {i: min([1 / label_count, len(self.labels_indices[i]) / total_count]) for i in range(label_count)}
    sum_index_weights_raw = sum(self.index_weights_raw.values())
    self.label_per_batch = {i: math.floor(batch_size * self.index_weights_raw[i] / sum_index_weights_raw) for i in range(label_count)}
    sum_label_per_batch = sum(self.label_per_batch.values())
    diff_batch = batch_size - sum_label_per_batch
    for _ in range(diff_batch):
      self.label_per_batch[min(self.label_per_batch, key=self.label_per_batch.get)] += 1

    # min_label_per_batch = batch_size // label_count
    # mod_batch = batch_size % label_count
    # self.label_per_batch = {i: min_label_per_batch + (0 if mod_batch >= i else 1) for i in range(label_count)}

    print(f"batch arrangement: {self.label_per_batch}")
    self.max_label = 0
    self.sampler_length = 0
    max_label_count = 0
    for i, v in self.labels_indices.items():
      if math.ceil(len(v) / self.label_per_batch[i]) > max_label_count:
        self.max_label = i
        self.sampler_length = max_label_count = math.ceil(len(v) / self.label_per_batch[i])

  def __len__(self):
    return self.sampler_length

  def __iter__(self):
    for _ in range(self.sampler_length):
      datalist = []
      for label, idxs in self.labels_indices_removal.items():
        try:
          data: list = np.random.choice(idxs, self.label_per_batch[label], replace=False).tolist()
          for idx in data: idxs.remove(idx)
        except:
          data = idxs
          data_plus: list = np.random.choice(self.labels_indices[label], self.label_per_batch[label] - len(data), replace=False).tolist()
          data = data + data_plus
        datalist += data
      yield datalist


def prepare_dataloader(
  tokenizer: BertTokenizer, datasources: ModalityDatasource, args: argparse.Namespace, mode: ExecMode, lang_src_override=None, desire_batch_size=None
):
  processor = TextDataSampler()
  batch_base, sample_provider, sampler_cls = \
         (args.train_batch, processor.get_train_examples, BalancedSampler if args.balanced_sampler else None) if mode == ExecMode.TRAIN \
    else (args.eval_batch, processor.get_dev_examples, SequentialSampler) if mode == ExecMode.EVAL \
    else (args.test_batch, processor.get_test_examples, SequentialSampler) if mode == ExecMode.TEST \
    else (None, None, None)

  if batch_base == None: raise argparse.ArgumentError("argument \"mode\" must be specified")

  l_samples = sample_provider(lang_src_override or args.src_lang)
  l_datasrc = convert_feature(l_samples, args.seq_limit, tokenizer)

  batch = (batch_base // args.gradient_accumulation_steps) if mode == ExecMode.TRAIN else batch_base

  dataset = convert_datasrc_to_dataset(l_datasrc, datasources, args.emo_cls, args.sent_bin)
  sampler = None if sampler_cls is None else sampler_cls(dataset, batch) if sampler_cls == BalancedSampler else sampler_cls(dataset)

  if desire_batch_size is not None:
    batch_base = (len(dataset) // desire_batch_size)

  def collate_fn(dlist: "list[dict]"):
    def pads(d):
      if d.get('audio_data') is not None:
        a_tpl = (0, (args.audio_raw_seq_limit if args.no_align else args.seq_limit) - d['audio_data'].size(0))
        for _ in range(len(d['audio_data'].size()) - 1): a_tpl += (0, 0)
        d['audio_data'] = F.pad(d['audio_data'], a_tpl)
      if d.get('video_data') is not None:
        v_tpl = (0, (args.video_raw_seq_limit if args.no_align else args.seq_limit) - d['video_data'].size(0))
        for _ in range(len(d['video_data'].size()) - 1): v_tpl += (0, 0)
        d['video_data'] = F.pad(d['video_data'], v_tpl)
      return d
    return ModelInput(default_collate([pads(d) for d in dlist]))

  # def collate_fn_w2v2(dlist: "list[dict]"):
  #   max_size = max([d['audio_data'].size(0) for d in dlist])
  #   print(f"max_size is {max_size}")
  #   def pads(d):
  #     lgth = d['audio_data'].size(0)
  #     print(f"difference: {max_size - lgth}")
  #     d['audio_data'] = F.pad(d['audio_data'], (0, max_size - lgth))
  #     print(f"aud size: {d['audio_data'].size(0)}")
  #     d['attention_mask'] = torch.tensor([1] * lgth + [0] * (max_size - lgth))
  #     print(f"msk size: {d['attention_mask'].size(0)}")
  #     return d
  #   return ModelInput(default_collate([pads(d) for d in dlist]))

  loader_opt = {
    'sampler' if sampler_cls != BalancedSampler else 'batch_sampler': sampler,
    'collate_fn': collate_fn
  }
  if sampler_cls != BalancedSampler:
    loader_opt['batch_size'] = batch
    loader_opt['shuffle'] = True if mode == ExecMode.TRAIN else False

  dataloader = DataLoader(dataset, **loader_opt)

  return dataloader
