# -*- coding: utf-8 -*-

import collections.abc as container_abcs
from typing import List, Union

import numpy as np
import torch
import torch.jit as jit
import torch.nn as nn
import torch.nn.functional as F

from ..common.logger import LOGGER


def import_class(entry_class):
    if not isinstance(entry_class, str):
        return entry_class
    import importlib
    package, name = entry_class.rsplit('.', 1)
    module = importlib.import_module(package)
    return getattr(module, name)


def set_random_seed(seed, logger=LOGGER):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def print_version(logger=LOGGER):
    logger.info('############################################################')
    logger.info('PyTorch version: %s', torch.__version__)
    logger.info('Platform: %s', torch.platform.platform())
    logger.info('CUDA devices: %s', torch.cuda.device_count())
    logger.info('############################################################')


def summary_model(model, logger=LOGGER):
    def _summary(model_object, prefix=''):
        num_params, num_trainable = 0, 0
        for params in model_object.parameters():
            num_params += params.numel()
            if params.requires_grad:
                num_trainable += params.numel()
        logger.info('%s\n%s', prefix, repr(model_object))
        logger.info('params count: %.2fM/%.2fM',
                    num_trainable / 1024 / 1024, num_params / 1024 / 1024)

    if isinstance(model, nn.Module):
        model = [('', model)]
    else:
        model = vars(model).items()

    for name, value in model:
        if isinstance(value, nn.Module):
            _summary(value, name)


@jit.script
def sequence_mask(lengths, max_length: int = -1, batch_first: bool=True):
    if max_length == -1:
        max_length = int(torch.max(lengths))

    ranges = torch.arange(0, max_length, device=lengths.device, dtype=torch.long)
    if batch_first:
        mask = ranges.unsqueeze(0) < lengths.unsqueeze(1)
    else:
        mask = ranges.unsqueeze(1) < lengths.unsqueeze(0)
    return mask


def clip_and_renormalize(word_probs, epsilon):
    word_probs = word_probs.clamp(epsilon, 1.0 - epsilon)
    return word_probs / word_probs.sum(dim=-1, keepdim=True)


def cross_entropy_nd(inputs, targets, **kwargs):
    entropy = F.cross_entropy(inputs.view(-1, inputs.size(-1)), targets.view(-1), **kwargs)
    if entropy.ndimension() == 1:
        entropy = entropy.view(*targets.size())
    return entropy


def broadcast_gather(input, dim, index, out=None):
    if len(index.shape) < len(input.shape):
        old_shape = index.shape
        index = index.view(index.shape + (1,) * (len(input.shape) - len(index.shape)))
        expand_params = (-1,) * len(old_shape) + input.shape[len(old_shape):]
        index = index.expand(*expand_params)
    return torch.gather(input, dim, index, out=out)


def sort_and_pack_sequences(seqs, lengths, pack=True, is_sorted=False):
    if not is_sorted:
        length_sorted, sort_indices = lengths.sort(descending=True)
        _, unsort_indices = sort_indices.sort()
        zero_indices = torch.masked_select(torch.arange(0, length_sorted.shape[0],
                                                        device=seqs.device),
                                           length_sorted.eq(0))
        if zero_indices.shape[0] > 0:
            zero_pos = zero_indices.min()
            length_sorted = length_sorted[:zero_pos]
            sort_indices = sort_indices[:zero_pos]

        seqs_sorted = seqs.index_select(0, sort_indices)
    else:
        seqs_sorted = seqs
        length_sorted = lengths
        unsort_indices = None

    if pack:
        packed_seqs = \
            nn.utils.rnn.pack_padded_sequence(seqs_sorted, length_sorted, batch_first=True)
        return packed_seqs, unsort_indices
    else:
        return seqs_sorted, length_sorted, unsort_indices


def pad_timestamps_and_batches(seqs, original_shape):
    if seqs.shape[0] < original_shape[0] or seqs.shape[1] < original_shape[1]:
        seqs = torch.nn.functional.pad(seqs,
                                       [0, 0,
                                        0, original_shape[1] - seqs.shape[1],
                                        0, original_shape[0] - seqs.shape[0]])
    return seqs


def unpack_and_unsort_sequences(seqs, unsort_indices, original_shape=None, unpack=True):
    if unpack:
        seqs, _ = nn.utils.rnn.pad_packed_sequence(seqs, batch_first=True)
    if original_shape is not None:
        seqs = pad_timestamps_and_batches(seqs, original_shape)

    if unsort_indices is None:
        return seqs
    return seqs.index_select(0, unsort_indices)


def pad_and_stack_1d(tensors: List[torch.Tensor], dim=-1, pad=0, device=None):
    # TODO: use torch.jit.script
    if device is None:
        device = tensors[0].device
    if dim == -1:
        dim = max(len(tensor) for tensor in tensors)

    shape = (len(tensors), dim) + tensors[0].shape[1:]
    padded_tensor = torch.full(shape, pad, dtype=tensors[0].dtype, device=device)

    for index, tensor in enumerate(tensors):
        padded_tensor[index, :tensor.size(0)] = tensor

    return padded_tensor


def pad_and_stack_2d(tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]],
                     dim1=-1, dim2=-1, pad=0, device=None):
    # TODO: use torch.jit.script
    if device is None:
        device = tensors[0].device
    if dim1 == -1:
        dim1 = max(len(tensor) for tensor in tensors)

    if torch.is_tensor(tensors[0]):
        return pad_and_stack_2d_case1(tensors, dim1, dim2, pad, device)
    return pad_and_stack_2d_case2(tensors, dim1, dim2, pad, device)


def pad_and_stack_2d_case1(tensors, dim1, dim2, pad, device):
    if dim2 == -1:
        dim2 = max(tensor.size(1) for tensor in tensors)

    shape = (len(tensors), dim1, dim2) + tensors[0].shape[2:]
    padded_tensor = torch.full(shape, pad, dtype=tensors[0].dtype, device=device)

    for index, tensor in enumerate(tensors):
        padded_tensor[index, :tensor.size(0), :tensor.size(1)] = tensor

    return padded_tensor


def pad_and_stack_2d_case2(tensors, dim1, dim2, pad, device):
    if dim2 == -1:
        dim2 = max(len(tensor) for tensor_list in tensors for tensor in tensor_list)

    shape = (len(tensors), dim1, dim2) + tensors[0][0].shape[1:]
    padded_tensor = torch.full(shape, pad, dtype=tensors[0][0].dtype, device=device)

    for index1, tensor_list in enumerate(tensors):
        for index2, tensor in enumerate(tensor_list):
            padded_tensor[index1, index2, :tensor.size(0)] = tensor

    return padded_tensor


def simple_collate(batch, mode='stack'):
    """Puts each data field into a tensor"""
    if isinstance(batch[0], torch.Tensor):
        return getattr(torch, mode)(batch, 0)
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int):
        return torch.tensor(batch)
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: simple_collate([d[key] for d in batch], mode) for key in batch[0]}
    elif isinstance(batch[0], tuple):
        return type(batch[0])((simple_collate(list(samples), mode) for samples in zip(*batch)))

    return batch


def simple_decollate(batch, mode='unbind'):
    """Splits a tensor to batch tensors"""

    if torch.is_tensor(batch):
        if mode == 'split':
            return torch.split(batch, 1, 0)
        return torch.unbind(batch, 0)
    elif isinstance(batch, container_abcs.Mapping):
        data = [(key, simple_decollate(value, mode)) for key, value in batch.items()]
        size = len(data[0][1])

        assert all(len(value) == size for _, value in data)
        return [{key: value[index] for key, value in data} for index in range(size)]
    elif isinstance(batch, tuple):
        item_type = type(batch)
        data = [simple_decollate(value, mode) for value in batch]
        size = len(data[0])

        assert all(len(value) == size for value in data)
        return [item_type(value[index] for value in data) for index in range(size)]

    return batch
