import torch
import torch.nn as nn
from torch.nn.parameter import Parameter


class ChainCRF(nn.Module):
    """
    This class implements the CRF model. It includes the loss function
    and the decoding function.
    """
    def __init__(self, num_labels=None):
        super(ChainCRF, self).__init__()
        self.num_labels = num_labels
        if num_labels:
            self.trans_matrix = Parameter(torch.Tensor(self.num_labels, self.num_labels))
            nn.init.normal_(self.trans_matrix)
        else:
            self.register_parameter('trans_matrix', None)

    def forward(self, emissions, transitions, mask=None):
        """
        This function receives the input and outputs the summation of inference scores
        and the transition scores.
        Args:
            input: Tensor
                the input tensor with shape = [batch, length, input_size]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]

        Returns: Tensor
            the energy tensor with shape = [batch, length, num_label, num_label]

        """
        if transitions is None:
            transitions = self.trans_matrix

        # [batch, length, num_label, num_label]
        output = emissions.unsqueeze(2) + transitions

        output = output.float()
        if mask is not None:
            output = output * mask.unsqueeze(2).unsqueeze(3)

        return output

    @staticmethod
    def _forward_energy(energy_transpose, mask=None):
        """
        This function calculates the forward CRF energy.
        Args:
            energy_transpose: Tensor
                the energy tensor with shape = [length, batch, num_label, num_label]
            mask:Tensor or None
                the mask tensor with shape = [batch, length]

        Returns: Tensor
                A 1D tensor for the forward energy
        """
        # shape = [length, batch, num_label, num_label]
        mask_transpose = None
        if mask is not None:
            mask_transpose = mask.unsqueeze(2).transpose(0, 1)

        # shape = [batch, num_label]
        partition = None
        length = energy_transpose.size()[0]

        for t in range(length):
            if t == 0:
                partition = energy_transpose[t, :, -1, :]
            else:
                # shape = [batch, num_label]
                partition_new = torch.logsumexp(energy_transpose[t] + partition.unsqueeze(2), dim=1)
                if mask_transpose is None:
                    partition = partition_new
                else:
                    mask_t = mask_transpose[t]
                    partition = partition + (partition_new - partition) * mask_t

        return torch.logsumexp(partition, dim=1)

    @staticmethod
    def _target_energy(energy_transpose, target):
        """
        This function calculates the CRF energy of an hypothesis.
        Args:
            energy_transpose: Tensor
                the energy tensor with shape = [length, batch, num_label, num_label]
            target: Tensor
                the tensor of target labels with shape [batch, length]

        Returns: Tensor
                A 1D tensor for the target energies
        """
        length, batch, num_label, _ = energy_transpose.size()
        # shape = [length, batch]
        target_transpose = target.transpose(0, 1)

        # shape = [batch]
        batch_index = torch.arange(0, batch).type_as(energy_transpose).long()
        prev_label = energy_transpose.new_full((batch, ), num_label - 1).long()
        tgt_energy = energy_transpose.new_zeros(batch)

        for t in range(length):
            tgt_energy += energy_transpose[t, batch_index, prev_label, target_transpose[t]]
            prev_label = target_transpose[t]

        return tgt_energy

    def loss(self, emissions, transitions, target, mask=None):
        """
        This function calculates the CRF loss.
        Args:
            input: Tensor
                the input tensor with shape = [batch, length, input_size]
            target: Tensor
                the tensor of target labels with shape [batch, length]
            mask:Tensor or None
                the mask tensor with shape = [batch, length]

        Returns: Tensor
                A 1D tensor for negative log likelihood loss
        """
        if transitions is None:
            transitions = self.trans_matrix

        energy = self(emissions, transitions, mask=mask)

        # shape = [length, batch, num_label, num_label]
        energy_transpose = energy.transpose(0, 1)

        return self._forward_energy(energy_transpose, mask) - self._target_energy(energy_transpose, target)

    def decode(self, emissions, transitions, mask=None, leading_symbolic=0, kbest=1):
        """
        This function decodes the k best sequences of labels.
        Args:
            input: Tensor
                the input tensor with shape = [batch, length, input_size]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]
            leading_symbolic: int
                number of symbolic labels leading in type alphabets (set it to 0 if you are not sure)
            kbest: int
                number of hypotheses to return

        Returns: Tensor
            decoding results in shape [batch, length, kbest]

        """
        if transitions is None:
            transitions = self.trans_matrix

        energy = self(emissions, transitions, mask=mask)

        # Input should be provided as (n_batch, n_time_steps, num_labels, num_labels)
        # For convenience, we need to dimshuffle to (n_time_steps, n_batch, num_labels, num_labels)
        energy_transpose = energy.transpose(0, 1)

        forward_energy = self._forward_energy(energy_transpose, mask)

        length, batch_size, num_labels, _ = energy_transpose.size()

        batch_index = torch.arange(0, batch_size,
                                   device=energy_transpose.device,
                                   dtype=torch.long)

        back_pointer = batch_index.new_zeros(length, batch_size)
        pointer = batch_index.new_zeros(length, batch_size, num_labels)
        pi = energy_transpose.new_zeros([length, batch_size, num_labels])
        #pi[0] = energy[:, 0, -1, leading_symbolic:-1]
        pi[0] = energy[:, 0, -1, :]
        pointer[0] = -1
        for t in range(1, length):
            pi_prev = pi[t - 1]
            pi[t], pointer[t] = torch.max(energy_transpose[t] + pi_prev.unsqueeze(2), dim=1)

        score, back_pointer[-1] = torch.max(pi[-1], dim=1)
        for t in reversed(range(length - 1)):
            pointer_last = pointer[t + 1]
            back_pointer[t] = pointer_last[batch_index, back_pointer[t + 1]]

        back_pointer = back_pointer.transpose(0, 1).unsqueeze(2) + leading_symbolic
        score = score.unsqueeze(1)

        return back_pointer.squeeze(), torch.clamp(torch.exp(score - forward_energy.unsqueeze(1)), 0, 1).squeeze()
