from re import A
import torch
from torch import nn
from transformers import AutoModel
from transformers.models.bert import BertConfig, BertModel
from models.common.msu import MSU
from models.activations.func import tanh_exp
from models.common.transformer import EncoderLayer, TransformerEncoder
from utils.data_loader import ModelInput
from utils.task_config import TaskConfig
from g_mlp_pytorch import gMLP

from optimizer.class_balaced_loss import cb_loss

def calculate_loss(logits, labels, loss_type, class_size):
  if labels is None: return None
  if loss_type == 'regression':
    return torch.mean(torch.log(torch.cosh(logits - labels)))
  elif loss_type == 'classify':
    class_indices = torch.argmax(labels, -1)
    class_counts = []
    for i, _ in enumerate([1] * class_size):
      class_counts.append(class_indices[class_indices == i].size(-1))
    class_counts = torch.tensor(class_counts).to(logits.device)
    return cb_loss(labels, logits, class_counts)
  elif loss_type == 'bybrid':
    loss_fst = torch.mean(torch.log(torch.cosh(logits[:,0] - labels[:,0])))
    class_indices = torch.argmax(labels[:,1:], -1)
    class_counts = []
    for i, _ in enumerate([1] * class_size):
      class_counts.append(class_indices[class_indices == i].size(-1))
    loss_snd = cb_loss(labels[:,1:], logits[:,1:], class_counts)
    return loss_fst + loss_snd

class WA_MSFModel(nn.Module):
  def __init__(self, bert_model: BertModel, cfg: TaskConfig):
    super().__init__()
    no_audio = self.no_audio = cfg.no_audio; no_video = self.no_video = cfg.no_video
    if cfg.emb_num is None: cfg.emb_num = 70
    if cfg.fusion_layer is None: cfg.fusion_layer = 3
    if cfg.fusion_head is None: cfg.fusion_head = 16
    if cfg.cls_do_rate is None: cfg.cls_do_rate = 0.1
    if cfg.amlp_av is None: cfg.amlp_av = False
    if cfg.pre_msu is None: cfg.pre_msu = 20
    if cfg.post_msu is None: cfg.post_msu = 4
    if cfg.use_msu is None: cfg.use_msu = True
    if cfg.use_amlp is None: cfg.use_amlp = True
    assert not (cfg.no_audio and cfg.no_video)

    def gMLPBuilder(len=cfg.seq_limit, depth=1, dim=cfg.topic_dim, head_unit=cfg.fusion_head):
      return gMLP(dim=dim, depth=depth, circulant_matrix=True, seq_len=len, attn_dim=dim // head_unit, act=tanh_exp)

    self.bert = bert_model
    self.config: BertConfig = bert_model.config
    self.fwd_arg = {'return_dict': False}
    bert_enc_layers = self.bert.encoder.layer

    target_layer = cfg.msu_layers if cfg.msu_layers is not None else self.config.num_hidden_layers if cfg.no_cross_msu else max([2, cfg.bert_freeze_layers])
    self.bert_enc_f = nn.ModuleList(bert_enc_layers[:target_layer])
    self.bert_enc_l = nn.ModuleList(bert_enc_layers[target_layer:])

    self.amlp_av = cfg.amlp_av
    pre_msu = cfg.pre_msu
    post_msu = cfg.post_msu

    if not no_audio:
      if cfg.amlp_av:
        self.a_enc_f = gMLPBuilder(depth=pre_msu, dim=cfg.audio_dim, head_unit=cfg.audio_head)
        self.a_enc_l = gMLPBuilder(depth=post_msu, dim=self.config.hidden_size, head_unit=self.config.num_attention_heads)
      else:
        self.a_enc_f = TransformerEncoder(cfg.audio_dim, cfg.audio_head, pre_msu)
        self.a_enc_l = TransformerEncoder(self.config.hidden_size, self.config.num_attention_heads, post_msu)
      self.a_pooler = nn.MaxPool1d(cfg.seq_limit)

    if not no_video:
      if cfg.amlp_av:
        self.v_enc_f = gMLPBuilder(depth=pre_msu, dim=cfg.video_dim, head_unit=cfg.video_head)
        self.v_enc_l = gMLPBuilder(depth=post_msu, dim=self.config.hidden_size, head_unit=self.config.num_attention_heads)
      else:
        self.v_enc_f = TransformerEncoder(cfg.video_dim, cfg.video_head, pre_msu)
        self.v_enc_l = TransformerEncoder(self.config.hidden_size, self.config.num_attention_heads, post_msu)
      self.v_pooler = nn.MaxPool1d(cfg.seq_limit)

    self.use_msu = cfg.use_msu
    self.use_amlp = cfg.use_amlp
    if cfg.use_msu:
      self.applier = MSU(self.config.hidden_size, None if no_audio else cfg.audio_dim, None if no_video else cfg.video_dim, cfg.seq_limit, self.config.num_attention_heads)
    else:
      if not no_audio: self.atol_proj = nn.Linear(cfg.audio_dim, self.config.hidden_size)
      if not no_video: self.vtol_proj = nn.Linear(cfg.video_dim, self.config.hidden_size)

    target_scale = 2 if no_audio or no_video else 3
    self.topic_mode = cfg.topic_mode
    self.dense_topic_1 = nn.Linear(self.config.hidden_size * target_scale, cfg.topic_dim)
    self.dense_topic_2 = nn.Linear(cfg.topic_dim, cfg.topic_dim)

    self.proj_w_l = nn.Linear(self.config.hidden_size, cfg.topic_dim)
    if not no_audio: self.proj_w_a = nn.Linear(self.config.hidden_size, cfg.topic_dim)
    if not no_video: self.proj_w_v = nn.Linear(self.config.hidden_size, cfg.topic_dim)

    self.layer_num = cfg.fusion_layer
    if self.layer_num > 0:
      self.fused_dense = gMLPBuilder(len=cfg.seq_limit * (2 if no_audio or no_video else 3) + 1, depth=self.layer_num) if self.use_amlp \
        else nn.ModuleList([EncoderLayer(cfg.topic_dim, cfg.fusion_head) for _ in range(cfg.fusion_layer)])
    else:
      seq_len = cfg.seq_limit * (2 if no_audio or no_video else 3) + 1
      self.fused_dense = nn.Linear(seq_len, seq_len)

    self.proj_l = nn.Linear(self.config.hidden_size, cfg.topic_dim)
    if not no_audio: self.proj_a = nn.Linear(cfg.audio_dim, cfg.topic_dim)
    if not no_video: self.proj_v = nn.Linear(cfg.video_dim, cfg.topic_dim)


    self.dense = nn.Linear(cfg.topic_dim, cfg.topic_dim)

    self.cls_do = nn.Dropout(cfg.cls_do_rate)
    self.cls = nn.Linear(cfg.topic_dim, cfg.class_size)

    self.class_size = cfg.class_size
    self.loss_type = cfg.loss_type

  def get_extended_att_mask(self, mask):
    extended_attention_mask = mask.unsqueeze(1).unsqueeze(2)
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    return extended_attention_mask

  def forward(self, input_ids, sequence_ids, att_mask, a_data, v_data, labels, use_bce=False):
    ext_att_mask = self.get_extended_att_mask(att_mask)
    head_mask = [None] * self.config.num_hidden_layers
    token_type_ids = sequence_ids if sequence_ids is not None else torch.zeros_like(input_ids)
    l_emb = self.bert.embeddings(input_ids, token_type_ids=token_type_ids)
    self.bert.encoder.layer = self.bert_enc_f

    l_hidden, *_ = self.bert.encoder(l_emb, attention_mask=ext_att_mask, head_mask=head_mask, **self.fwd_arg)
    mask_bool = att_mask.bool()
    if self.amlp_av:
      a_hidden = self.a_enc_f(a_data) if not self.no_audio else None
      v_hidden = self.v_enc_f(v_data) if not self.no_video else None
    else:
      a_hidden = self.a_enc_f(a_data, mask_bool) if not self.no_audio else None
      v_hidden = self.v_enc_f(v_data, mask_bool) if not self.no_video else None

    if self.use_msu:
      l_hidden, a_hidden, v_hidden = self.applier(l_hidden, a_hidden, v_hidden, att_mask)
    else:
      if not self.no_audio: a_hidden = self.atol_proj(a_hidden)
      if not self.no_video: v_hidden = self.vtol_proj(v_hidden)

    self.bert.encoder.layer = self.bert_enc_l

    l_hidden, *_ = self.bert.encoder(l_hidden, attention_mask=ext_att_mask, head_mask=head_mask, **self.fwd_arg)
    if self.amlp_av:
      a_hidden = self.a_enc_l(a_hidden) if not self.no_audio else None
      v_hidden = self.v_enc_l(v_hidden) if not self.no_video else None
    else:
      a_hidden = self.a_enc_l(a_hidden, mask_bool) if not self.no_audio else None
      v_hidden = self.v_enc_l(v_hidden, mask_bool) if not self.no_video else None

    l_pool = self.bert.pooler(l_hidden)
    a_pool = self.a_pooler(a_hidden.permute(0,2,1)).permute(0,2,1) if not self.no_audio else None
    v_pool = self.v_pooler(v_hidden.permute(0,2,1)).permute(0,2,1) if not self.no_video else None

    cat_pool = torch.cat(
      (l_pool, a_pool.squeeze(1)) if self.no_video else \
      (l_pool, v_pool.squeeze(1)) if self.no_audio else \
      (l_pool, a_pool.squeeze(1), v_pool.squeeze(1))
    , -1)

    f_emb = self.dense_topic_2(tanh_exp(self.dense_topic_1(cat_pool)))

    l_t_hidden = self.proj_w_l(l_hidden)
    a_t_hidden = self.proj_w_a(a_hidden) if not self.no_audio else None
    v_t_hidden = self.proj_w_v(v_hidden) if not self.no_video else None

    if self.layer_num == 0:
      fused_mdls = (torch.cat((f_emb.unsqueeze(1), l_t_hidden) + (
        (a_t_hidden,) if self.no_video else (v_t_hidden,) if self.no_audio else (a_t_hidden, v_t_hidden)
      ), 1))
      fused_mdls = self.fused_dense(fused_mdls.permute(0,2,1)).permute(0,2,1)
    else:
      att_mask_bool = att_mask.bool()
      mask = torch.ones((att_mask_bool.size()[0],1), device=att_mask_bool.device).bool()
      mask = torch.cat((mask, att_mask_bool) + ((att_mask_bool,) if self.no_audio or self.no_video else (att_mask_bool, att_mask_bool)), -1)
      fused_mdls = (torch.cat((f_emb.unsqueeze(1), l_t_hidden) + (
        (a_t_hidden,) if self.no_video else (v_t_hidden,) if self.no_audio else (a_t_hidden, v_t_hidden)
      ), 1))
      if self.use_amlp:
        fused_mdls = self.fused_dense(fused_mdls)
      else:
        for i in range(self.layer_num):
          fused_mdls = self.fused_dense[i](fused_mdls, mask)

    fused_cls = fused_mdls[:,0,:]

    logits = self.cls(self.cls_do(torch.tanh(self.dense(fused_cls))))

    loss = calculate_loss(logits, labels, self.loss_type, self.class_size)

    if self.loss_type == 'classify':
      logits = torch.softmax(logits, -1)

    return logits, loss, None, None

class WA_MSF(nn.Module):
  def __init__(self, cfg: TaskConfig):
    super().__init__()
    self.bert = AutoModel.from_pretrained(cfg.bert_model)
    self.audio_vertical = cfg.a_conv_v
    self.video_vertical = cfg.v_conv_v
    self.no_audio = cfg.no_audio
    self.no_video = cfg.no_video
    self.no_align = cfg.no_align

    audio_dim_bk = cfg.audio_dim
    video_dim_bk = cfg.video_dim

    if self.no_align:
      self.a_conv = nn.Conv1d(cfg.audio_dim, cfg.seq_limit, 20, 10, 5) if not cfg.a_conv_v else nn.Conv1d(cfg.audio_raw_seq_limit, cfg.seq_limit, 1, 1, 0)
      self.v_conv = nn.Conv1d(cfg.video_dim, cfg.seq_limit, 3, 2, 1) if not cfg.v_conv_v else nn.Conv1d(cfg.video_raw_seq_limit, cfg.seq_limit, 1, 1, 0)
      if not cfg.a_conv_v: cfg.audio_dim = cfg.audio_raw_seq_limit // 10
      if not cfg.v_conv_v: cfg.video_dim = cfg.video_raw_seq_limit // 2
    self.model = WA_MSFModel(self.bert, cfg)
    cfg.audio_dim = audio_dim_bk
    cfg.video_dim = video_dim_bk

  def forward(self, ipt: ModelInput, use_bce=False):

    if self.no_align:
      ipt.audio_data = None if self.no_audio else self.a_conv(ipt.audio_data if self.audio_vertical else ipt.audio_data.permute(0,2,1).contiguous())
      ipt.video_data = None if self.no_video else self.v_conv(ipt.video_data if self.video_vertical else ipt.video_data.permute(0,2,1).contiguous())

    return self.model(ipt.input_ids, ipt.segment_ids, ipt.attention_mask, ipt.audio_data, ipt.video_data, ipt.labels, use_bce)
