from torch.nn.utils.rnn import pack_padded_sequence
import librosa, torch
from torch import nn

class accustic_mfcc(nn.Module):
    def __init__(self, dropout_rate=0.1) -> None:
        super().__init__()
        self.accustic_dim = 64
        self.accustic_rnn = nn.GRU(
            13, 
            self.accustic_dim, 
            batch_first=True, 
            dropout = dropout_rate
        )
        
    def _get_mfcc_feature(self, SPEECH_WAVEFORM, SAMPLE_RATE=16000, window_len=25, step_len=10):
        mfcc = librosa.feature.mfcc(
            y=SPEECH_WAVEFORM, sr=SAMPLE_RATE, n_mfcc=13, center=False,
            n_fft=SAMPLE_RATE//1000*window_len, hop_length=SAMPLE_RATE//1000*step_len
        )
        return torch.from_numpy(mfcc).permute(1, 0)
    
    def get_mfcc_feature(self, SPEECH_WAVEFORM):
        device = SPEECH_WAVEFORM.device
        result_wave = []
        for wave in SPEECH_WAVEFORM.cpu().numpy():
            result_wave.append(self._get_mfcc_feature(wave))
        return torch.stack(result_wave).to(device)

    def forward(self, accustic_feature):
        accustic_feature = self.get_mfcc_feature(accustic_feature)
        _, hn = self.accustic_rnn(accustic_feature)
        audio_emb = hn[-1]
        return audio_emb