import os
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertForTokenClassification, BertModel, BertForSequenceClassification

class BertForRationaleGeneration(nn.Module):
    """
    This class contains two bert modules, one for generator, and the other for encoder. 
    Generator: A BertForTokenClassification module that learns token-level distributions. 
    Output of this model is fed into gumbel-softmax to obtain a mask. The mask is a hard
    mask, but gradient backpropagate throught the softmax results. 
    Encoder: The encoder is a dissected BertModel, and the mask is applied through 
    multiplication to each hidden layer calculated by the BertEncoders. BertPooler was used
    to pool the word-embeddings to get the sentence representation.
    Classifier: A simple linear classifier is put on top of the pooled sequence.  

    Reference:
    gumbel-softmax adapted from https://github.com/ericjang/gumbel-softmax/. 
    """

    def __init__(self, args, num_labels=2, from_output=False):
        super(BertForRationaleGeneration, self).__init__()
        self.num_labels = num_labels
        if from_output:
            self.gen_bert = BertForTokenClassification.from_pretrained(args.generator_output_dir, num_labels=2)
            self.enc_bert = BertForSequenceClassification.from_pretrained(args.encoder_output_dir, num_labels=num_labels)
        else:
            self.gen_bert = BertForTokenClassification.from_pretrained(args.bert_model, num_labels=2)
            self.enc_bert = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
        self.args = args
        self.dropout = self.enc_bert.dropout
        self.classifier = self.enc_bert.classifier
        self.temperature = args.temperature
    
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, do_eval=False):
     #  masks, selection_cost, continuity_cost = self.generator(input_ids, token_type_ids, attention_mask)
        logits = self.gen_bert(input_ids, attention_mask, token_type_ids, labels)

        mask, hard_mask = self.gumbel_softmax(logits)
        
        if do_eval:
            mask = hard_mask

        # if attention_mask is not None:
        #    mask = torch.mul(mask, attention_mask.float())
        selection_loss = self.loss(mask)
        mask = mask.unsqueeze(-1)
        
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.enc_bert.bert.embeddings(input_ids, token_type_ids)
        embedding_output[:, 1:, :] = embedding_output[:, 1:, :].clone() * mask[:, 1:, :]

        encoded_layers = self.enc_bert.bert.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=False)

        pooled_output = self.enc_bert.bert.pooler(encoded_layers[-1])
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        masked_ids = self.tagger(hard_mask, input_ids)

        return selection_loss, logits, masked_ids

    def gumbel_softmax(self, logits):
        # sample from Gumbel (0,1)
        noise = torch.rand(logits.size())
        noise.add_(1e-9).log_().neg_()
        noise.add_(1e-9).log_().neg_()
        noise = autograd.Variable(noise)
        if not self.args.no_cuda:
            noise = noise.cuda()
        
        # sample from gumbel softmax
        y = (logits + noise) / self.temperature
        y = F.softmax(y.view(-1, y.size()[-1]), dim=-1)
        y = y.view_as(logits)
        #_, ind = y.max(dim=-1)
        
        #y_hard = torch.zeros_like(y).view(-1, y.size()[-1])
        #y_hard.scatter_(1, ind.view(-1, 1), 1)
        #y_hard = y_hard.view_as(y)
        # Use softmax result for backpropagation
        #mask = (y_hard - y).detach() + y
        mask = y
        mask = mask[:, :, 1]
        max_z, _ = mask.max(dim=-1)
        hard_mask = torch.ge(mask, max_z.unsqueeze(-1)).float()
        return mask, hard_mask

    def loss(self, mask):
        length_cost = torch.mean(torch.sum(mask, dim=1))
        l_padded_mask =  torch.cat([mask[:,0].unsqueeze(1), mask] , dim=1)
        r_padded_mask =  torch.cat([mask, mask[:,-1].unsqueeze(1)] , dim=1)
        continuity_cost = torch.mean(torch.sum(torch.abs(l_padded_mask - r_padded_mask) , dim=1))
        return self.args.l_lambda * length_cost + self.args.c_lambda * continuity_cost
    
    def tagger(self, hard_mask, input_ids):
        masked_ids = torch.mul(input_ids.float(), hard_mask)
        masked_ids = masked_ids.long()
        return masked_ids
