import torch
import torch.nn as nn

# from utils.loss import LossComputeBase


def collapse_copy_scores(scores, tgt_vocab, src_vocab, offset,
                         batch_dim=1, batch_offset=None):
    """
    Given scores from an expanded dictionary
    corresponeding to a batch, sums together copies,
    with a dictionary word when it is ambiguous.
    """
    # offset = len(tgt_vocab["itos"])
    for b in range(scores.size(batch_dim)):
        blank = []
        fill = []
        batch_id = batch_offset[b] if batch_offset is not None else b
        # index = batch.indices.data[batch_id]
        # src_vocab = src_vocabs[index]
        for i in range(1, len(src_vocab["itos"])):
            sw = src_vocab["itos"][i]
            if sw not in tgt_vocab["stoi"]:
                sw = "UNK"
            ti = tgt_vocab["stoi"][sw]
            if ti != 0:
                blank.append(offset + i)
                fill.append(ti)
        if blank:
            blank = torch.LongTensor(blank)
            fill = torch.LongTensor(fill)
            score = scores[:, b] if batch_dim == 1 else scores[b]
            score.index_add_(0, fill, score.index_select(0, blank))
            score.index_fill_(0, blank, 1e-20)
    return scores

class CopyGenerator(nn.Module):
    """An implementation of pointer-generator networks
    :cite:`DBLP:journals/corr/SeeLM17`.
    These networks consider copying words
    directly from the source sequence.
    The copy generator is an extended version of the standard
    generator that computes three values.
    * :math:`p_{softmax}` the standard softmax over `tgt_dict`
    * :math:`p(z)` the probability of copying a word from
      the source
    * :math:`p_{copy}` the probility of copying a particular word.
      taken from the attention distribution directly.
    The model returns a distribution over the extend dictionary,
    computed as
    :math:`p(w) = p(z=1)  p_{copy}(w)  +  p(z=0)  p_{softmax}(w)`
    .. mermaid::
       graph BT
          A[input]
          S[src_map]
          B[softmax]
          BB[switch]
          C[attn]
          D[copy]
          O[output]
          A --> B
          A --> BB
          S --> D
          C --> D
          D --> O
          B --> O
          BB --> O
    Args:
       input_size (int): size of input representation
       output_size (int): size of output vocabulary
       pad_idx (int)
    """

    def __init__(self, input_size, output_size, pad_idx, gpu):
        super(CopyGenerator, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        # self.linear_copy = nn.Linear(input_size, 1)
        self.linear_copy = nn.Linear(input_size, 3)
        self.pad_idx = pad_idx
        self.gpu = gpu

    def forward(self, hidden, copy, abs_attn, context_attn, abs_src_map, context_src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.
        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, abs_slen = abs_attn.size()
        abs_slen_, batch, abs_cvocab = abs_src_map.size()
        assert batch_by_tlen == batch_by_tlen_
        if abs_slen_ != abs_slen:
            tmp_abs_attn = torch.zeros(batch_by_tlen_, abs_slen_)
            tmp_abs_attn[:,:abs_slen] = abs_attn
            abs_attn = tmp_abs_attn
            if self.gpu:
                abs_attn = abs_attn.cuda()
            del tmp_abs_attn

        batch_by_tlen_, context_slen = context_attn.size()
        context_slen_, batch, context_cvocab = context_src_map.size()
        assert batch_by_tlen == batch_by_tlen_
        if context_slen_ != context_slen:
            tmp_context_attn = torch.zeros(batch_by_tlen_, context_slen_)
            tmp_context_attn[:,:context_slen] = context_attn
            context_attn = tmp_context_attn
            if self.gpu:
                context_attn = context_attn.cuda()
            del tmp_context_attn

        # Original probabilities.
        logits = self.linear(hidden)
        del hidden
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)
        del logits

        # Probability of copying p(z=1) batch. p_gen
        p_copy = torch.softmax(self.linear_copy(copy),-1) #[p_gen, p_copy1, p_copy2]
        del copy
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, p_copy[:,0].unsqueeze(1)) # [batch_by_tlen, vocab]
        abs_mul_attn = torch.mul(abs_attn, p_copy[:,1].unsqueeze(1))
        context_mul_attn = torch.mul(context_attn, p_copy[:,2].unsqueeze(1))
        del prob, abs_attn, context_attn

        abs_copy_prob = torch.bmm(
            abs_mul_attn.view(-1, batch, abs_slen_).transpose(0, 1),
            abs_src_map.transpose(0, 1)
        ).transpose(0, 1) # [tgt_len, batch, abs_cvocab]
        abs_copy_prob = abs_copy_prob.contiguous().view(-1, abs_cvocab) # [batch_by_tlen, abs_cvocab]
        context_copy_prob = torch.bmm(
            context_mul_attn.view(-1, batch, context_slen_).transpose(0, 1),
            context_src_map.transpose(0, 1)
        ).transpose(0, 1)
        context_copy_prob = context_copy_prob.contiguous().view(-1, context_cvocab) # [batch_by_tlen, context_cvocab]

        return (out_prob, abs_copy_prob, context_copy_prob)

class CopyGeneratorLoss(nn.Module):
    """Copy generator criterion."""

    def __init__(self, vocab_size, force_copy, unk_index=0,
                 ignore_index=-100, eps=1e-20):
        super(CopyGeneratorLoss, self).__init__()
        self.force_copy = force_copy
        self.eps = eps
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index # padding
        self.unk_index = unk_index

    def forward(self, out_scores, abs_scores, context_scores, abs_align, context_align, target):
        """
        Args:
            scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
                whose sum along dim 1 is less than or equal to 1, i.e. cols
                softmaxed.
            align (LongTensor): ``(batch_size x tgt_len)``
            target (LongTensor): ``(batch_size x tgt_len)``
        """
        # probabilities assigned by the model to the gold targets
        vocab_probs = out_scores.gather(1, target.unsqueeze(1)).squeeze(1)

        # probability of tokens copied from source
        copy_ix = abs_align.unsqueeze(1)
        abs_copy_tok_probs = abs_scores.gather(1, copy_ix).squeeze(1) # [batch_size x tgt_len]
        # Set scores for unk to 0 and add eps
        abs_copy_tok_probs[abs_align == self.unk_index] = 0

        copy_ix = context_align.unsqueeze(1)
        context_copy_tok_probs = context_scores.gather(1, copy_ix).squeeze(1)
        # Set scores for unk to 0 and add eps
        context_copy_tok_probs[context_align == self.unk_index] = 0
        out = abs_copy_tok_probs + context_copy_tok_probs
        out = out + self.eps

        abs_align_unk = abs_align.eq(self.unk_index).float()
        context_align_unk = context_align.eq(self.unk_index).float()
        target_unk = target.eq(self.unk_index).float()
        target_not_unk = target.ne(self.unk_index).float()
        if not self.force_copy:
            # Add score for non-unks in target
            out = out + vocab_probs.mul(target_not_unk)
            # Add score for when word is unk in both align and tgt
            out = out + vocab_probs.mul(abs_align_unk).mul(context_align_unk).mul(target_unk)
        else:
            # Forced copy. Add only probability for not-copied tokens
            out = out + vocab_probs.mul(abs_align_unk).mul(context_align_unk)

        loss = -out.log()  # just NLLLoss; can the module be incorporated?
        # Drop padding.
        loss[target == self.ignore_index] = 0
        return out, loss