import numpy as np
import torch

from . import data_utils, FairseqDataset


def collate(
        samples,
        pad_idx,
        eos_idx,
        left_pad_source=True,
        left_pad_target=False,
        input_feeding=True,
        order_by_size=False
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )

    validate_samples_eos_bos(eos_idx, samples)

    id = torch.LongTensor([s['id'] for s in samples])
    art_tokens = merge('art', left_pad=left_pad_source)
    lab_tokens = merge('lab', left_pad=left_pad_source)

    # sort by descending source length
    art_lengths = torch.LongTensor([s['art'].numel() for s in samples])
    lab_lengths = torch.LongTensor([s['lab'].numel() for s in samples])
    if order_by_size:
        art_lengths, sort_order = art_lengths.sort(descending=True)
        id = id.index_select(0, sort_order)
        art_tokens = art_tokens.index_select(0, sort_order)
        lab_tokens = lab_tokens.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get('ans', None) is not None:
        target = merge('ans', left_pad=left_pad_target)

        ntokens = sum(len(s['ans']) for s in samples)

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                'ans',
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
            )
            if order_by_size:
                prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
        if order_by_size:
            target = target.index_select(0, sort_order)

    else:
        ntokens = sum(len(s['art']) for s in samples)

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'doc_tokens': art_tokens,
            'doc_lengths': art_lengths,
            'prop_tokens': lab_tokens,
            'prop_lengths': lab_lengths,
        },
        'target': target,
    }
    if prev_output_tokens is not None:
        batch['net_input']['prev_output_tokens'] = prev_output_tokens

    return batch


def validate_samples_eos_bos(eos_idx, samples, ignore_bos_eos_asserts=True):
    try:
        for el in samples:
            for input_type in ['ans', 'art', 'lab']:
                if el[input_type] is not None:
                    assert el[input_type][-1].item() == eos_idx, f'Eos not at the end of the {input_type}'
                    if len(el[input_type]) >= 2:
                        assert el[input_type][-2].item() != eos_idx, f'Eos at the end of the {input_type} twice'
                    assert el[input_type][0].item() == 0, f'Bos not at the beggining of the {input_type}'
                    assert el[input_type][1].item() != 0, f'Bos at the beggining of the {input_type} twice'
    except AssertionError as err:
        if ignore_bos_eos_asserts:
            pass
        else:
            raise err


class SkynetDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (List[torch.utils.data.Dataset]): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        max_source_positions (int, optional): max number of tokens in the
            source sentence (default: 1024).
        max_target_positions (int, optional): max number of tokens in the
            target sentence (default: 1024).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
    """

    def __init__(
            self, lab, lab_sizes, src_dict,
            art, art_sizes,
            ans=None, ans_sizes=None,
            left_pad_source=True, left_pad_target=False,
            max_source_positions=1024, max_target_positions=1024,
            shuffle=True, input_feeding=True,
            remove_eos_from_source=False, append_eos_to_target=False,
            append_bos=False, order_by_size=False
    ):

        self.lab = lab
        self.art = art
        self.ans = ans
        self.lab_sizes = np.array(lab_sizes)
        self.art_sizes = np.array(art_sizes)
        self.ans_sizes = np.array(ans_sizes) if ans_sizes is not None else None
        self.sizes = self.art_sizes
        self.src_dict = src_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.max_source_positions = max_source_positions
        self.max_target_positions = max_target_positions
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.append_bos = append_bos
        self.order_by_size = order_by_size

    def __getitem__(self, index):
        ans_item = self.ans[index] if self.ans is not None else None
        art_item = self.art[index]
        lab_item = self.lab[index]

        example = {
            'id': index,
            'ans': ans_item,
            'art': art_item,
            'lab': lab_item,
        }
        return example

    def __len__(self):
        return len(self.art)

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `order_by_size` (bool): whether to order batch by the length of the articles

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
        """
        return collate(
            samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
            left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding, order_by_size=self.order_by_size,
        )

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(self.lab_sizes[index], self.art_sizes[index],
                   self.ans_sizes[index] if self.ans_sizes is not None else 0)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return self.lab_sizes[index], self.art_sizes[index], self.ans_sizes[
            index] if self.ans_sizes is not None else 0

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        return indices[np.argsort(self.art_sizes[indices], kind='mergesort')]

    @property
    def supports_prefetch(self):
        return (
                getattr(self.art, 'supports_prefetch', False)
                and getattr(self.lab, 'supports_prefetch', False)
                and (getattr(self.ans, 'supports_prefetch', False) or self.ans is None)
        )

    def prefetch(self, indices):
        self.art.prefetch(indices)
        self.lab.prefetch(indices)
        if self.ans is not None:
            self.ans.prefetch(indices)
        if getattr(self, 'align_dataset', False):
            self.align_dataset.prefetch(indices)
