import torch
import torch.nn.functional as F
from torch import nn


class ProtoNetNerLoss(nn.Module):
    def __init__(self, proto_net, use_softmax):
        super(ProtoNetNerLoss, self).__init__()
        self.proto_net = proto_net
        self.use_softmax = use_softmax

    def forward(self, *args):
        if self.use_softmax:
            return self.forward_softmax(*args)
        else:
            return self.forward_crf(*args)

    def forward_softmax(self, queries, intents, slots, supports, supports_intents, supports_slots, device):
        # emissions of shape (B, L, P), P includes 'other'
        emissions, _, prototypes_slots = self.proto_net.forward(queries, supports, supports_intents, supports_slots)
        B, L, P = emissions.shape

        emissions = emissions.reshape(-1, P)
        true_prototypes = torch.flatten((slots.unsqueeze(2) == prototypes_slots).float().argmax(2))
        slot_mask = (torch.flatten(slots) > 0).float()

        return - (slot_mask * F.log_softmax(emissions, dim=1)[torch.arange(B*L).to(device), true_prototypes]).mean()

    def forward_crf(self, queries, intents, slots, supports, supports_intents, supports_slots, device):
        # emissions of shape (B, L, P)
        emissions, transitions, prototypes_slots = self.proto_net.forward(queries, supports, supports_intents, supports_slots)
        true_prototypes = (slots.unsqueeze(2) == prototypes_slots).long().argmax(2)

        # mask loss on pad slots, assuming padding slotid is 0
        slot_mask = (slots > 0).float()

        return self.proto_net.crf_decoder.loss(emissions, transitions, true_prototypes, mask=slot_mask).mean()


class BaselineNerLoss(nn.Module):
    def __init__(self, base_net, use_softmax):
        super(BaselineNerLoss, self).__init__()
        self.base_net = base_net
        self.use_softmax = use_softmax

    def forward(self, *args):
        if self.use_softmax:
            return self.forward_softmax(*args)
        else:
            return self.forward_crf(*args)

    def forward_crf(self, input, intents, slots, device):
        emissions = self.base_net(input)

        # mask loss on pad slots, assuming padding slotid is 0
        slot_mask = (slots > 0).float()

        return self.base_net.crf_decoder.loss(emissions, None, slots, mask=slot_mask).mean()

    def forward_softmax(self, input, intents, slots, device):
        emissions = self.base_net(input)
        B, L, T = emissions.shape
        return nn.CrossEntropyLoss(ignore_index=0)(emissions.reshape(-1, T), torch.flatten(slots))
