import torch
from torch import nn

class SelectiveHistoryModule(nn.Module):
    def __init__(self, attention_type, sentence_dim):
        super().__init__()
        if(attention_type == "MHA"):
            from .MultiHeadAttention import MultiHeadAttentionModel
            self.memory_attention = MultiHeadAttentionModel(
                hidden_size=sentence_dim, 
                num_heads=6
            )

        elif(attention_type == "TAA"):
            from .TimeAwareAttention import TimeAwareAttention
            self.memory_attention = TimeAwareAttention(hidden_size=sentence_dim)

    def forward(self, current_emb, history_emb, history_size):
        """
        Forward pass of the SelectiveHistoryModule.

        Args:
            current_emb (Tensor): Embeddings for the current utterance. Shape: [batch_size, embedding_dim]
            history_emb (Tensor): Embeddings for historical utterances. Shape: [batch_size, max_history_size, embedding_dim]
            history_size (Tensor): The number of historical utterances for each sample in the batch. Shape: [batch_size]

        Returns:
            Tensor: Memory embeddings that capture selective historical information. Shape: [batch_size, memory_dim]
        """
        # Get the batch size.
        b_size = current_emb.size(0)

        # Create a history mask to attend only to valid historical utterances. Shape: [batch_size, max_history_size]
        history_mask = self.make_make(history_size, b_size)

        # Apply memory attention to selectively aggregate history embeddings.
        memory_emb, _ = self.memory_attention(
            memory=history_emb,
            current_utterance=current_emb.unsqueeze(1),
            attention_mask=history_mask
        )

        return memory_emb.reshape((b_size, -1))

    def make_make(self, history_size, b_size):
        history_mask = torch.arange(
            max(history_size), 
            device=history_size.device,
            dtype=history_size.dtype
        )

        history_mask = history_mask.expand(b_size, max(history_size))
        history_mask = history_mask < history_size.unsqueeze(1)
        return history_mask