from torch import nn
from torch.nn.utils.rnn import pad_sequence
import torch, argparse
from .SelectiveHistory.SelectiveHistoryModule import SelectiveHistoryModule

def get_dim_list(start_dim, inter_dim, last_dim=None):
    """
    Create a list of dimensions based on provided inputs.

    Args:
        start_dim (int): The starting dimension to include in the list.
        inter_dim (str): A comma-separated string specifying intermediate dimensions.
        last_dim (int, optional): The last dimension to include in the list. Defaults to None.

    Returns:
        List[int]: A list of integers representing the dimensions, including the starting and, optionally, the last dimension.
    """

    # Add the starting dimension to dim_list.
    dim_list = [start_dim] +\
                [int(x) for x in inter_dim.split(",")]
    
    # Optionally, add the last dimension.
    if(last_dim is not None):  dim_list.append(last_dim)
    return dim_list

def make_flexible_pooling_layer(start_dim, inter_dim, last_dim=None, dropout_rate=0.1, is_predict_layer=False):
    """
    Create a flexible pooling layer with specified dimensions.

    Args:
        start_dim (int): The input dimension of the first layer.
        inter_dim (str): A comma-separated string specifying intermediate dimensions.
        last_dim (int, optional): The output dimension of the last layer. Defaults to None.
        dropout_rate (float, optional): The dropout rate to apply after each linear layer. Defaults to 0.1.
        is_predict_layer (bool, optional): If True, returns all layers except dropout and activation for prediction. Defaults to False.

    Returns:
        List[nn.Module]: A list of PyTorch nn.Module layers comprising the flexible pooling layer.
    """

    # Create a list of dimensions based on the provided inputs.
    dim_list = get_dim_list(start_dim, inter_dim, last_dim)

    # Initialize a list to store the layers of the flexible pooling module.
    doc_pooling_layer = []
    for in_dim, out_dim in zip(dim_list[:-1], dim_list[1:]):
        doc_pooling_layer.append(nn.Linear(in_dim, out_dim, bias=False))
        doc_pooling_layer.append(nn.GELU())
    doc_pooling_layer.append(nn.Dropout(dropout_rate))
    
    # If this is a prediction layer, return all layers except the last two (dropout and activation).
    if(is_predict_layer):
        return doc_pooling_layer[:-2]
    
    # Otherwise, return the entire flexible pooling module.
    else:
        return doc_pooling_layer
    
def get_sentence_emb(sentence_encoder, token_ids, attention_mask, max_tokens):
    """
    Obtain sentence embeddings using a transformer-based sentence encoder.

    Args:
        sentence_encoder (nn.Module): The sentence encoder model.
        token_ids (Tensor): Input token IDs for the sentences. Shape: [batch_size, sequence_length]
        attention_mask (Tensor): Attention mask for the input tokens. Shape: [batch_size, sequence_length]
        max_tokens (int): Maximum number of tokens to process in a single forward.

    Returns:
        Tensor: The sentence embeddings produced by the sentence encoder. Shape: [batch_size, hidden_size]
    """

    if(max_tokens == 0):
        _sentence_emb = sentence_encoder(input_ids = token_ids, attention_mask = attention_mask)[1]
    
    else:
        _sentence_emb = []
        num_sent_per_iter = max_tokens//token_ids.size(-1)
        for iter in range(token_ids.size(0)//num_sent_per_iter + 1):
            _token_ids = token_ids[iter*num_sent_per_iter:(iter+1)*num_sent_per_iter]
            _attention_mask = attention_mask[iter*num_sent_per_iter:(iter+1)*num_sent_per_iter]
            _sentence_emb.append(sentence_encoder(input_ids = _token_ids, attention_mask = _attention_mask)[1])
        _sentence_emb = torch.vstack(_sentence_emb)
        
    return _sentence_emb

def listtomatrix(target, length):
    """
    Convert a list of variable-length sequences into a padded matrix.

    Args:
        target (Tensor): List of input tensors, each representing a sequence. Shape: [batch_size x sequence_length, feature_dim]
        length (List[int]): List of sequence lengths for each input tensor.

    Returns:
        Tensor: Padded matrix where sequences are stacked along the batch dimension. Shape: [batch_size, max_sequence_length, feature_dim]
    """

    vec_tokens = []
    s = 0
    for l in length:
        vec_tokens.append(torch.narrow(target, 0, s, l))
        s += l
    vec_tokens = pad_sequence(vec_tokens, batch_first=True)
    return vec_tokens

def get_selective_history_module(hparams, fusion_dim, sentence_dim):
    if(hparams.memory_size != 0 and hparams.selective_history_type != ""):
        fusion_dim += sentence_dim
        selective_history = SelectiveHistoryModule(hparams.selective_history_type, sentence_dim)
        return selective_history, fusion_dim, True
    else: 
        return nn.Identity(), fusion_dim, False

def get_accustic_encoder(hparams, fusion_dim):
    used = True
    if(hparams.accustic_feature == "wav2vec"):
        from ..Accustic_feature_extractor.wav2vec import accustic_wav2vec
        accustic_encoder = accustic_wav2vec(hparams.dropout_rate)
        fusion_dim += accustic_encoder.accustic_dim
    elif(hparams.accustic_feature == "rnn"):
        from ..Accustic_feature_extractor.mfcc import accustic_mfcc
        accustic_encoder = accustic_mfcc(hparams.dropout_rate)
        fusion_dim += accustic_encoder.accustic_dim
    else:
        accustic_encoder = nn.Identity()
        used = False
    return accustic_encoder, fusion_dim, used

def get_holistic_history_module(hparams, fusion_dim, sentence_dim):
    if(hparams.use_holistic_history):
        encoder = nn.GRU(
            sentence_dim, 
            sentence_dim, 
            batch_first=True, 
            dropout=hparams.dropout_rate
        )
        used = True
        fusion_dim += sentence_dim
    else:
        encoder = nn.Identity()
        used = False
    return encoder, fusion_dim, used

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')