# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
from argparse import Namespace
import itertools
import logging
import os

from fairseq import metrics, options, utils
from fairseq.data import (
    AppendTokenDataset,
    MoveTokenDataset,
    ConcatDataset,
    data_utils,
    encoders,
    indexed_dataset,
    LanguagePairDataset,
    PrependTokenDataset,
    StripTokenDataset,
    TruncateDataset,
    LanguagePairUnsortedDataset
)

from fairseq.tasks.translation_from_pretrained_bart import TranslationFromPretrainedBARTTask
from fairseq.tasks import register_task, FairseqTask

logger = logging.getLogger(__name__)

def load_langpair_sumdataset(
    data_path, split,
     src, src_dict,
    tgt, tgt_dict, 
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions,
    max_target_positions, prepend_bos=False, load_alignments=False,
    truncate_source=False
):

    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info('{} {} {}-{} {} examples'.format(
            data_path, split_k, src, tgt, len(src_datasets[-1])
        ))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    if truncate_source:
        trunc_len = max_source_positions
        logger.info("Truncate source to max length %d", trunc_len)
        src_dataset = AppendTokenDataset(
            TruncateDataset(
                StripTokenDataset(src_dataset, src_dict.eos()),
                trunc_len - 1,
            ),
            src_dict.eos(),
        )

    eos = None
    src_dataset = MoveTokenDataset(
        src_dataset, 
        append_first_token=True,
        delete_first_token=True    
    )
    if tgt_dataset is not None:
        tgt_dataset = MoveTokenDataset(
            tgt_dataset, 
            append_first_token=True,
            delete_first_token=False    
        )

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairUnsortedDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset, eos=eos
    )

@register_task('multi_task_multi_loss_from_pretrained_mbart_mspm4')
class MultiTaskMultiLossFromPretrainedMBARTTaskMPSM4(TranslationFromPretrainedBARTTask):
    """
    Translate from source language to target language with a model initialized with a multilingual pretrain.

    Args:
        src_dict (~fairseq.data.Dictionary): dictionary for the source language
        tgt_dict (~fairseq.data.Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        TranslationFromPretrainedBARTTask.add_args(parser)
        parser.add_argument('--fixed-encoder-layer', action='store_true', help='fixed encoder layer during from pretrain')
        parser.add_argument('--fixed-decoder-layer', action='store_true', help='fixed decoder layer during from pretrain')
        parser.add_argument('--prefix-tokens', type=str, nargs="+", default=None, 
            help="prefix tokens used during the inference. " \
            "If is activated, it will overwrite ``prefix_size``")
        parser.add_argument('--extended-dict', type=str, default="")
        # fmt: on

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        logger.info("bos %d, pad %d, eos %d, unk %d", 
            src_dict.index('<s>'),src_dict.index('<pad>'),
            src_dict.index('</s>'),src_dict.index('<unk>')
        )
        logger.info("en_XX {}, zh_CN {}, de_DE {}, fr_XX {}".format(
            src_dict.index('[en_XX]'),src_dict.index('[zh_CN]'),
            src_dict.index('[de_DE]'),src_dict.index('[fr_XX]'))
        )
        self.extended_dict = args.extended_dict
        self.prefix_tokens = args.prefix_tokens

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # src="doc", tgt="sum"
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_sumdataset(
            data_path, split, 
            src, self.src_dict,
            tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            # max_source_positions=getattr(self.args, 'max_source_positions', 1024),
            # max_target_positions=getattr(self.args, 'max_target_positions', 1024),
            max_source_positions=1000,
            max_target_positions=1000,
            truncate_source=self.args.truncate_source,
            load_alignments=self.args.load_alignments,
            prepend_bos=getattr(self.args, 'preprend_bos', False),
        )
        
        for i in range(2):
            print("{} {}".format(split, i), self.datasets[split][i])

    def build_model(self, args):
        model = super().build_model(args)

        if getattr(args, "fixed_encoder_layer", False):
            logger.info("Fix Encoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("encoder.layers"):
                    param.requires_grad = False

        if getattr(args, "fixed_decoder_layer", False):
            logger.info("Fix Decoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("decoder.layers"):
                    param.requires_grad = False
        
        return model


    def build_generator(self, models, args):
        if getattr(args, 'score_reference', False):
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(
                self.target_dictionary,
                # eos=self.tgt_dict.index('[{}]'.format(self.args.sum_lang))
            )
        else:
            from fairseq.sequence_generator import SequenceGenerator
            return SequenceGenerator(
                models,
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                temperature=getattr(args, 'temperature', 1.),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
                # eos=self.tgt_dict.index('[{}]'.format(self.args.sum_lang))  # eos: beginning of sentence token
            )

    def build_dataset_for_inference(self, src_tokens, src_lengths):
        src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.doc_lang))
        source_tokens = []
        for s_t in src_tokens:
            s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
            source_tokens.append(s_t)
        dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary)
        return dataset

    def postprocess(self):
        if self.extended_dict != "":
            print("self.extended_dict: ", self.extended_dict)
            self.src_dict.add_from_file(self.extended_dict)
            self.tgt_dict.add_from_file(self.extended_dict)

            with open(self.extended_dict, 'r') as fin:
                extended_tokens = [line.strip().split()[0] for line in fin]

            logging_str = "src_dict: "
            for token in extended_tokens:
                logging_str += "token: {} idx: {}, ".format(
                    token, self.src_dict.index(token) 
                )
            logger.info(
                logging_str
            )

            logging_str = "tgt_dict: "
            for token in extended_tokens:
                logging_str += "token: {} idx: {}, ".format(
                    token, self.tgt_dict.index(token) 
                )
            logger.info(
                logging_str
            )

    def set_prefix_tokens(self, sample):
        prefix_tokens = None
        if self.prefix_tokens is not None:
            tokens = [self.tgt_dict.index('[{}]'.format(token)) for token in self.args.prefix_tokens]
            tokens = torch.tensor(tokens, dtype=torch.int64)
            prefix_tokens = torch.unsqueeze(tokens, 0).repeat(sample['target'].size(0), 1)
            prefix_tokens = prefix_tokens.to(sample['target'])
            prefix_tokens = prefix_tokens.to(sample['target'].dtype)
        else:
            if self.args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :self.args.prefix_size]
        return prefix_tokens

    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API."
            )
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion
            )
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.reduce_metrics(logging_outputs)

    def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        """
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            update_num (int): the current update
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
        """
        model.train()
        model.set_num_updates(update_num)
        loss, sample_size, logging_output = criterion(model, sample, update_num=update_num)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        return loss, sample_size, logging_output

    def valid_step(self, sample, model, criterion):
        model.eval()
        # print("sample['target'].size(): {}".format(sample['target'].size()))
        with torch.no_grad():
            loss, sample_size, logging_output = criterion(model, sample)
        return loss, sample_size, logging_output