import torch, torchaudio
from torch import nn
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

class accustic_wav2vec(nn.Module):
    def __init__(self, dropout_rate=0.1) -> None:
        super().__init__()
        model_name = "facebook/wav2vec2-base"
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
        self.resampler = torchaudio.transforms.Resample(8_000, 16_000)
        self.dropout = nn.Dropout(dropout_rate)
        self.pooling_mode = "mean"
        self.accustic_dim = self.wav2vec2.config.hidden_size

        # self.freeze_model()

    def freeze_model(self):
        self.wav2vec2.freeze_feature_extractor()
        self.wav2vec2.freeze_feature_encoder()

    def merged_strategy(
            self,
            hidden_states,
            mode="mean"
    ):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception(
                "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")

        return outputs

    def forward(self, accustic_feature):
        bs = accustic_feature.real.size(0)
        accustic_feature = self.resampler(accustic_feature)

        input_feature = self.feature_extractor(
            accustic_feature, 
            return_tensors="pt", 
            sampling_rate=16000
        ).input_values.squeeze().to(accustic_feature.device)

        input_feature = input_feature.reshape((bs, -1))

        # with torch.no_grad():
        hidden_states  = self.wav2vec2(input_feature)[0]
        hidden_states  = self.merged_strategy(hidden_states, self.pooling_mode)
        return hidden_states