# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
import numpy as np

import torch
import torch.nn.functional as F

from fairseq import search, utils, checkpoint_utils
from fairseq.data import data_utils

from bert import BertTokenizer
from bert import BertForPreTraining
from fairseq.logging.meters import safe_round
from fairseq.data.data_utils import post_process


class Wav2BertDecoder(object):
    def __init__(
        self,
        tgt_dict,
        args=None,
    ):
        """Generates translations of a given source sentence.

        Args:
            tgt_dict (~fairseq.data.Dictionary): target dictionary
        """
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos()
        self.vocab_size = len(tgt_dict)
        self.bos = tgt_dict.bos()
       
        self.tgt_dict = tgt_dict

    @torch.no_grad()
    def generate(
        self,
        models,
        sample,
        tgt_bert_tokenizer=None,
        **kwargs
    ):
        model = models[0]
        net_output = model(**sample["net_input"])

        lprobs = model.get_normalized_probs(
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder
        lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() #(T, B, C)->(B, T, C)
        non_padding_mask = ~net_output["padding_mask"]
        input_lengths = non_padding_mask.long().sum(-1)

        res = []
        for distribution, length in zip(lprobs_t, input_lengths):
            distribution = distribution[:length]
            result = distribution.argmax(-1).unique_consecutive()
            result = result[result != self.bos]
            score = 0.0
            res.append([{'tokens': result,
                         "score": score}])

        return res
