# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
from argparse import Namespace
import json
import itertools
import logging
import os

import numpy as np

from fairseq import metrics, options, utils
from fairseq.data import (
    AppendTokenDataset,
    ConcatDataset,
    data_utils,
    encoders,
    indexed_dataset,
    LanguagePairDataset,
    PrependTokenDataset,
    StripTokenDataset,
    TruncateDataset,
    ResamplingDataset,
    SortDataset,
    ConcatSentencesDataset,
    Dictionary,
    IdDataset,
    NestedDictionaryDataset,
    NumSamplesDataset,
    NumelDataset,
    RawLabelDataset,
    RightPadDataset,
)

from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks.sentence_ranking import SentenceRankingTask
from fairseq.tasks import register_task

logger = logging.getLogger(__name__)

@register_task('summarization_mbart_rank')
class SummarizationFromMBARTRank(SentenceRankingTask):
    """
    Ranking task on multiple sentences.

    Args:
        dictionary (Dictionary): the dictionary for the input of the task
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        SentenceRankingTask.add_args(parser)
        parser.add_argument('--langs', required=True, metavar='LANG')
        parser.add_argument('--max-positions', type=int, help='max positions')
        parser.add_argument('--truncate-sequence', action='store_true', help='if truncate sequence by shorten-method')
        parser.add_argument('--only-encoder-for-cls', action='store_true', help='use use encoder for classification')
        # fmt: on

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)
        self.langs = args.langs.split(',')
        for l in self.langs:
            self.dictionary.add_symbol('[{}]'.format(l))
        self.dictionary.add_symbol('<mask>')
        logger.info('[input] dictionary: {} types'.format(len(self.dictionary)))

    @classmethod
    def setup_task(cls, args, **kwargs):
        assert args.criterion == 'sentence_ranking', \
            'Must set --criterion=sentence_ranking'

        # load data dictionary
        data_dict = cls.load_dictionary(
            args,
            os.path.join(args.data, 'input0', 'dict.txt'),
            source=True,
        )
        return SummarizationFromMBARTRank(args, data_dict)


@register_task('summarization_mbart_rank_joint')
class SummarizationFromMBARTRankJoint(SummarizationFromMBARTRank):
    """
    Ranking task on multiple sentences.

    Args:
        dictionary (Dictionary): the dictionary for the input of the task
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        SummarizationFromMBARTRank.add_args(parser)
        parser.add_argument('--langs-for-sum', required=True, help='language for rank pretrain')
        # fmt: on

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)
        self.langs = args.langs.split(',')
        for l in self.langs:
            self.dictionary.add_symbol('[{}]'.format(l))
        self.dictionary.add_symbol('<mask>')
        logger.info('[input] dictionary: {} types'.format(len(self.dictionary)))
        self.langs_for_summ = args.langs_for_sum.split(",")

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # src="doc", tgt="sum"
        src, tgt = self.args.source_lang, self.args.target_lang

        lang_datasets= []
        languages = self.langs_for_summ
        for lang in languages:
            code = lang.split('_')[0]   # en_XX -> en
            lang_path = os.path.join(data_path, code)
            logger.info("load lang {} from {}".format(lang, lang_path))
            lang_dataset = self.rank_dataset(
                lang_path, split, 
                combine=combine
                )
            lang_datasets.append(lang_dataset)
            # print(lang_dataset[0])

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            'Loaded total {} examples for all languages'.format(
                dataset_lengths.sum(),
            )
        )

        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info("Sample probability by language: ")
            logger.info({
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                }
            )
            size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
            logger.info("Up/Down Sampling ratio by language: ")
            logger.info({
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                }
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(
                resampled_lang_datasets,
            )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + '_' + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ','.join(lang_splits)
                )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )

        self.datasets[split] = dataset
        return self.datasets[split]

    @classmethod
    def setup_task(cls, args, **kwargs):
        assert args.criterion == 'sentence_ranking', \
            'Must set --criterion=sentence_ranking'

        # load data dictionary
        data_dict = cls.load_dictionary(
            args,
            os.path.join(args.data, 'input0', 'dict.txt'),
            source=True,
        )
        return SummarizationFromMBARTRank(args, data_dict)


    def rank_dataset(self, path, split, combine=False):
        """Load a given dataset split (e.g., train, valid, test)."""

        def get_path(type, split):
            return os.path.join(path, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        input_options = [
            make_dataset(
                'input{idx}'.format(idx=idx + 1),
                self.source_dictionary
            )
            for idx in range(self.args.num_classes)
        ]

        if self.args.separator_token is not None:
            input0 = PrependTokenDataset(input0, self.args.separator_token)

        src_tokens = []
        for input_option in input_options:
            if self.args.init_token is not None:
                input_option = PrependTokenDataset(input_option, self.args.init_token)
            if self.args.max_option_length is not None:
                input_option = TruncateDataset(input_option, self.args.max_option_length)
            src_token = ConcatSentencesDataset(input_option, input0)
            if self.args.truncate_sequence:
                src_token = maybe_shorten_dataset(
                    src_token,
                    split,
                    self.args.shorten_data_split_whitelist,
                    self.args.shorten_method,
                    self.args.max_positions,
                    self.args.seed,
                )
            src_tokens.append(src_token)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens[0]))

        dataset = {
            'id': IdDataset(),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens[0], reduce=True),
        }

        for src_token_idx in range(len(src_tokens)):
            dataset.update(
                {
                    'net_input{idx}'.format(idx=src_token_idx+1): {
                        'src_tokens': RightPadDataset(
                            src_tokens[src_token_idx],
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False),
                    }
                }
            )

        label_path = '{}.label'.format(get_path('label', split))
        if os.path.exists(label_path):
            with open(label_path) as h:
                dataset.update(
                    target=RawLabelDataset([
                        int(x.strip()) for x in h.readlines()
                    ])
                )

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))

        return dataset

    def _get_sample_prob(self, dataset_lens):
        """
        Get smoothed sampling porbability by languages. This helps low resource
        languages by upsampling them.
        """
        prob = dataset_lens / dataset_lens.sum()
        # smoothed_prob = prob ** self.args.multilang_sampling_alpha
        smoothed_prob = prob ** 1.0 # self.args.multilang_sampling_alpha=1.0
        smoothed_prob = smoothed_prob / smoothed_prob.sum()
        return smoothed_prob


