import chunk
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F

from memformers.recurrent_training.recurrent_memory import RecurrentMemory

# pylint:disable=no-member


class NoneMemory(RecurrentMemory):
    def __init__(self, memory_states, batch_size: int):
        super().__init__(memory_states, batch_size)

    def to(self, device: torch.device):
        pass

    def update(self, new_memory_states) -> RecurrentMemory:
        self.memory_states = new_memory_states

    def chunk(self, chunks: int) -> List[RecurrentMemory]:
        """
        chunks (int): number of chunks to return 
        """

        return [NoneMemory(None, 1) for _ in range(chunks)]

    def retain_grad(self):
        pass

    def backward(self, grad):
        pass

    def detach_(self):
        pass

    @property
    def grad(self):
        return None
