# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import random

from numpy.random import geometric
from . import data_utils
from . import LanguagePairDataset

logger = logging.getLogger(__name__)


def collate(
    samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
    input_feeding=True,
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )

    def check_alignment(alignment, src_len, tgt_len):
        if alignment is None or len(alignment) == 0:
            return False
        if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
            logger.warning("alignment size mismatch found, skipping alignment!")
            return False
        return True

    def compute_alignment_weights(alignments):
        """
        Given a tensor of shape [:, 2] containing the source-target indices
        corresponding to the alignments, a weight vector containing the
        inverse frequency of each target index is computed.
        For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
        a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
        index 3 is repeated twice)
        """
        align_tgt = alignments[:, 1]
        _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
        align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
        return 1. / align_weights.float()

    id = torch.LongTensor([s['id'] for s in samples])
    src_tokens = merge('source', left_pad=left_pad_source)
    # sort by descending source length
    src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
    src_lengths, sort_order = src_lengths.sort(descending=True)
    id = id.index_select(0, sort_order)
    src_tokens = src_tokens.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get('target', None) is not None:
        target = merge('target', left_pad=left_pad_target)
        target = target.index_select(0, sort_order)
        tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
        ntokens = sum(len(s['target']) for s in samples)

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                'target',
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
    else:
        ntokens = sum(len(s['source']) for s in samples)

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'src_tokens': src_tokens,
            'src_lengths': src_lengths,
        },
        'target': target,
    }
    if prev_output_tokens is not None:
        batch['net_input']['prev_output_tokens'] = prev_output_tokens

    if samples[0].get('alignment', None) is not None:
        bsz, tgt_sz = batch['target'].shape
        src_sz = batch['net_input']['src_tokens'].shape[1]

        offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
        offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
        if left_pad_source:
            offsets[:, 0] += (src_sz - src_lengths)
        if left_pad_target:
            offsets[:, 1] += (tgt_sz - tgt_lengths)

        alignments = [
            alignment + offset
            for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
            for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
            if check_alignment(alignment, src_len, tgt_len)
        ]

        if len(alignments) > 0:
            alignments = torch.cat(alignments, dim=0)
            align_weights = compute_alignment_weights(alignments)

            batch['alignments'] = alignments
            batch['align_weights'] = align_weights

    if samples[0].get('denoise_source', None) is not None:
        denoise_src_tokens = merge('denoise_source', left_pad=left_pad_source)
        denoise_src_tokens = denoise_src_tokens.index_select(0, sort_order)

        denoise_src_lengths = torch.LongTensor([s['denoise_source'].numel() for s in samples])
        denoise_src_lengths = denoise_src_lengths.index_select(0, sort_order)

        denoise_target = merge('denoise_target', left_pad=left_pad_target)
        denoise_target = denoise_target.index_select(0, sort_order)

        denoise_net_input = {
            'src_tokens': denoise_src_tokens,
            'src_lengths': denoise_src_lengths
        }

        if input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            denoise_prev_output_tokens = merge(
                'denoise_target',
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
            )
            denoise_prev_output_tokens = denoise_prev_output_tokens.index_select(0, sort_order)
            denoise_net_input['prev_output_tokens'] = denoise_prev_output_tokens
        batch['denoise_net_input'] = denoise_net_input
        batch['denoise_target'] = denoise_target
    return batch


class bartNoiseDataset(LanguagePairDataset):
    def __init__(
        self, src, src_sizes, src_dict,
        tgt=None, tgt_sizes=None, tgt_dict=None,
        left_pad_source=True, left_pad_target=False,
        max_source_positions=1024, max_target_positions=1024,
        shuffle=True, input_feeding=True,
        remove_eos_from_source=False, append_eos_to_target=False,
        align_dataset=None,
        append_bos=False, eos=None,
        source_unsupervised=False,
        source_idx_range=[1, -1],
        target_unsupervised=False,
        target_idx_range=[2, -1],
        p=0.2,
        masked_ratio=0.15,
        full_stop_idx=2,
        mask_idx=None
    ):
        """
            source_idx_range: [left_bound, right_bound] of normal tokens in the source
            target_idx_range: [left_bound, right_bound] of normal tokens in the target
        """
        super().__init__(
            src, src_sizes, src_dict,
            tgt, tgt_sizes, tgt_dict,
            left_pad_source, left_pad_target,
            max_source_positions, max_target_positions,
            shuffle, input_feeding,
            remove_eos_from_source, append_eos_to_target,
            align_dataset,
            append_bos, eos
        )
        self.source_unsupervised = source_unsupervised
        self.target_unsupervised = target_unsupervised
        self.source_idx_range = source_idx_range
        self.target_idx_range = target_idx_range
        self.mask_idx = mask_idx
        self.p = p 
        self.masked_ratio = masked_ratio
        self.full_stop_idx = full_stop_idx

    def sentenceShuffle(self, tokens: torch.Tensor):
        is_full_stops = (tokens == self.full_stop_idx)
        sentence_starts = torch.nonzero(~is_full_stops[:-1] & is_full_stops[1:], as_tuple=False).squeeze(-1) + 2
        sentences = []
        for i in range(sentence_starts.size(0)):
            sent_start = sentence_starts[i-1] if i > 0 else 0
            next_sent_start = sentence_starts[i]
            sentences.append(tokens[sent_start:next_sent_start])
        random.shuffle(sentences)
        shuffled_tokens = torch.cat(sentences, dim=0)
        return shuffled_tokens

    def maskSpan(self, tokens: torch.Tensor):
        """
        refer to SpanBERT https://arxiv.org/pdf/1907.10529.pdf
        """
        p = self.p
        masked_ratio = self.masked_ratio
        rawLen = tokens.size(0)
        remainedMaskToken = int(rawLen * masked_ratio)
        while remainedMaskToken > 0:
            spanLength = geometric(p)
            spanLength = min([spanLength, 10, remainedMaskToken])
            sentLen = len(tokens)
            start = random.randrange(0, max(0, sentLen - spanLength)+1)
            retry = 0
            while torch.sum(tokens[start:start + spanLength] == self.mask_idx) != 0 and retry < 5:
                spanLength = geometric(p)
                spanLength = max([spanLength, 10, remainedMaskToken])
                start = random.randrange(0, max(0, sentLen - spanLength)+1)
                retry += 1
            
            tokens = torch.cat(
                (tokens[:start], torch.tensor([self.mask_idx], dtype=torch.long).to(tokens), tokens[(start+spanLength):]), 
                0
            )
            remainedMaskToken -= spanLength
        return tokens
    
    def noiseFn(self, tokens):
        shuffled_tokens = self.sentenceShuffle(tokens)
        masked_tokens = self.maskSpan(shuffled_tokens)
        return masked_tokens

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][-1] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            'id': index,
            'source': src_item,
            'target': tgt_item,
        }

        if self.source_unsupervised:
            example['denoise_target'] = src_item
            range_end = len(src_item) + self.source_idx_range[1]
            denoise_input = self.noiseFn(src_item[self.source_idx_range[0]:range_end])
            denoise_input = torch.cat(
                (src_item[:self.source_idx_range[0]], denoise_input, src_item[range_end:]), dim=0
            )
            example['denoise_source'] = denoise_input

        elif self.target_unsupervised and tgt_item is not None:
            example['denoise_target'] = tgt_item
            range_end = len(tgt_item) + self.target_idx_range[1]
            denoise_input = self.noiseFn(tgt_item[self.target_idx_range[0]:range_end])
            denoise_input = torch.cat(
                (tgt_item[:self.target_idx_range[0]], denoise_input, tgt_item[range_end:]), dim=0
            )
            example['denoise_source'] = denoise_input
        if self.align_dataset is not None:
            example['alignment'] = self.align_dataset[index]
        return example

    def collater(self, samples):
        # TODO: change fn and documentation
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:
                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
        """
        return collate(
            samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos,
            left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
        )
