import torch.nn as nn


class IntentClassifierNLI(nn.Module):
    def __init__(self, base_model, hidden_size=768, dropout=0.5):
        super(IntentClassifierNLI, self).__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.similarity_layer = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, 1),
            nn.Sigmoid()
        )

    def encode(self, pairs, attention_masks):
        """
        :param pairs: intent x utterance pair ids (bs, n_pairs, max_len)
        :param attention_masks: intent x utterance pair attention masks (bs, n_pairs, max_len)
        :return: pair embeddings (bs, n_pairs, emb_size)
        """
        bs, n_pairs, seq_len = pairs.size()
        flat_pairs = pairs.view(-1, seq_len)  # (batch_size*n_pairs, max_len)
        attention_masks = attention_masks.view(-1, seq_len)
        output = self.base_model(flat_pairs.long(), attention_masks.long())
        if 'pooler_output' in output:
            cls_tokens = output['pooler_output']  # (batch_size*n_pairs, emb_size)
        else:
            cls_tokens = output['last_hidden_state'][:, 0, :]  # (batch_size*n_pairs, emb_size)
        return cls_tokens.reshape(bs, n_pairs, -1)  # (batch_size, n_pairs, emb_size)

    def forward(self, pairs, attention_masks):
        """
        :param pairs: intent x utterance pair ids (bs, n_pairs, max_len)
        :param attention_masks: intent x utterance pair attention mask (bs, n_pairs, max_len)
        :return: pair similarities (bs, n_pairs)
        """
        cls_tokens = self.encode(pairs, attention_masks)
        similarity = self.similarity_layer(cls_tokens).squeeze(dim=2)  # (batch_size, n_pairs)
        return similarity


# class IntentClassifierContrstiveNLI(nn.Module):
#     def __init__(self, base_model, hidden_size=768, dropout=0.5):
#         super(IntentClassifierContrstiveNLI, self).__init__()
#         self.base_model = base_model
#         self.hidden_size = hidden_size
#         self.similarity_layer = nn.Sequential(
#             nn.Dropout(dropout),
#             nn.Linear(self.hidden_size, 1)
#         )
#         self.loss_func = nn.CrossEntropyLoss(reduction='none')
#
#     def forward(self, pairs=None, labels=None):
#         # utterances (batch_size, n_intents, n_uttrs, max_len)
#         if labels is not None:
#             bs, n_intents, n_uttrs, seq_len = pairs.size()
#             assert n_uttrs == n_intents
#
#             flat_uttrs = pairs.view(-1, seq_len)  # (batch_size*n_uttrs*n_intents, max_len)
#             cls_tokens = self.base_model(flat_uttrs.long())[1]  # (batch_size*n_uttrs*n_intents, bert_emb)
#             logits = self.similarity_layer(cls_tokens).reshape(-1, n_intents, n_uttrs)  # (batch_size, n_intents, n_uttrs)
#
#             loss_uttrs = self.loss_func(logits, labels)
#             loss_intents = self.loss_func(torch.transpose(logits, dim0=1, dim1=2), labels)
#             loss = torch.mean(loss_uttrs + loss_intents / 2, dim=(0, 1))
#             return loss, logits
#         else:
#             _, n_intents, seq_len = pairs.size()
#             flat_uttrs = pairs.view(-1, seq_len)  # (batch_size*n_intents, max_len)
#             cls_tokens = self.base_model(flat_uttrs.long())[1]  # (batch_size*n_intents, bert_emb)
#             logits = self.similarity_layer(cls_tokens).reshape(-1, n_intents)  # (batch_size, n_intents)
#             return logits


# def get_model(cfg, base_model):
#     if cfg.model.model_type == 'nli_ca':
#         return IntentClassifierDecomposedNLI(base_model, hidden_size=cfg.model.embedding_dim, dropout=cfg.model.dropout)
#     elif cfg.model.model_type == 'nli_strict':
#         return IntentClassifierNLI(base_model, hidden_size=cfg.model.embedding_dim, dropout=cfg.model.dropout)
#     elif cfg.model.model_type == 'nli_contrastive':
#         return IntentClassifierContrstiveNLI(base_model, hidden_size=cfg.model.embedding_dim, dropout=cfg.model.dropout)
#     else:
#         ValueError("Unknown model type")
