# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch

from . import LanguagePairDataset
from collections import defaultdict

logger = logging.getLogger(__name__)


class LanguagePairMultiLDataset(LanguagePairDataset):
    def __init__(
        self, sentence_batch_size, 
        src, src_sizes, src_dict,
        tgt=None, tgt_sizes=None, tgt_dict=None,
        left_pad_source=True, left_pad_target=False,
        max_source_positions=1024, max_target_positions=1024,
        shuffle=True, input_feeding=True,
        remove_eos_from_source=False, append_eos_to_target=False,
        align_dataset=None,
        append_bos=False, eos=None
    ):
        super().__init__(
            src, src_sizes, src_dict,
            tgt, tgt_sizes, tgt_dict,
            left_pad_source, left_pad_target,
            max_source_positions, max_target_positions,
            shuffle, input_feeding,
            remove_eos_from_source, append_eos_to_target,
            align_dataset,
            append_bos, eos
        )
        self.sentence_batch_size = sentence_batch_size

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.tgt_sizes is not None:
            indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
        indices = indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
        data_dict = defaultdict(list)        
        for indice in indices:
            key_token = int(self.src[indice][-1])
            data_dict[key_token].append(indice)
        result_indices = []
        min_batch_num = min([len(data_dict[key]) // self.sentence_batch_size for key in data_dict])
        for i in range(min_batch_num):
            for key in data_dict:
                result_indices.extend(data_dict[key][
                    i*self.sentence_batch_size:(i+1)*self.sentence_batch_size
                ])
        return np.array(result_indices, dtype=np.int)
