# -*- coding: utf-8 -*-

import numpy as np


def sequence_mask(lengths, maxlen=None):
    if not isinstance(lengths, np.ndarray):
        lengths = np.array(lengths)
    if maxlen is None:
        maxlen = lengths.max()

    mask = np.arange(0, maxlen, dtype=lengths.dtype)
    return (mask < np.expand_dims(lengths, axis=-1)).astype(np.uint8)


def pad_2d_values(in_values, dim1=None, dim2=None, dtype=np.int64, pad=0):
    if dim1 is None or dim2 is None:
        dim1 = len(in_values)
        dim2 = max(len(x) for x in in_values)
    out_values = np.full((dim1, dim2), pad, dtype=dtype)
    dim1 = min(len(in_values), dim1)
    for i in range(dim1):
        values = in_values[i]
        current_dim2 = min(dim2, len(values))
        out_values[i, :current_dim2] = values[:current_dim2]
    return out_values


def pad_3d_values(in_values, dim1=None, dim2=None, dim3=None, dtype=np.int64, pad=0):
    if dim1 is None or dim2 is None or dim3 is None:
        dim1 = len(in_values)
        dim2 = max(len(x) for x in in_values)
        dim3 = 0
        for value in in_values:
            dim3 = max(dim3, max(len(x) for x in value))
    out_values = np.full((dim1, dim2, dim3), pad, dtype=dtype)
    dim1 = min(dim1, len(in_values))
    for i in range(dim1):
        values_i = in_values[i]
        current_dim2 = min(dim2, len(values_i))
        for j in range(current_dim2):
            values_ij = values_i[j]
            current_dim3 = min(dim3, len(values_ij))
            out_values[i, j, :current_dim3] = values_ij[:current_dim3]
    return out_values
