import numpy as np
import torch
from torch import nn


class ProtoNetPredictor(nn.Module):
    def __init__(self, proto_net, use_softmax, max_token_length):
        super(ProtoNetPredictor, self).__init__()
        self.proto_net = proto_net
        self.batch_size = 64
        self.use_softmax = use_softmax
        self.max_token_length = max_token_length

    def iter(self, data, device):
        size = data[0].size(0)
        for start_idx in range(0, size, self.batch_size):
            excerpt = slice(start_idx, start_idx + self.batch_size)
            tokens, lengths, mask, sub_word_ids, sub_word_valid_lengths, intents, slots = tuple(d[excerpt].to(device) for d in data)
            lengths, idx = torch.sort(lengths, descending=True)
            mask = mask[idx]
            tokens = tokens[idx]
            sub_word_ids = sub_word_ids[idx]
            sub_word_valid_lengths = sub_word_valid_lengths[idx]
            intents = intents[idx]
            slots = slots[idx]
            yield ( ( tokens,
                    lengths,
                    mask,
                    sub_word_ids,
                    sub_word_valid_lengths ),
                    intents,
                    slots )

    def set_prototypes(self, supports):
        supports_embed = self.proto_net.encode(*supports[0])
        prototypes, prototypes_slots, _ = self.proto_net.compute_prototypes(supports_embed, supports[1], supports[2])
        self.prototypes = prototypes
        self.prototypes_slots = prototypes_slots

    def forward(self, input, device):
        if self.use_softmax:
            return self.forward_softmax(input, device)
        else:
            return self.forward_crf(input, device)

    def forward_softmax(self, input, device):
        # emissions of shape (B, L, P)
        emissions, _, prototypes_slots = self.proto_net.predict_from_prototypes(input, self.prototypes, self.prototypes_slots)
        preds = self.prototypes_slots[torch.argmax(emissions, dim=2)]
        lengths = input[1]
        mask = (torch.arange(self.max_token_length).to(device) < lengths.unsqueeze(1))
        preds = preds * mask.long()
        return preds, emissions

    def forward_crf(self, input, device):
        emissions, transitions, prototypes_slots = self.proto_net.predict_from_prototypes(input, self.prototypes,
                self.prototypes_slots)

        sub_word_valid_lengths = input[-1]
        mask = (torch.arange(self.max_token_length).to(device) < sub_word_valid_lengths.unsqueeze(1)).float()

        preds, score = self.proto_net.crf_decoder.decode(emissions, transitions, mask=mask)
        preds = prototypes_slots[preds]
        preds = preds * mask.long()
        return preds, score


class NERPredictor(nn.Module):
    def __init__(self, base_net, use_softmax, max_token_length):
        super(NERPredictor, self).__init__()
        self.base_net = base_net
        self.use_softmax = use_softmax
        self.mapping = None
        self.max_token_length = max_token_length

    def set_mapping(self, tgt_slots):
        self.mapping = torch.Tensor(np.append([0, 1], tgt_slots)).long()

    def forward(self, input, device):
        if self.use_softmax:
            pred_slots, score = self.forward_softmax(input, device)
        else:
            pred_slots, score = self.forward_crf(input, device)
        pred_slots = self.mapping[pred_slots]
        return pred_slots, score

    def forward_crf(self, input, device='cpu'):
        emissions = self.base_net(input)

        sub_word_valid_lengths = input[-1]
        mask = (torch.arange(self.max_token_length).to(device) < sub_word_valid_lengths.unsqueeze(1)).float()

        preds, score = self.base_net.crf_decoder.decode(emissions, None, mask=mask)
        preds = preds * mask.long()
        return preds, score

    def forward_softmax(self, input, device='cpu'):
        emissions = self.base_net(input)

        sub_word_valid_lengths = input[-1]
        mask = (torch.arange(self.max_token_length).to(device) < sub_word_valid_lengths.unsqueeze(1)).float()

        preds = torch.argmax(emissions, dim=2)
        preds = preds * mask.long()
        return preds, emissions
