import inspect
from itertools import zip_longest

import torch
from typing import Any, Optional

from torch import Tensor

from coli.basic_tools.common_utils import NoPickle


class AutoBatchMixin(object):
    parameters: Any
    __call__: Any
    forward: Any

    def __init__(self, *args, **kwargs):
        super(AutoBatchMixin, self).__init__(*args, **kwargs)
        self.pending_inputs = []
        self.has_processed = 0
        self.results: Optional[Tensor] = None

        self.stack_types = NoPickle(self.get_stack_types())

    def get_stack_types(self):
        parameters = list(inspect.signature(self.forward).parameters.values())
        return [i.annotation for i in parameters]

    def add_input(self, *args):
        result_idx = len(self.pending_inputs)
        self.pending_inputs.append(args)
        return result_idx

    def calculate_results(self):
        device = next(self.parameters()).device
        if not getattr(self, "stack_types", None):
            self.stack_types = NoPickle(self.get_stack_types())
        if self.has_processed == len(self.pending_inputs):
            return

        def smart_stack(maybe_tensor_list, expect_type):
            if issubclass(expect_type, list):
                return maybe_tensor_list
            if isinstance(maybe_tensor_list[0], Tensor):
                return torch.stack(maybe_tensor_list)
            if isinstance(maybe_tensor_list[0], (int, float)):
                return torch.tensor(maybe_tensor_list, device=device)
            return maybe_tensor_list

        def smart_cat(maybe_tensor_1, maybe_tensor_2):
            if isinstance(maybe_tensor_1, Tensor):
                return torch.cat([maybe_tensor_1, maybe_tensor_2], dim=0)
            if isinstance(maybe_tensor_1, list):
                return maybe_tensor_1 + maybe_tensor_2
            raise Exception("Invalid concat")

        batch_data = [smart_stack(i, expect_type)
                      for i, expect_type in zip_longest(
                zip(*self.pending_inputs[self.has_processed:]),
                self.stack_types
            )]
        outputs = self(*batch_data)
        if self.results is None:
            self.results = outputs
        else:
            if isinstance(outputs, Tensor):
                self.results = smart_cat(self.results, outputs)
            else:
                assert isinstance(outputs, tuple)
                self.results = tuple(smart_cat(old_result, result)
                                     for old_result, result in zip(self.results, outputs))
        self.has_processed = len(self.pending_inputs)

    def refresh(self):
        self.has_processed = 0
        self.pending_inputs = []
        self.results = None


class AutoBatchModule(AutoBatchMixin, torch.nn.Module):
    pass


class AutoBatchSequential(AutoBatchMixin, torch.nn.Sequential):
    pass
