from parser.models.feature_hmm import FeatureHMM
from torch.nn.functional import dropout
import torch
import torch.nn as nn

from parser.modules import Elmo


class CRFAE(nn.Module):
    def __init__(self, args, feature_hmm):
        """
        use ELMo, CRF auto encoder and feature HMM

        Args:
            args:
        """
        super(CRFAE, self).__init__()
        self.args = args

        self.encoder = Elmo(layer=args.layer,
                            dropout=args.dropout,
                            fd_repr=not args.without_fd_repr)

        # encoder
        self.represent_ln = nn.LayerNorm(args.n_pretrained)
        if args.n_bottleneck <= 0:
            self.encoder_emit_scorer = nn.Linear(args.n_pretrained,
                                                 args.n_labels)
        else:
            self.encoder_emit_scorer = nn.Sequential(
                nn.Linear(args.n_pretrained, args.n_bottleneck),
                nn.LeakyReLU(), nn.Linear(args.n_bottleneck, args.n_labels))
        self.encoder_emit_ln = nn.LayerNorm(args.n_labels)
        self.start = nn.Parameter(torch.randn((args.n_labels, )))
        self.transitions = nn.Parameter(
            torch.randn((args.n_labels, args.n_labels)))
        self.end = nn.Parameter(torch.randn((args.n_labels, )))

        # decoder
        self.feature_hmm = feature_hmm

    def forward(self, words, chars):
        """

        Args:
            inputs:

        Returns:

        """

        # [batch_size, seq_len, n_elmo]
        represent = self.encoder(chars)

        # [batch_size, seq_len, n_labels]
        represent = self.represent_ln(represent)
        encoder_emits = self.encoder_emit_scorer(represent)
        encoder_emits = self.encoder_emit_ln(encoder_emits)

        return encoder_emits

    def loss(self, words, encoder_emits, mask):
        """

        Args:
            words:
            encoder_emits:
            mask:

        Returns:

        """
        _, seq_len, _ = encoder_emits.shape

        encoder_emits = encoder_emits.double()
        decoder_emits = self.feature_hmm.feature_scorer(words).double()

        start = self.start.double()
        transitions = self.transitions.double()
        end = self.end.double()

        # start
        # [1, n_labels] + [batch_size, n_labels]
        log_alpha = start.unsqueeze(0) + encoder_emits[:, 0]
        # [batch_size, n_labels]
        log_beta = log_alpha + decoder_emits[:, 0]

        for i in range(1, seq_len):
            # [batch_size, 1, n_labels] + [1, n_labels, n_labels]
            crf_scores = encoder_emits[:, i].unsqueeze(
                1) + transitions.unsqueeze(0)
            # [batch_size, n_labels, 1] +  [batch_size, n_labels, n_labels]
            alpha_scores = log_alpha.unsqueeze(-1) + crf_scores
            # [batch_size, n_labels, 1] + [batch_size, n_labels, n_labels] + [batch_size, 1, n_labels]
            beta_scores = log_beta.unsqueeze(
                -1) + crf_scores + decoder_emits[:, i].unsqueeze(1)

            log_alpha[mask[:, i]] = torch.logsumexp(alpha_scores,
                                                    dim=1)[mask[:, i]]
            log_beta[mask[:, i]] = torch.logsumexp(beta_scores, dim=1)[mask[:,
                                                                            i]]

        # end
        # [batch_size, n_labels] + [1, n_labels]
        alpha_scores = log_alpha + end.unsqueeze(0)
        # [batch_size, n_labels] + [1, n_labels]
        beta_scores = log_beta + end.unsqueeze(0)

        # [batch_size]
        log_alpha = torch.logsumexp(alpha_scores, dim=-1)
        log_beta = torch.logsumexp(beta_scores, dim=-1)

        return (log_alpha - log_beta).sum().float()

    def crf_loss(self, encoder_emits, labels, mask):
        """
        compute crf loss to train encoder

        Args:
            encoder_emits (torch.Tensor): [batch_size, seq_len, n_labels]
            labels (torch.Tensor): [batch_size, seq_len]
            mask (torch.Tensor): [batch_size, seq_len]

        Returns:

        """

        encoder_emits = encoder_emits.double()

        start = self.start.double()
        transitions = self.transitions.double()
        end = self.end.double()

        # compute log p
        batch_size, seq_len, n_labels = encoder_emits.shape
        # [1, n_labels] + [batch_size, n_labels]
        log_score = start.unsqueeze(0) + encoder_emits[:, 0]
        for i in range(1, seq_len):
            # [batch_size, n_labels, 1] + [1, n_labels, n_labels] + [batch_size, 1, n_labels]
            score = log_score.unsqueeze(-1) + transitions.unsqueeze(
                0) + encoder_emits[:, i].unsqueeze(1)
            log_score[mask[:, i]] = torch.logsumexp(score, dim=1)[mask[:, i]]
        log_p = torch.logsumexp(log_score + end.unsqueeze(0), dim=-1).sum()

        # compute score for pseudo labels
        batch = torch.arange(batch_size).to(encoder_emits.device)
        last_pos = mask.sum(-1) - 1
        # [batch_size]
        score = (start[labels[:, 0]] + end[labels[batch, last_pos]]).sum()
        # emits score
        score += torch.gather(encoder_emits[mask],
                              dim=-1,
                              index=labels[mask].unsqueeze(-1)).sum()
        # transitions score
        for i in range(1, seq_len):
            score += transitions[labels[:, i - 1][mask[:, i]],
                                 labels[:, i][mask[:, i]]].sum()
        return (log_p - score).float()

    def predict(self, words, encoder_emits, mask):
        """

        Args:
            words (torch.Tensor): [batch_size, seq_len]
            encoder_emits (torch.Tensor): [batch_size, seq_len, n_labels]
            mask (torch.Tensor): [batch_size, seq_len]

        Returns:

        """
        batch_size, seq_len, n_labels = encoder_emits.shape

        decoder_emits = self.feature_hmm.feature_scorer(words)

        start_transitions = self.start
        transitions = self.transitions
        end_transitions = self.end

        last_next_position = mask.sum(1)

        # [batch_size, seq_len + 1, n_labels]
        path = encoder_emits.new_zeros(
            (batch_size, seq_len + 1, n_labels)).long()

        # start
        # [batch_size, n_labels]
        score = start_transitions.unsqueeze(
            0) + encoder_emits[:, 0] + decoder_emits[:, 0]

        for i in range(1, seq_len):
            # [batch_size, n_labels, 1] + [batch_size, n_labels, n_labels] => [batch_size, n_labels, n_labels]
            temp_score = score.unsqueeze(-1) + transitions.unsqueeze(
                0) + encoder_emits[:, i].unsqueeze(
                    1) + decoder_emits[:, i].unsqueeze(1)
            # [batch_size, n_labels]
            temp_score, path[:, i] = torch.max(temp_score, dim=1)
            score[mask[:, i]] = temp_score[mask[:, i]]
            path[:, i][~mask[:, i]] = 0

        # end
        score = score + end_transitions.unsqueeze(0)

        batch = torch.arange(batch_size,
                             dtype=torch.long).to(encoder_emits.device)
        path[batch, last_next_position, 0] = torch.argmax(score, dim=-1)

        # tags: [batch_size, seq_len]
        tags = encoder_emits.new_zeros((batch_size, seq_len)).long()
        # pre_tags: [batch_size, 1]
        pre_tags = encoder_emits.new_zeros((batch_size, 1)).long()
        for i in range(seq_len, 0, -1):
            j = i - seq_len - 1
            # pre_tags: [batch_size, 1]
            pre_tags = torch.gather(path[:, i], 1, pre_tags)
            tags[:, j] = pre_tags.squeeze()
        return tags

    def save(self, path):
        """
        don't save model

        Args:
            path (str):

        Returns:

        """
        # save feature hmm firstly
        feature_hmm_path = path + ".feature_hmm"
        self.feature_hmm.save(feature_hmm_path)
        state = {"args": self.args, "state_dict": self.state_dict()}
        torch.save(state, path)

    @classmethod
    def load(cls, path):
        """

        Args:
            path (str):

        Returns:

        """
        device = "cuda" if torch.cuda.is_available() else "cpu"
        feature_hmm_path = path + ".feature_hmm"
        feature_hmm = FeatureHMM.load(feature_hmm_path).to(device)

        state = torch.load(path, map_location=device)
        crf_ae = cls(state["args"], feature_hmm).to(device)
        crf_ae.load_state_dict(state["state_dict"], strict=False)

        return crf_ae
