# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import random

from numpy.random import geometric
from . import data_utils
from . import LanguagePairDataset

logger = logging.getLogger(__name__)
from .bartnoise_dataset import collate


class LMDataset(LanguagePairDataset):
    def __init__(
        self, src, src_sizes, src_dict,
        tgt=None, tgt_sizes=None, tgt_dict=None,
        left_pad_source=True, left_pad_target=False,
        max_source_positions=1024, max_target_positions=1024,
        shuffle=True, input_feeding=True,
        remove_eos_from_source=False, append_eos_to_target=False,
        align_dataset=None,
        append_bos=False, eos=None,
        source_unsupervised=False,
        source_idx_range=[1, -1],
        target_unsupervised=False,
        target_idx_range=[2, -1],
        unconditional_lm=False
    ):
        """
            source_idx_range: [left_bound, right_bound] of normal tokens in the source
            target_idx_range: [left_bound, right_bound] of normal tokens in the target
        """
        super().__init__(
            src, src_sizes, src_dict,
            tgt, tgt_sizes, tgt_dict,
            left_pad_source, left_pad_target,
            max_source_positions, max_target_positions,
            shuffle, input_feeding,
            remove_eos_from_source, append_eos_to_target,
            align_dataset,
            append_bos, eos
        )
        self.source_unsupervised = source_unsupervised
        self.target_unsupervised = target_unsupervised
        self.source_idx_range = source_idx_range
        self.target_idx_range = target_idx_range
        self.unconditional_lm = unconditional_lm

    def noiseFn(self, tokens):
        return tokens

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][-1] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            'id': index,
            'source': src_item,
            'target': tgt_item,
        }

        if self.source_unsupervised:
            example['denoise_target'] = src_item
            range_end = len(src_item) + self.source_idx_range[1]
            denoise_input = self.noiseFn(src_item[self.source_idx_range[0]:range_end])
            denoise_input = torch.cat(
                (src_item[:self.source_idx_range[0]], denoise_input, src_item[range_end:]), dim=0
            )
            example['denoise_source'] = denoise_input

        elif self.target_unsupervised and tgt_item is not None:
            example['denoise_target'] = tgt_item
            range_end = len(tgt_item) + self.target_idx_range[1]
            denoise_input = self.noiseFn(tgt_item[self.target_idx_range[0]:range_end])
            denoise_input = torch.cat(
                (tgt_item[:self.target_idx_range[0]], denoise_input, tgt_item[range_end:]), dim=0
            )
            example['denoise_source'] = denoise_input
        if self.unconditional_lm:
            unk = self.src_dict.unk()
            example['denoise_source'] = torch.LongTensor([unk])
        if self.align_dataset is not None:
            example['alignment'] = self.align_dataset[index]
        return example

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:
                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
        """
        return collate(
            samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos,
            left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
        )
