# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import logging

from argparse import Namespace
import numpy as np

from fairseq.data import (
    MoveTokenDataset,
    data_utils,
    AppendTokenDataset,
    TruncateDataset,
    StripTokenDataset,
    BOWDataset
)

from .summarization_from_pretrained_mbart_mspm4 import SummarizationFromPretrainedMBARTTaskMPSM4, load_langpair_sumdataset
from fairseq.tasks import register_task

EVAL_BLEU_ORDER = 4

logger = logging.getLogger(__name__)

@register_task('summarization_from_pretrained_mbart_mspm4_w_bow')
class SummarizationFromPretrainedMBARTwBowMspm4(SummarizationFromPretrainedMBARTTaskMPSM4):
    """
    Translate from source language to target language with a model initialized with a multilingual pretrain.

    Args:
        src_dict (~fairseq.data.Dictionary): dictionary for the source language
        tgt_dict (~fairseq.data.Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        SummarizationFromPretrainedMBARTTaskMPSM4.add_args(parser)
        # fmt: on

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        src, tgt = self.args.source_lang, self.args.target_lang

        summ_ds = load_langpair_sumdataset(
            data_path, split, 
            src, self.src_dict,
            tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=getattr(self.args, 'max_source_positions', 1024),
            max_target_positions=getattr(self.args, 'max_target_positions', 1024),
            truncate_source=self.args.truncate_source,
            load_alignments=self.args.load_alignments,
            prepend_bos=getattr(self.args, 'preprend_bos', False),
        )
        
        nonstop_file = os.path.join(data_path, '{}.nonstop'.format(split))
        nonstop_ds = data_utils.load_indexed_dataset(nonstop_file, self.src_dict, self.args.dataset_impl)
        stop_file = os.path.join(data_path, '{}.stop'.format(split))
        stop_ds = data_utils.load_indexed_dataset(stop_file, self.src_dict, self.args.dataset_impl)
        self.datasets[split] = BOWDataset(summ_ds, nonstop_ds, stop_ds, self.src_dict, self.args.seed)
        print("split: {} dataset size: {}".format(split, len(self.datasets[split])))
        print(self.datasets[split][0])
