# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
import logging

from . import data_utils, FairseqDataset
import torch.nn.functional as F

logger = logging.getLogger(__name__)

def collate(
    samples, vocab_size, 
    pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
    input_feeding=True,
):
    if len(samples) == 0:
        return {}

    def collect_bow(key):
        results = []
        for s in samples:
            bow_tensor = F.one_hot(s[key], num_classes=vocab_size).to(torch.int8)
            bow_tensor = torch.sum(bow_tensor, dim=0)
            results.append(bow_tensor)
        return torch.stack(results, dim=0)

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )

    def check_alignment(alignment, src_len, tgt_len):
        if alignment is None or len(alignment) == 0:
            return False
        if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
            logger.warning("alignment size mismatch found, skipping alignment!")
            return False
        return True

    def compute_alignment_weights(alignments):
        """
        Given a tensor of shape [:, 2] containing the source-target indices
        corresponding to the alignments, a weight vector containing the
        inverse frequency of each target index is computed.
        For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
        a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
        index 3 is repeated twice)
        """
        align_tgt = alignments[:, 1]
        _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
        align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
        return 1. / align_weights.float()

    id = torch.LongTensor([s['id'] for s in samples])
    src_tokens = merge('source', left_pad=left_pad_source)
    # sort by descending source length
    src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
    src_lengths, sort_order = src_lengths.sort(descending=True)
    id = id.index_select(0, sort_order)
    src_tokens = src_tokens.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get('target', None) is not None:
        target = merge('target', left_pad=left_pad_target)
        target = target.index_select(0, sort_order)
        tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
        ntokens = sum(len(s['target']) for s in samples)

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                'target',
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
    else:
        ntokens = sum(len(s['source']) for s in samples)

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'src_tokens': src_tokens,
            'src_lengths': src_lengths,
        },
        'target': target,
    }

    if samples[0].get('nonstop_tokens', None) is not None:
        nonstop_label_vector = collect_bow('nonstop_tokens')
        nonstop_label_vector = nonstop_label_vector.index_select(0, sort_order)
        batch['nonstop_label_vector'] = nonstop_label_vector
        batch['num_nonstop_token'] = sum([s['num_nonstop_token'] for s in samples])
    
    if samples[0].get('stop_tokens', None) is not None:
        stop_label_vector = collect_bow('stop_tokens')
        stop_label_vector = stop_label_vector.index_select(0, sort_order)
        batch['stop_label_vector'] = stop_label_vector
        batch['num_stop_token'] = sum([s['num_stop_token'] for s in samples])

    if prev_output_tokens is not None:
        batch['net_input']['prev_output_tokens'] = prev_output_tokens

    if samples[0].get('alignment', None) is not None:
        bsz, tgt_sz = batch['target'].shape
        src_sz = batch['net_input']['src_tokens'].shape[1]

        offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
        offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
        if left_pad_source:
            offsets[:, 0] += (src_sz - src_lengths)
        if left_pad_target:
            offsets[:, 1] += (tgt_sz - tgt_lengths)

        alignments = [
            alignment + offset
            for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
            for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
            if check_alignment(alignment, src_len, tgt_len)
        ]

        if len(alignments) > 0:
            alignments = torch.cat(alignments, dim=0)
            align_weights = compute_alignment_weights(alignments)

            batch['alignments'] = alignments
            batch['align_weights'] = align_weights

    return batch

class BOWDataset(FairseqDataset):
    """
    A wrapper around TokenBlockDataset for BART dataset.

    Args:
        dataset (TokenBlockDataset): dataset to wrap
        sizes (List[int]): sentence lengths
        vocab (~fairseq.data.Dictionary): vocabulary
        shuffle (bool, optional): shuffle the elements before batching.
          Default: ``True``
        seed: Seed for random number generator for reproducibility.
        args: argparse arguments.
    """

    def __init__(
        self,
        dataset,
        nonstop_mapping,
        stop_mapping,
        vocab,
        seed,
        eos=None
    ):
        self.dataset = dataset
        self.src_dict = getattr(dataset, 'src_dict', None)
        self.tgt_dict = getattr(dataset, 'tgt_dict', None)
        self.left_pad_source = getattr(dataset, 'left_pad_source', None)
        self.left_pad_target = getattr(dataset, 'left_pad_target', None)
        self.input_feeding = getattr(dataset, 'input_feeding', None)
        self.nonstop_mapping = nonstop_mapping
        self.stop_mapping = stop_mapping

        self.sizes = dataset.sizes

        self.vocab = vocab
        self.seed = seed
        self.eos = (eos if eos is not None else vocab.eos())

        self.sent_split_tag = self.vocab.eos()

        self.epoch = 0

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    def __getitem__(self, index):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            example = self.dataset[index]
            if self.nonstop_mapping is not None:
                nonstop = self.nonstop_mapping[index]
                unique = list(set(nonstop.tolist()))
                unique_tensor = torch.tensor(unique, dtype=torch.int64)
                example['nonstop_tokens'] = unique_tensor
                example['num_nonstop_token'] = len(unique)
            else:
                unique_tensor = torch.tensor(self.vocab.pad(), dtype=torch.int64)
                example['num_nonstop_token'] = 1

            if self.stop_mapping is not None:
                stop = self.stop_mapping[index]
                unique = list(set(stop.tolist()))
                unique_tensor = torch.tensor(unique, dtype=torch.int64)
                example['stop_tokens'] = unique_tensor
                example['num_stop_token'] = len(unique)
            else:
                unique_tensor = torch.tensor(self.vocab.pad(), dtype=torch.int64)
                example['num_stop_token'] = 1

        return example 

    def __len__(self):
        return len(self.dataset)

    def collater(self, samples):
        # For now only supports datasets with same underlying collater implementations
        return collate(
            samples, vocab_size=len(self.src_dict), 
            pad_idx=self.src_dict.pad(), eos_idx=self.eos,
            left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding
        )

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return self.dataset.num_tokens(index)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return self.dataset.size(index)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        return self.dataset.ordered_indices()
