# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import random
from fairseq.data import data_utils, FairseqDataset


logger = logging.getLogger(__name__)

def collate_prev_tokens(values, pad_idx, bos_idx=None, left_pad=False, move_eos_to_beginning=False):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    res = values[0].new(len(values), size).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            dst[0] = bos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res

def collate(
    samples,
    pad_idx,
    eos_idx,
    left_pad_source=True,
    left_pad_target=False,
    input_feeding=True,
    decoder_start_token_idx=None,
    sort_examples=True,
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )

    def check_alignment(alignment, src_len, tgt_len):
        if alignment is None or len(alignment) == 0:
            return False
        if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
            logger.warning("alignment size mismatch found, skipping alignment!")
            return False
        return True

    def compute_alignment_weights(alignments):
        """
        Given a tensor of shape [:, 2] containing the source-target indices
        corresponding to the alignments, a weight vector containing the
        inverse frequency of each target index is computed.
        For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
        a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
        index 3 is repeated twice)
        """
        align_tgt = alignments[:, 1]
        _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
        align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
        return 1. / align_weights.float()

    id = torch.LongTensor([s['id'] for s in samples])
    src_tokens = merge('source', left_pad=left_pad_source)
    # sort by descending source length
    src_lengths = torch.LongTensor([
        s['source'].ne(pad_idx).long().sum() for s in samples
    ])
    if sort_examples:
        src_lengths, sort_order = src_lengths.sort(descending=True)
        id = id.index_select(0, sort_order)
        src_tokens = src_tokens.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get('target', None) is not None:
        target = merge('target', left_pad=left_pad_target)
        tgt_lengths = torch.LongTensor([
            s['target'].ne(pad_idx).long().sum() for s in samples
        ])
        if sort_examples:
            target = target.index_select(0, sort_order)
            tgt_lengths = tgt_lengths.index_select(0, sort_order)
        ntokens = tgt_lengths.sum().item()

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            if decoder_start_token_idx is None:
                decoder_start_token_idx = eos_idx
            prev_output_tokens = collate_prev_tokens(
                [s['target'] for s in samples],
                pad_idx, decoder_start_token_idx, False, True)

            if sort_examples:
                prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
    else:
        ntokens = src_lengths.sum().item()

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'src_tokens': src_tokens,
            'src_lengths': src_lengths,
        },
        'target': target,
    }
    if prev_output_tokens is not None:
        batch['net_input']['prev_output_tokens'] = prev_output_tokens

    if samples[0].get('alignment', None) is not None:
        bsz, tgt_sz = batch['target'].shape
        src_sz = batch['net_input']['src_tokens'].shape[1]

        offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
        offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
        if left_pad_source:
            offsets[:, 0] += (src_sz - src_lengths)
        if left_pad_target:
            offsets[:, 1] += (tgt_sz - tgt_lengths)

        alignments = [
            alignment + offset
            for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
            for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
            if check_alignment(alignment, src_len, tgt_len)
        ]

        if len(alignments) > 0:
            alignments = torch.cat(alignments, dim=0)
            align_weights = compute_alignment_weights(alignments)

            batch['alignments'] = alignments
            batch['align_weights'] = align_weights

    return batch


class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        max_source_positions (int, optional): max number of tokens in the
            source sentence (default: 1024).
        max_target_positions (int, optional): max number of tokens in the
            target sentence (default: 1024).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        decoder_start_token_idx (integer, optional), if set, put this token to the
            beggining of previous token list in the decoder.
    """

    def __init__(
        self, src, src_sizes, src_dict,
        tgt=None, tgt_sizes=None, tgt_dict=None,
        left_pad_source=True, left_pad_target=False,
        max_source_positions=1024, max_target_positions=1024,
        shuffle=True, input_feeding=True,
        remove_eos_from_source=False, append_eos_to_target=False,
        align_dataset=None,
        append_bos=False, eos=None,
        decoder_start_token_idx=None,
        sort_examples=True,
        args=None
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.max_source_positions = max_source_positions
        self.max_target_positions = max_target_positions
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.align_dataset = align_dataset
        self.sort_examples = sort_examples
        self.args = args
        if self.align_dataset is not None:
            assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
        self.append_bos = append_bos
        self.eos = (eos if eos is not None else src_dict.eos())

        if decoder_start_token_idx is not None:
            self.decoder_start_token_idx = decoder_start_token_idx
        else:
            if self.tgt_dict:
                self.decoder_start_token_idx = self.tgt_dict.eos()
            else:
                self.decoder_start_token_idx = self.src_dict.eos()
        self.is_train_set = None

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][-1] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        is_corrupted = len(src_item) < len(tgt_item)
        if hasattr(self.args, 'use_pegasus') and self.args.use_pegasus:
            if is_corrupted and self.args.pooler_remove_corrupted_datapoint:
                # sometimes abstracts are longer than documents due to bad preprocessing,
                # it is native to arxiv train set
                example = {
                    'id': index,
                    'source': src_item[-1].unsqueeze(0),
                    'target': tgt_item[-1].unsqueeze(0),
                }
            else:
                example = self.convert_to_pegasus_gapped_datapoint(index, src_item)
        else:
            if self.is_train_set and  is_corrupted and self.args.pooler_remove_corrupted_datapoint:
                # in arxiv summarization train set, we can remove invalid entries
                example = {
                    'id': index,
                    'source': src_item[-1].unsqueeze(0),
                    'target': tgt_item[-1].unsqueeze(0),
                }
            else:
                example = {
                    'id': index,
                    'source': src_item,
                    'target': tgt_item,
                }
        if self.align_dataset is not None:
            example['alignment'] = self.align_dataset[index]
        return example

    def convert_to_pegasus_gapped_datapoint(self, index, src_item):
        """ Split by dots and choose randomly sentences that will be masked
        (by 2 tokens per masked sentence). The masked sentences are to be generated,
         thus they are the target.
         zdanie 1. zdanie 2. zdanie 3           : 2048   -> 128
         >>
         in:  zdanie 1. # & # & zdanie 3            : 1800 -> 250


        TODO: możnaby wypadować
        > specjalny token w słowniku gdzieś wsadzić
         """
        tensor_conf ={'device': src_item.device, 'dtype': src_item.dtype}
        s_splitter = self.tgt_dict.symbols.index('▁.')
        at = torch.tensor([0], **tensor_conf).unsqueeze(0)
        bt = (src_item == s_splitter).nonzero()
        ct = torch.tensor([len(src_item) - 1], **tensor_conf).unsqueeze(0)
        split_pos = torch.cat((at, bt, ct), dim=0)
        GSR = self.args.pegasus_gsr
        number_of_sentences = max(1, int(GSR * (len(split_pos) - 1)))
        number_of_sentences = min(number_of_sentences, len(split_pos) - 2)
        bstart = sorted(random.sample(range(len(split_pos) - 2), number_of_sentences))
        bend = [b + 1 for b in bstart]
        # Get target (gaps)
        idd = []
        for s, e in zip(bstart, bend):
            ss = split_pos[s]
            se = split_pos[e]
            idd += list(range(ss + 1, se + 1))
        gathered_target = torch.gather(src_item, 0, torch.torch.tensor(idd, **tensor_conf)
                                       )
        # debug
        # self.tgt_dict.string(gathered_target)
        # get input and mask gaps
        mask_token = torch.tensor([31977, 31989] * 2, **tensor_conf)  # special sequence of mask tokens: '# &'
        src_item_with_mask_appended = torch.cat((src_item, mask_token))
        idd_inv = []
        bs_inv = [0] + bend
        be_inv = bstart + [len(split_pos) - 1]
        for s, e in zip(bs_inv, be_inv):
            ss = split_pos[s]
            se = split_pos[e]
            idd_inv += list(range(ss + 1, se + 1))
            idd_inv += list(range(len(src_item), len(src_item_with_mask_appended)))
        idd_inv = idd_inv[:-len(mask_token)]
        gathered_source = torch.gather(src_item_with_mask_appended, 0,
                                       torch.tensor(idd_inv, **tensor_conf)
                                       )
        predicate_debug_1 = not(len(gathered_source.shape) == 1 and gathered_source.shape[0] > 0)
        predicate_debug_2 = not(len(gathered_target.shape) == 1 and gathered_target.shape[0] > 0)
        if any([predicate_debug_1, predicate_debug_2]):
            print(f'D1_source = {predicate_debug_1}, D2_target = {predicate_debug_2}')
            print(f'gathered_source = {gathered_source}, len = {len(gathered_source)}')
            print(f'gathered_target = {gathered_target}, len = {len(gathered_target)}')
            if all([predicate_debug_1, predicate_debug_2]):     # both are failing
                example = {
                    'id': index,
                    'source': src_item[-1].unsqueeze(0),
                    'target': src_item[-1].unsqueeze(0),
                }
                return example
            if predicate_debug_2:
                gathered_target = gathered_source[:self.max_target_positions]
            elif predicate_debug_1:
                gathered_source = gathered_target[:self.max_source_positions]
        # debug
        # self.tgt_dict.string(gathered_source)
        # prepare example

        example = {
            'id': index,
            'source': gathered_source[:self.max_source_positions],
            'target': gathered_target[:self.max_target_positions],
        }
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
        """
        return collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            decoder_start_token_idx=self.decoder_start_token_idx,
            sort_examples=self.sort_examples,
        )

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.tgt_sizes is not None:
            indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
        return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]

    @property
    def supports_prefetch(self):
        return (
            getattr(self.src, 'supports_prefetch', False)
            and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
        )

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)
