import torch
import torch.nn as nn


class IntentClassifierDecomposedNLI(nn.Module):
    def __init__(self, base_model, hidden_size=768, dropout=0.5):
        super(IntentClassifierDecomposedNLI, self).__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.similarity_layer = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, 1),
            nn.Sigmoid()
        )
        self.loss_func = nn.BCELoss(reduction='sum')

    def forward(self, utterances, intent_concept=None, intent_action=None):
        cls_tokens = self.encode(utterances)
        similarity = self.similarity_layer(cls_tokens).squeeze(dim=2)  # (batch_size, n_intents)
        loss_concept = torch.tensor(data=0., requires_grad=True).to(similarity.device)
        if intent_concept is not None:
            loss_concept = self.loss_func(similarity, intent_concept)
        loss_action = torch.tensor(data=0., requires_grad=True).to(similarity.device)
        if intent_action is not None:
            loss_action = self.loss_func(similarity, intent_action)

        return loss_action + loss_concept, similarity

    def encode(self, utterances):
        # utterances (batch_size, n_intents, max_len)
        bs, n_intents, seq_len = utterances.size()
        flat_uttrs = utterances.view(-1, seq_len)  # (batch_size*n_intents, max_len)
        output = self.base_model(flat_uttrs.long())
        if 'pooler_output' in output:
            cls_tokens = output['pooler_output']  # (batch_size*n_intents, bert_emb)
        else:
            cls_tokens = output['last_hidden_state'][:, 0, :]  # (batch_size*n_intents, bert_emb)
        return cls_tokens.reshape(bs, n_intents, -1)  # (batch_size, n_intents, bert_emb)
