# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import numpy as np
import torch

from . import data_utils, FairseqDataset

logger = logging.getLogger(__name__)


def collate(
    samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
    input_feeding=True, num_summ_langs=1
):
    if len(samples) == 0:
        return {}

    def merge_samples(
        given_samples, key, left_pad, move_eos_to_beginning=False,
        pad=pad_idx
    ):
        return data_utils.collate_tokens(
            [s[key] for s in given_samples],
            pad, eos_idx, left_pad, move_eos_to_beginning,
        )

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )

    def check_alignment(alignment, src_len, tgt_len):
        if alignment is None or len(alignment) == 0:
            return False
        if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
            logger.warning("alignment size mismatch found, skipping alignment!")
            return False
        return True

    def compute_alignment_weights(alignments):
        """
        Given a tensor of shape [:, 2] containing the source-target indices
        corresponding to the alignments, a weight vector containing the
        inverse frequency of each target index is computed.
        For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
        a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
        index 3 is repeated twice)
        """
        align_tgt = alignments[:, 1]
        _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
        align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
        return 1. / align_weights.float()

    id = torch.LongTensor([s['id'] for s in samples])
    src_tokens = merge('source', left_pad=left_pad_source)
    # sort by descending source length
    src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
    src_lengths, sort_order = src_lengths.sort(descending=True)
    id = id.index_select(0, sort_order)
    src_tokens = src_tokens.index_select(0, sort_order)

    max_sent_num = max([len(sample['margin_tokens']) for sample in samples])
    sent_samples = []
    for s in samples:
        sent_num = len(s['margin_tokens'])
        src_len = s['source'].numel()
        margin_mask = torch.ones(src_len).bool()
        for margin in s['margin_tokens']:
            sent_samples.append(
                {'margin_tokens': margin}
            )
        for _ in range(max_sent_num - sent_num):
            sent_samples.append(
                {'margin_tokens': margin_mask}
            )

    margin_tokens = merge_samples(
        sent_samples,
        'margin_tokens', left_pad=left_pad_source
    )

    bsz = len(samples)
    margin_tokens = torch.reshape(margin_tokens, [bsz, max_sent_num, -1])
    margin_tokens = margin_tokens.index_select(0, sort_order)

    for (si, sample) in enumerate(samples):
        sent_num = len(sample['ext_label'])
        samples[si]['ext_label'] = torch.LongTensor(sample['ext_label'])
    ext_label = merge_samples(
        samples,
        'ext_label', left_pad=False,
        pad=2
    )
    ext_label = ext_label.index_select(0, sort_order)
    ext_mask = torch.eq(ext_label, 2)

    prev_output_tokens = None
    target = None
    if samples[0].get('target', None) is not None:
        target = merge('target', left_pad=left_pad_target)
        target = target.index_select(0, sort_order)
        tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
        ntokens = sum(len(s['target']) for s in samples)

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                'target',
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
    else:
        ntokens = sum(len(s['source']) for s in samples)

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'src_tokens': src_tokens,
            'src_lengths': src_lengths,
            'margin': margin_tokens, # [batch, max_sent_num, src_len]
        },
        'target': target,
        'ext_label': ext_label,
        'ext_mask': ext_mask, # 1 means mask
    }

    if len(samples) >= 1 and 'lang_label' in samples[0]:
        lang_label = torch.tensor(
            [sample['lang_label'] for sample in samples], dtype=torch.long
        )
        lang_label = lang_label.index_select(0, sort_order)
        batch['lang_label'] = lang_label

    if prev_output_tokens is not None:
        batch['net_input']['prev_output_tokens'] = prev_output_tokens

    if samples[0].get('alignment', None) is not None:
        bsz, tgt_sz = batch['target'].shape
        src_sz = batch['net_input']['src_tokens'].shape[1]

        offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
        offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
        if left_pad_source:
            offsets[:, 0] += (src_sz - src_lengths)
        if left_pad_target:
            offsets[:, 1] += (tgt_sz - tgt_lengths)

        alignments = [
            alignment + offset
            for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
            for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
            if check_alignment(alignment, src_len, tgt_len)
        ]

        if len(alignments) > 0:
            alignments = torch.cat(alignments, dim=0)
            align_weights = compute_alignment_weights(alignments)

            batch['alignments'] = alignments
            batch['align_weights'] = align_weights

    return batch


class ExtractiveDataset(FairseqDataset):
    """
    A wrapper around TokenBlockDataset for BART dataset.

    Args:
        dataset (TokenBlockDataset): dataset to wrap
        sizes (List[int]): sentence lengths
        vocab (~fairseq.data.Dictionary): vocabulary
        shuffle (bool, optional): shuffle the elements before batching.
          Default: ``True``
        seed: Seed for random number generator for reproducibility.
        args: argparse arguments.
    """

    def __init__(
        self,
        dataset,
        vocab,
        seed,
        oracle_labels=None,
        eos=None,
        summ_langs=[],
        denoise_langs=[],
        src_lang_token_pos=0
    ):
        self.dataset = dataset
        self.sizes = dataset.sizes

        self.vocab = vocab
        self.seed = seed
        self.eos = (eos if eos is not None else vocab.eos())
    
        self.sent_split_tag = self.vocab.eos()
        self.epoch = 0
        self.oracle_labels = oracle_labels
        self.summ_langs = [self.vocab.index("[{}]".format(lang)) for lang in summ_langs]
        self.denoise_langs = [self.vocab.index("[{}]".format(lang)) for lang in denoise_langs]
        self.langs = self.summ_langs + self.denoise_langs
        self.src_lang_token_pos = 0

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    def __getitem__(self, index):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            example = self.dataset[index]
            source = example['source']
            sents = self.get_sents(source)
            example['margin_tokens'] = [sent['mask'] for sent in sents]
            ext_label = [0 for _ in range(len(sents))]
            oracle_labels = self.oracle_labels[index].get('label', [])
            for sid in oracle_labels:
                if sid < len(sents):
                    ext_label[sid] = 1
            if len(self.langs) != 0: 
                lang_token = source[self.src_lang_token_pos]
                lang_label = self.langs.index(lang_token)
                example['lang_label'] = lang_label
                if lang_label >= len(self.summ_langs):
                    ext_label = [2 for _ in range(len(sents))]

            example['ext_label'] = ext_label
        return example

    def __len__(self):
        return len(self.dataset)

    def get_sents(self, source):
        full_stops = (source == self.sent_split_tag)
        full_stops[-2] = 1
        sentence_ends = (~full_stops[:-1] * full_stops[1:]).nonzero(as_tuple=False).squeeze(-1) + 1 # 1d-tensor
        num_sentences = sentence_ends.size(0)

        bos_index = torch.eq(source, self.vocab.bos()).nonzero(as_tuple=False).squeeze(-1) + 1
        sentence_ends = torch.cat((
            bos_index, sentence_ends
        ), dim=0)
        sents = []
        for idx in range(num_sentences):
            # ignore language and <bos>
            sent_tokens = source[sentence_ends[idx]:sentence_ends[idx+1]]
            mask = torch.zeros(len(source), dtype=torch.bool)
            mask[sentence_ends[idx]:sentence_ends[idx+1]] = True
            sents.append({'sentence':sent_tokens, 'mask': ~mask}) # 1 means mask
        return sents

    def collater(self, samples):
        # For now only supports datasets with same underlying collater implementations
        num_summ_langs = len(self.summ_langs)
        return collate(
            samples, pad_idx=self.dataset.src_dict.pad(), eos_idx=self.eos,
            left_pad_source=self.dataset.left_pad_source, left_pad_target=self.dataset.left_pad_target,
            input_feeding=self.dataset.input_feeding,
            num_summ_langs=num_summ_langs
        )

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return self.dataset.num_tokens(index)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return self.dataset.size(index)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        return self.dataset.ordered_indices()
