# -*- coding: utf-8 -*-

import torch
import torch.jit as jit
import torch.nn as nn
import torch.nn.init as init

from ..utils import sequence_mask


class CRF(jit.ScriptModule):
    def __init__(self, num_tags):
        super().__init__()

        assert num_tags > 0, f'invalid number of tags: {num_tags}'

        self.num_tags = num_tags
        self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))
        self.start_transitions = nn.Parameter(torch.Tensor(num_tags))
        self.end_transitions = nn.Parameter(torch.Tensor(num_tags))

        self.reset_parameters()

    def extra_repr(self):
        return f'num_tags={self.num_tags}'

    def reset_parameters(self):
        init.xavier_normal_(self.transitions.data)
        init.uniform_(self.start_transitions, -0.1, 0.1)
        init.uniform_(self.end_transitions, -0.1, 0.1)

    @jit.script_method
    def _compute_normalizer(self, emissions, seq_mask):
        """
        Arguments:
             emissions: [seq_length, batch_size, num_tags] LongTensor
             seq_mask: [seq_length, batch_size] ByteTensor
        """
        transitions = self.transitions

        seq_length, batch_size, num_tags = emissions.size()

        alpha = self.start_transitions + emissions[0]  # start_transitions
        for i in range(seq_length - 1):
            emission = emissions[i + 1]
            mask = seq_mask[i + 1].unsqueeze(-1)

            # transitions[i, j] is the score transiting from i to j
            # alpha_next[b, i, j] = alpha[b, i, ?] + emission[b, ?, j] + transitions[?, i, j]
            scores = alpha.unsqueeze(2) + emission.unsqueeze(1) + transitions.unsqueeze(0)

            alpha_next = scores.logsumexp(1)  # alpha_next[b, j]
            alpha = torch.where(mask, alpha_next, alpha)

        alpha += self.end_transitions  # automaticly broadcast

        return alpha.logsumexp(1)

    @jit.script_method
    def _compute_score(self, emissions, tags, lengths, seq_mask):
        """
        Arguments:
             emissions: [seq_length, batch_size, num_tags] LongTensor
             tags: [seq_length, batch_size] LongTensor
             lengths: [batch_size] LongTensor
        """
        transitions = self.transitions

        score = self.start_transitions[tags[0]]

        for i in range(tags.size(0)):
            step_score = emissions[i].gather(1, tags[i].unsqueeze(-1)).squeeze(-1)
            if i > 0:
                # Transition score to next tag, only added if next timestep is valid (mask == 1)
                step_score += transitions[tags[i - 1], tags[i]]
            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            score += step_score * seq_mask[i].float()

        # End transition score
        last_tags = tags.gather(0, (lengths - 1).unsqueeze(0)).squeeze(0)
        score += self.end_transitions[last_tags]

        return score

    def _normalize_args(self, emissions, tags=None, lengths=None, batch_first=True):
        if batch_first:
            emissions = emissions.transpose(1, 0)
            if tags is not None:
                tags = tags.transpose(1, 0)

        seq_length, batch_size, _ = emissions.size()

        if lengths is None:
            lengths = torch.full((batch_size, ), seq_length)

        assert (lengths > 0).all()

        seq_mask = sequence_mask(lengths, seq_length, batch_first=False)
        return emissions, tags, lengths, seq_mask

    def forward(self, emissions, tags, lengths=None, batch_first=True, reduction='mean'):
        """
        Arguments:
            emissions: [batch_size, seq_length, num_tags] FloatTensor
            tags: [batch_size, seq_length] LongTensor
            lengths: [batch_size]
        """
        emissions, tags, lengths, seq_mask = \
            self._normalize_args(emissions, tags, lengths, batch_first=batch_first)

        denominator = self._compute_normalizer(emissions, seq_mask)
        numerator = self._compute_score(emissions, tags, lengths, seq_mask)

        result = numerator - denominator
        if reduction == 'mean':
            return result.mean()
        elif reduction == 'sum':
            return result.sum()

        assert result == 'none'
        return result

    def viterbi_decode(self, emissions, lengths, batch_first=True):
        emissions, _, lengths, seq_mask = \
            self._normalize_args(emissions, lengths=lengths, batch_first=batch_first)

        return self._viterbi_decode(emissions, lengths, seq_mask)

    @jit.script_method
    def _viterbi_decode(self, emissions, lengths, seq_mask):
        transitions = self.transitions

        pointers = []
        seq_length, batch_size, num_tags = emissions.size()

        alpha = self.start_transitions + emissions[0]  # start_transitions

        for i in range(seq_length - 1):
            emission = emissions[i + 1]
            mask = seq_mask[i + 1].unsqueeze(-1)

            # transitions[i, j] is the score transiting from i to j
            # alpha_next[b, i, j] = alpha[b, i, ?] + emission[b, ?, j] + transitions[?, i, j]
            alpha_next = alpha.unsqueeze(2) + emission.unsqueeze(1) + transitions.unsqueeze(0)
            alpha_next, indices = alpha_next.max(1)

            alpha = torch.where(mask, alpha_next, alpha)
            pointers.append(indices)

        alpha += self.end_transitions  # automaticly broadcast

        # shape: (batch_size,)
        best_paths = []
        for index in range(batch_size):
            length = int(lengths[index] - 1)
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            path = [0] * (length + 1)
            path[-1] = alpha[index].argmax(dim=0)

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for step in range(length):
                step = length - 1 - step
                path[step] = pointers[step][index, path[step + 1]]

            # Reverse the order because we start from the last timestep
            best_paths.append(torch.tensor(path))

        return best_paths
