from transformers import BertTokenizer
from functools import reduce
from logging import getLogger
import torch

logger = getLogger(__name__)

class BertInputData(object):
  """A single training/test example for simple sequence classification."""

  def __init__(self, guid, text_a, text_b=None, label=None):
    """Constructs a InputExample.
    Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
        sequence tasks, only this sequence must be specified.
        text_b: (Optional) string. The untokenized text of the second sequence.
        Only must be specified for sequence pair tasks.
        label: (Optional) string. The label of the example. This should be
        specified for train and dev examples, but not for test examples.
    """
    self.guid = guid
    self.text_a = text_a
    self.text_b = text_b
    self.label = label


class InputFeature():
  """A single set of features of data."""

  def __init__(self, guid, input_ids, input_mask, segment_ids, label_id, sentences: "list[InputFeature]" = None):
    self.guid = guid
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_id = label_id
    self.sentences = sentences

def split_by_delimiter(ipt: "list[str]"):
  """split token list by delimiters: ".", "?" and "!"."""
  def splitter(acc: "list[list[str]]", cur: str):
    acc[-1].append(cur)
    if cur in [".", "?", "!"]: acc.append([])
    return acc

  return reduce(splitter , ipt, [[]])


def convert_feature(utterances: "list[BertInputData]", seq_length_max: int, tokenizer: BertTokenizer):
  features: list[InputFeature] = []

  for _, utterance in enumerate(utterances):
    label_id = [float(label) for label in utterance.label]

    tokens = tokenizer.encode_plus(utterance.text_a, padding='max_length', max_length=seq_length_max, truncation=True)
    token_ids = torch.tensor(tokens["input_ids"])
    input_mask = torch.tensor(tokens["attention_mask"])
    segment_ids = torch.tensor(tokens["token_type_ids"] if "token_type_ids" in tokens else [0] * seq_length_max)

    assert len(token_ids) == seq_length_max
    assert len(input_mask) == seq_length_max
    assert len(segment_ids) == seq_length_max

    features.append(InputFeature(utterance.guid, token_ids, input_mask, segment_ids, label_id))

  return features
