from models.vq_bottleneck.model import VQBFMT
from utils.task_config import TaskConfig

class VQBFMTForONNX(VQBFMT):
  def __init__(self, cfg: TaskConfig):
    super().__init__(cfg)
  def forward(self, input_ids, segment_ids, attention_mask, audio_data, video_data):
    if self.no_align:
      audio_data = None if self.no_audio else self.a_conv(audio_data if self.audio_vertical else audio_data.permute(0,2,1).contiguous())
      video_data = None if self.no_video else self.v_conv(video_data if self.video_vertical else video_data.permute(0,2,1).contiguous())
    return self.model(input_ids, segment_ids, attention_mask, audio_data, video_data, None, False)[0]