# encoding=utf-8
import os
import sys
curPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(curPath)[0]
sys.path.append(rootPath)

import torch
import torch.nn as nn
import torchvision.models as cv_models
import torch.nn.functional as F

import math
import oss2 as oss
from io import BytesIO

from atten import utils


def myPrinter(args, str_to_be_printed):
  if args.device == 'cuda' or args.device == 'cpu' or args.rank == 0:
    print(str_to_be_printed)

def init_weights(module):
  m=module
  if isinstance(m, nn.Linear):
    nn.init.xavier_normal_(m.weight.data)
  elif isinstance(m, nn.Embedding):
    nn.init.xavier_normal_(m.weight.data)
  elif isinstance(m, nn.LayerNorm):
    m.bias.data.zero_()
    m.weight.data.fill_(1.0)
  elif isinstance(m, nn.BatchNorm1d):
    m.bias.data.zero_()
    m.weight.data.fill_(1.0)


class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout=0.1, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

class OneStreamSelfAttenMMEncoder(nn.Module):
  def __init__(self, config):
    super(OneStreamSelfAttenMMEncoder, self).__init__()
    myPrinter(config, '\t############OneStreamSelfAttenMMEncoder Model#####################')
    self.device = config.device
    self.position_embeddings = PositionalEncoding(d_model=config.d_tok)
    self.seg_embeddings = nn.Embedding(3, config.d_tok, padding_idx=0)
    # self-attention layer
    encoder_layer = nn.TransformerEncoderLayer(d_model=config.d_tok, nhead=config.n_head,
                                               dim_feedforward=config.d_hid, dropout=config.dr,
                                               activation='gelu')
    encoder_norm = nn.LayerNorm(config.d_tok)
    self.encoder = nn.TransformerEncoder(encoder_layer, config.x_layers, encoder_norm)

  def forward(self, lang_feats, lang_pad_mask, token_type_ids, visn_feats):
    (batch_size,seq_len,d_tok)=lang_feats.size()
    (_,patch_size,d_patch)=visn_feats.size()
    assert d_tok==d_patch

    segment = torch.ones((batch_size, patch_size)).to(self.device)  # bsz * patch_size
    segment = torch.cat((token_type_ids, segment * 2), dim=1).type(torch.long)

    visn_pad_mask = torch.ones((batch_size,patch_size)).to(self.device)
    pad_mask = torch.cat([lang_pad_mask, visn_pad_mask], dim=1).type(torch.bool)
    pad_mask = (~pad_mask)

    unit_embeddings = torch.cat([visn_feats, lang_feats], dim=1)
    unit_embeddings = self.position_embeddings(unit_embeddings)
    if torch.isnan(unit_embeddings.std()):
      print('features are nan after position embedding layer!')
    seg_embeddings = self.seg_embeddings(segment)
    unit_embeddings = unit_embeddings + seg_embeddings
    if torch.isnan(unit_embeddings.std()):
      print('lfeatures are nan after segmentation embedding layer!')

    output = self.encoder(unit_embeddings.permute(1, 0, 2), src_key_padding_mask=pad_mask)
    if torch.isnan(output.std()):
      print('language features are nan after self-attention layer!')
    output = output.permute(1, 0, 2)
    lang_feats=output[:,seq_len-1,:]
    visn_feats=output[:,seq_len-1,:]
    return lang_feats,visn_feats,output

class SelfAttenFFNLayers(nn.Module):
  def __init__(self,d_model,dim_feedforward,dropout,activation='gelu'):
    super().__init__()
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)
    self.norm = nn.LayerNorm(d_model)
    self.activation = self._get_activation_fn(activation)

  def forward(self, attention_output):
    src=attention_output
    attention_output = self.activation(self.linear1(attention_output))
    attention_output = self.dropout(self.linear2(attention_output))
    src = self.norm(src+attention_output)
    return src

  def _get_activation_fn(self,activation):
    if activation == "relu":
      return F.relu
    elif activation == "gelu":
      return F.gelu

class LXRTXLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.cross_attention = nn.MultiheadAttention(config.d_tok, config.n_head, dropout=config.dr)
    # self.attention_output = AttenOutput(config)
    self.lang_self_att = nn.MultiheadAttention(config.d_tok, config.n_head, dropout=config.dr)
    self.visn_self_att = nn.MultiheadAttention(config.d_tok, config.n_head, dropout=config.dr)
    self.lang_ffn=SelfAttenFFNLayers(d_model=config.d_tok, dim_feedforward=config.d_hid, dropout=config.dr)
    self.visn_ffn=SelfAttenFFNLayers(d_model=config.d_tok, dim_feedforward=config.d_hid, dropout=config.dr)

  def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
    lang_input=lang_input.permute(1,0,2)
    visn_input=visn_input.permute(1,0,2)
    lang_att_output,_ = self.cross_attention(query = lang_input, key=visn_input, value = visn_input,
                                             key_padding_mask=visn_attention_mask)
    visn_att_output,_ = self.cross_attention(query = visn_input, key=lang_input, value = lang_input,
                                             key_padding_mask=lang_attention_mask)
    lang_att_output = lang_att_output.permute(1, 0, 2)
    visn_att_output = visn_att_output.permute(1, 0, 2)
    return lang_att_output, visn_att_output

  def self_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
    lang_input=lang_input.permute(1,0,2)
    visn_input=visn_input.permute(1,0,2)
    lang_att_output,_ = self.lang_self_att(query=lang_input, key=lang_input, value=lang_input,
                                         key_padding_mask=lang_attention_mask)
    visn_att_output,_ = self.visn_self_att(query=visn_input, key=visn_input, value=visn_input,
                                         key_padding_mask=visn_attention_mask)
    lang_att_output = lang_att_output.permute(1,0,2)
    visn_att_output = visn_att_output.permute(1,0,2)
    return lang_att_output, visn_att_output

  def output_fc(self, lang_input, visn_input):
    lang_output=self.lang_ffn(lang_input)
    visn_output=self.visn_ffn(visn_input)
    return lang_output, visn_output

  def forward(self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask):
    lang_att_output, visn_att_output = lang_feats, visn_feats
    lang_att_output, visn_att_output = self.cross_att(lang_att_output, lang_attention_mask,visn_att_output, visn_attention_mask)
    lang_att_output, visn_att_output = self.self_att(lang_att_output, lang_attention_mask, visn_att_output, visn_attention_mask)
    lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output)
    return lang_output, visn_output

class SelfAttenFuseLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    # self-attention layer
    encoder_layer = nn.TransformerEncoderLayer(d_model=config.d_tok, nhead=config.n_head,
                                               dim_feedforward=config.d_hid, dropout=config.dr,
                                               activation='gelu')
    encoder_norm = nn.LayerNorm(config.d_tok)
    self.encoder = nn.TransformerEncoder(encoder_layer, config.x_layers, encoder_norm)

  def forward(self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask):
    (batch_size, seq_len, d_tok) = lang_feats.size()
    input = torch.cat((lang_feats,visn_feats),dim=1)
    pad_mask = torch.cat((lang_attention_mask,visn_attention_mask),dim=1)
    output = self.encoder(input.permute(1, 0, 2), src_key_padding_mask=pad_mask)
    if torch.isnan(output.std()):
      print('fused features are nan after self-attention layer!')
    output = output.permute(1, 0, 2)
    lang_feats = output[:, :seq_len - 1, :]
    visn_feats = output[:, seq_len - 1:, :]
    return lang_feats, visn_feats


class MMEncoder(nn.Module):
  def __init__(self, config, mode='lxr', use_one_stream=False):
    super().__init__()
    myPrinter(config, '\t############Mention Multimodal Encoder Model#####################')
    self.mode = mode
    self.device = config.device
    self.use_one_stream = use_one_stream
    self.tok_embeddings = nn.Embedding(config.n_tok,config.d_tok)
    # fuse layer
    if use_one_stream:
      self.encoder = OneStreamSelfAttenMMEncoder(config)
    else:
      if config.mode == 'x-lxrt' or config.mode=='x-self':
        self.encoder = DualStreamSelfAttenMMEncoder(config)
      elif config.mode == 'x-adapt':
        self.encoder = AdaptiveFuseMMEncoder(config)
      else:
        print('WRONG TYPE OF MODEL: {}'.format(config.mode))
        exit(0)

  def forward(self, input_ids, token_type_ids=None, lang_pad_mask=None, visual_feats=None):
    (batch_size, patch_size, _) = visual_feats.size()

    if lang_pad_mask is None:
      lang_pad_mask = torch.ones_like(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.ones_like(input_ids)

    # Word Embeddings
    lang_embeddings = self.tok_embeddings(input_ids)
    if torch.isnan(lang_embeddings.std()):
      print('language features are nan after embedding layer!')

    # Run LXRT backbone
    lang_feats, visn_feats, output = self.encoder(lang_feats=lang_embeddings, token_type_ids=token_type_ids,
                                                  lang_pad_mask=lang_pad_mask,
                                                  visn_feats=visual_feats)
    # pool
    output = output[:, 0, :]
    return output

class EntMMEncoder(nn.Module):
  def __init__(self, config, mode='lxr', use_one_stream=False):
    super().__init__()
    myPrinter(config, '\t############Entity Multimodal Encoder Model#####################')
    self.mode = mode
    self.device = config.device
    self.use_one_stream = use_one_stream
    # fuse layer
    if use_one_stream:
      self.encoder = OneStreamSelfAttenMMEncoder(config)
    else:
      if config.mode == 'x-lxrt' or config.mode=='x-self':
        self.encoder = DualStreamSelfAttenMMEncoder(config)
      elif config.mode == 'x-adapt':
        self.encoder = AdaptiveFuseMMEncoder(config)
      else:
        print('WRONG TYPE OF MODEL: {}'.format(config.mode))
        exit(0)

  def forward(self, textual_feats, visual_feats):
    (_, patch_size, _) = visual_feats.size()
    (batch_size, seq_len, _) =textual_feats.size()

    lang_pad_mask = torch.ones((batch_size,seq_len)).to(self.device)
    token_type_ids = torch.ones((batch_size,seq_len)).to(self.device)

    # Run LXRT backbone
    lang_feats, visn_feats, output = self.encoder(lang_feats=textual_feats, token_type_ids=token_type_ids,
                                                  lang_pad_mask=lang_pad_mask,
                                                  visn_feats=visual_feats)
    # pool
    output = output[:, 0, :]
    return output

class TXTSimpleEncoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.device = config.device
    # embedding layer
    self.tok_embeddings = nn.Embedding(config.n_tok, config.d_tok, padding_idx=0)
    self.seg_embeddings = nn.Embedding(2, config.d_tok, padding_idx=0)
    self.position_embeddings = PositionalEncoding(d_model=config.d_tok)
    # self-attention layer
    encoder_layer = nn.TransformerEncoderLayer(d_model=config.d_tok, nhead=config.n_head,
                                               dim_feedforward=config.d_hid, dropout=config.dr,
                                               activation='gelu')
    encoder_norm = nn.LayerNorm(config.d_tok)
    self.encoder = nn.TransformerEncoder(encoder_layer, config.x_layers, encoder_norm)

  def forward(self, input_ids, token_type_ids=None, pad_mask=None):
    if pad_mask is None:
      pad_mask = torch.ones_like(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.ones_like(input_ids)

    pad_mask = (~pad_mask).type(torch.bool)

    # Positional Word Embeddings
    lang_embeddings = self.tok_embeddings(input_ids)
    if torch.isnan(lang_embeddings.std()):
      print('language features are nan after embedding layer!')
    lang_embeddings = self.position_embeddings(lang_embeddings)
    if torch.isnan(lang_embeddings.std()):
      print('language features are nan after position embedding layer!')
    seg_embeddings = self.seg_embeddings(token_type_ids)
    lang_embeddings = lang_embeddings + seg_embeddings
    if torch.isnan(lang_embeddings.std()):
      print('language features are nan after segmentation embedding layer!')

    # self-attention layer
    # input: bsz * seq_len * d_tok
    # output: bsz * seq_len * d_tok
    output = self.encoder(lang_embeddings.permute(1, 0, 2), src_key_padding_mask=pad_mask)
    if torch.isnan(output.std()):
      print('language features are nan after self-attention layer!')
    output = output.permute(1, 0, 2)

    return output

class IMGSimpleEncoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    myPrinter(config,'\t############Simple IMG Model#####################')
    self.prjLayers = nn.Sequential(
      nn.Linear(config.d_tok, config.d_tok, bias=False),
      nn.LayerNorm(config.d_tok, eps=1e-12))

  def forward(self, imgs):
    imgs=torch.mean(imgs,dim=1)
    if torch.isnan(imgs.std()):
      print('img features are nan after mean!')
    imgs=self.prjLayers(imgs)
    if torch.isnan(imgs.std()):
      print('img features are nan after prjLayer!')
    return imgs



class ELModel(nn.Module):
  def __init__(self, config, device):
    super().__init__()
    self.one_stream=config['one_stream']
    self.args = config
    self.mode = config['mode']
    self.d_hid=config['d_hid']
    self.d_ent=config['d_ent']
    self.n_prjLayer=2
    self.device = device
    config=utils.dictToObj(config)
    myPrinter(config, '\t############EL Model#####################')
    self.config=config

    '''img encoder'''
    if self.mode=='txt':
      pass
    else:
      self.img_feats_extractor=ImgFeatsExtractor(config=config)
    # para_num_dict = utils.get_parameter_number(self.img_feats_extractor)
    # print(para_num_dict)

    '''mention encoder'''
    self.m_encoder=MentionEncoder(config=config)
    # para_num_dict = utils.get_parameter_number(self.m_encoder)
    # print(para_num_dict)

    '''entity encoder'''
    self.e_encoder=EntityEncoder(config=config)
    # para_num_dict = utils.get_parameter_number(self.e_encoder)
    # print(para_num_dict)

    '''score combination'''
    self.score_combine = nn.Linear(2, 1, bias=False)

    # if args.dynamic_temperature:
    #   self.temperature = nn.Parameter(torch.tensor(2.))
    self.apply(init_weights)
    for n,p in self.named_parameters():
      if n.endswith('tok_embeddings.weight'):
        p.data.copy_(torch.from_numpy(config['pretrained_embs_tok']))
        p.requires_grad = True
      if n.endswith('ent_embeddings.weight'):
        p.data.copy_(torch.from_numpy(config['pretrained_embs_ent']))
        p.requires_grad = True


  def forward(self, m_txt_info,m_img_info,e_txt_info, e_img_info, e_mask, is_train=False, is_contrastive=False):
    # print(self.score_combine.weight)
    # os._exit(0)
    [m_seqs, m_seq_pad_masks,m_seq_segment] = m_txt_info
    m_seqs = m_seqs.to(self.device) # bsz * seq_len
    m_seq_pad_masks=m_seq_pad_masks.to(self.device) # bsz * seq_len
    m_seq_segment=m_seq_segment.to(self.device) # bsz * seq_len
    m_imgs = m_img_info.to(self.device)  # bsz * 3 * 256 *256

    [e_ids, e_pems] = e_txt_info
    e_ids = e_ids.to(self.device)  # bsc * n_cand
    e_pems = e_pems.to(self.device)  # bsc * n_cand
    e_img_info = e_img_info.to(self.device) # bsc * n_cand * n_img_per_cand * 3 * 256 * 256
    e_mask = e_mask.to(self.device)  # bsc * n_cand

    (batch_size, seq_len) = m_seqs.size()
    (_, n_cand) = e_ids.size()

    '''先处理mention侧的数据'''
    if self.mode=='txt':
      m_output = self.m_encoder(m_seqs, m_seq_segment, m_seq_pad_masks) # bsz * 1 * d_hid
    elif self.mode=='img':
      m_imgs = self.img_feats_extractor(m_imgs)  # bsz * patch_num * d_hidden  patch_size=64
      m_output = self.m_encoder(input_ids=None,token_type_ids=None,
                                attention_mask=None,visual_feats=m_imgs)
    else:
      m_imgs=self.img_feats_extractor(m_imgs) # bsz * patch_num * d_hidden
      m_output = self.m_encoder(input_ids=m_seqs, token_type_ids=m_seq_segment,
                                attention_mask=m_seq_pad_masks, visual_feats=m_imgs)
    '''再处理entity侧的数据'''
    if self.mode=='txt':
        e_output=self.e_encoder(e_ids)
    elif self.mode=='img':
      e_img_info = e_img_info[:,:,0,:,:,:] # bsz * n_cand * 3 * 256 *256
      e_imgs = e_img_info.to(self.device).view(-1, 3, 256, 256)  # (bsz * n_cand) * 3 * 256 *256
      e_imgs = self.img_feats_extractor(e_imgs)  # (bsz * n_cand) * patch_num * 100
      e_output = self.e_encoder(e_ids=e_ids, e_imgs=e_imgs)
    else:
      e_img_info = e_img_info[:,:,0,:,:,:] # bsz * n_cand * 3 * 256 *256
      e_imgs = e_img_info.to(self.device).view(-1, 3, 256, 256)  # (bsz * n_cand) * 3 * 256 *256
      e_imgs=self.img_feats_extractor(e_imgs) # (bsz * n_cand) * patch_num * 100
      e_output = self.e_encoder(e_ids=e_ids, e_imgs=e_imgs) # bsz * n_cand * d_hid


    '''比较相似度'''
    m_output = m_output.view(batch_size, -1, 1)  # bsz * d_hid * 1

    if is_train and is_contrastive:
      simi = torch.einsum('ijkl,ild->ijk', [e_output.expand(batch_size,-1,-1,-1), m_output]) # 扩展成bsz * bsz * n_cand * d_hid    bsz * d_hid * 1
      simi = simi.view(batch_size,-1) # bsz * (bsz * n_cand)
      score = self.masked_softmax(simi, e_mask.expand(batch_size,-1,-1).view(batch_size,-1))
    else:
      simi = torch.bmm(e_output, m_output).view(batch_size, -1)  # bsz * n_cand
      score = self.masked_softmax(simi, e_mask)
    if torch.isnan(score.std()):
      print('output is nan after the masked log softmax')
    return score

  def masked_softmax(self, tensor, mask):
    input_tensor = tensor.masked_fill((~mask), value=torch.tensor(-1e9))
    result = nn.functional.log_softmax(input_tensor, dim=-1)
    return result