from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import math
import os

import torch
from torch import nn
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F

from unilm.modeling import UniLMPreTrainedModel, UniLMModel, BertPredictionHeadTransform, BertLMPredictionHead


class BertTableMatchHead(nn.Module):
  def __init__(self, config):
    super(BertTableMatchHead, self).__init__()
    self.transform = BertPredictionHeadTransform(config)

  def forward(self, headers, values, mask):
    values = self.transform(values)

    score = torch.sum(headers.unsqueeze(1) * values.unsqueeze(2), -1)
    masked_score = score * mask.unsqueeze(1).type_as(score)

    return masked_score


class AttentionPooling(nn.Module):
  def __init__(self, hidden_size):
    super(AttentionPooling, self).__init__()
    self.linear = nn.Linear(hidden_size, hidden_size, bias=True)
    self.activation = nn.Tanh()
    self.get_weight = nn.Linear(hidden_size, 1, bias=True)
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, vectors, mask):
    u = self.linear(vectors)
    logits = self.activation(u)
    weight = self.get_weight(logits).squeeze(-1)
    weight += (1.0 - mask.type_as(weight)) * -10000.0

    score = self.softmax(weight)
    return (score.unsqueeze(-1) * vectors).sum(-2)


class BertTableMatchHeadV2(nn.Module):
  def __init__(self, config):
    super(BertTableMatchHeadV2, self).__init__()
    self.headers_pooling = AttentionPooling(config.hidden_size)
    self.values_pooling = AttentionPooling(config.hidden_size)
    self.headers_transform = BertPredictionHeadTransform(config)
    self.values_transform = BertPredictionHeadTransform(config)

  def forward(self, headers, headers_mask, values, values_mask, mask):
    headers = self.headers_pooling(headers, headers_mask)
    values = self.headers_pooling(values, values_mask)
    headers = self.headers_transform(headers)
    values = self.values_transform(values)

    # print("headers shape = {}".format(headers.shape))
    # print("values shape = {}".format(values.shape))

    score = torch.sum(headers.unsqueeze(1) * values.unsqueeze(2), -1)
    masked_score = score * mask.unsqueeze(1).type_as(score)

    return masked_score


class BertPreTrainingHeads(nn.Module):
  def __init__(self, config, bert_model_embedding_weights):
    super(BertPreTrainingHeads, self).__init__()
    self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
    self.seq_relationship = nn.Linear(config.hidden_size, 2)

  def forward(self, sequence_output, pooled_output):
    prediction_scores = self.predictions(sequence_output)
    seq_relationship_score = self.seq_relationship(pooled_output)
    return prediction_scores, seq_relationship_score


class UniLMForStructPreTraining(UniLMPreTrainedModel):
  def __init__(self, config):
    super(UniLMForStructPreTraining, self).__init__(config)
    self.bert = UniLMModel(config)
    self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
    self.match_head = BertTableMatchHeadV2(config)

    self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
    self.crit_next_sent = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
    self.num_labels = 2
    self.init_weights()

  @staticmethod
  def loss_mask_and_normalize(loss, mask):
    mask = mask.type_as(loss)
    loss = loss * mask
    denominator = torch.sum(mask) + 1e-5
    return (loss / denominator).sum()

  def forward(self, input_ids, token_type_ids=None, attention_mask=None, is_random_next=None,
              masked_lm_labels=None, masked_pos=None, masked_weights=None, is_next_mask=None,
              match_labels=None, val_indexs=None, val_index_mask=None, col_indexs=None,
              col_index_mask=None, match_label_mask=None):
    sequence_output, pooled_output = self.bert(
      input_ids, token_type_ids, attention_mask)
    prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

    def gather_seq_out_by_pos(seq, pos):
      return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))

    def gather_seq_out_by_pos_average(seq, pos, mask=None):
      # pos/mask: (batch, num_pair, max_token_num)
      batch_size, max_token_num = pos.size(0), pos.size(-1)
      # (batch, num_pair, max_token_num, seq.size(-1))
      pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze(
        2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1))
      # (batch, num_pair, seq.size(-1))
      if mask is None:
        return pos_vec
      mask = mask.type_as(pos_vec)
      pos_vec_masked_sum = (
          pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2)
      return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum)

    if masked_lm_labels is not None and is_random_next is not None:
      # masked lm
      sequence_output_masked = gather_seq_out_by_pos(
        sequence_output, masked_pos)
      prediction_scores_masked, seq_relationship_score = self.cls(
        sequence_output_masked, pooled_output)
      masked_lm_loss = self.crit_mask_lm(
        prediction_scores_masked.transpose(1, 2), masked_lm_labels)
      masked_lm_loss = self.loss_mask_and_normalize(
        masked_lm_loss, masked_weights)
      # prediction_prob = F.softmax(prediction_scores_masked, dim=-1)
      _, predict_index = prediction_scores_masked.max(-1)
      num_correct_tokens = torch.mul(predict_index.eq(masked_lm_labels).long(), masked_weights).sum()
      num_mask_tokens = masked_weights.sum()
      # next sentence
      next_sentence_loss = self.crit_next_sent(
        seq_relationship_score.view(-1, self.num_labels), is_random_next.view(-1))
      next_sentence_loss = self.loss_mask_and_normalize(next_sentence_loss, is_next_mask)
      _, predict_relationships = seq_relationship_score.max(-1)
      num_correct_relationship = (predict_relationships.eq(is_random_next).type_as(next_sentence_loss) * is_next_mask.type_as(next_sentence_loss)).sum()

      rets = [masked_lm_loss, next_sentence_loss]
      infos = [num_correct_tokens, num_mask_tokens, num_correct_relationship, is_next_mask.sum().type_as(next_sentence_loss)]

      if match_labels is not None:
        val_output = gather_seq_out_by_pos_average(sequence_output, val_indexs)  #, mask=val_index_mask)
        col_output = gather_seq_out_by_pos_average(sequence_output, col_indexs)  #, mask=col_index_mask)

        # [Batch_size, num_cols, num_vals]
        # dot = torch.sum(col_output.unsqueeze(1) * val_output.unsqueeze(2), -1)
        # masked_dot = dot * match_label_mask.unsqueeze(1).float()

        match_score = self.match_head(col_output, col_index_mask, val_output, val_index_mask, match_label_mask)

        match_loss = self.crit_mask_lm(match_score, match_labels)
        match_loss = self.loss_mask_and_normalize(match_loss, match_label_mask)
        rets.append(match_loss)

        _, predict_index = match_score.max(1)
        num_correct_match = torch.mul(predict_index.eq(match_labels).long(), match_label_mask).sum()
        num_match = match_label_mask.sum()

        infos.append(num_correct_match)
        infos.append(num_match)

      # return (masked_lm_loss + next_sentence_loss) / 2
      return rets[0] + rets[2]
    else:
      return prediction_scores, seq_relationship_score
