# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import itertools
import logging
import os
import re

from argparse import Namespace
import json
import numpy as np

from fairseq.data import (
    AppendTokenDataset,
    MoveTokenDataset,
    ConcatDataset,
    data_utils,
    indexed_dataset,
    LanguagePairDataset,
    PrependTokenDataset,
    StripTokenDataset,
    TruncateDataset,
    encoders
)

from fairseq.tasks.translation_from_pretrained_bart import TranslationFromPretrainedBARTTask
from fairseq.tasks import register_task
from fairseq import utils, metrics
import sys
sys.path.append("/opt/tiger/sumtest/crossLingualTransfer")
from xnlg.src.evaluation.rouge import rouge_eval

EVAL_BLEU_ORDER = 4

logger = logging.getLogger(__name__)

def load_langpair_sumdataset(
    data_path, split,
     src, src_dict,
    tgt, tgt_dict, 
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions,
    max_target_positions, prepend_bos=False, load_alignments=False,
    truncate_source=False
):

    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info('{} {} {}-{} {} examples'.format(
            data_path, split_k, src, tgt, len(src_datasets[-1])
        ))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    if truncate_source:
        trunc_len = max_source_positions
        logger.info("Truncate source to max length %d", trunc_len)
        src_dataset = AppendTokenDataset(
            TruncateDataset(
                StripTokenDataset(src_dataset, src_dict.eos()),
                trunc_len - 1,
            ),
            src_dict.eos(),
        )

    eos = None
    src_dataset = MoveTokenDataset(
        src_dataset, 
        append_first_token=True,
        delete_first_token=True    
    )
    if tgt_dataset is not None:
        tgt_dataset = MoveTokenDataset(
            tgt_dataset, 
            append_first_token=True,
            delete_first_token=False    
        )

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset, eos=eos
    )

@register_task('summarization_from_pretrained_mbart_mspm4')
class SummarizationFromPretrainedMBARTTaskMPSM4(TranslationFromPretrainedBARTTask):
    """
    Translate from source language to target language with a model initialized with a multilingual pretrain.

    Args:
        src_dict (~fairseq.data.Dictionary): dictionary for the source language
        tgt_dict (~fairseq.data.Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        TranslationFromPretrainedBARTTask.add_args(parser)
        
        # options for reporting ROUGE during validation
        parser.add_argument('--eval-rouge', action='store_true',
                            help='evaluation with BLEU scores')

        parser.add_argument('--eval-rouge-detok-args', type=str, metavar='JSON', default="{}",
                            help='args for building the tokenizer, if needed')
        parser.add_argument('--eval-rouge-remove-bpe', nargs='?', const='@@ ', default="sentencepiece",
                            help='remove BPE before computing ROUGE')
        parser.add_argument('--eval-rouge-args', type=str, metavar='JSON',
                            help='generation args for BLUE scoring, '
                                 'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
        parser.add_argument('--eval-rouge-print-samples', action='store_true',
                            help='print sample generations during validation')
        parser.add_argument('--eval-language', default='en',
                            help='the evaluation language')
        parser.add_argument('--loose-load', action="store_true")

        parser.add_argument('--summ-langs', nargs='+', type=str, help='languages with summarization data, e.g. en_XX', default=[])
        parser.add_argument('--unsupervised-langs', nargs='+', type=str, help='languages with unsupervised data, e.g. zh_CN', default=[])

        parser.add_argument('--use-lang-classifier', action="store_true", default=False)
        parser.add_argument('--lang-cls-unsup-lang-no-grad', action="store_true", default=False)
        parser.add_argument("--train-adapter", action="store_true", default=False)
        parser.add_argument("--eval-batchs", type=int, default=None)
        parser.add_argument("--doc-lang", type=str, default=None, help='e.g., en_XX')
        # fmt: on

        parser.add_argument('--prefix-tokens', type=str, nargs="+", default=None, 
            help="prefix tokens used during the inference. " \
            "If is activated, it will overwrite ``prefix_size``")
        parser.add_argument('--extended-dict', type=str, default="")

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        logger.info("bos %d, pad %d, eos %d, unk %d", 
                src_dict.index('<s>'),src_dict.index('<pad>'),
                src_dict.index('</s>'),src_dict.index('<unk>')
                )
        logger.info("en_XX {}, zh_CN {}, de_DE {}, fr_XX {}".format(
            src_dict.index('[en_XX]'),src_dict.index('[zh_CN]'),
            src_dict.index('[de_DE]'),src_dict.index('[fr_XX]'))
        )
        self.prefix_tokens = args.prefix_tokens
        

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # src="doc", tgt="sum"
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_sumdataset(
            data_path, split, 
            src, self.src_dict,
            tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=getattr(self.args, 'max_source_positions', 1024),
            max_target_positions=getattr(self.args, 'max_target_positions', 1024),
            truncate_source=self.args.truncate_source,
            load_alignments=self.args.load_alignments,
            prepend_bos=getattr(self.args, 'preprend_bos', False),
        )
        print(split, self.datasets[split][0])

    def build_model(self, args):
        model = super().build_model(args)

        if getattr(args, "fixed_encoder_layer", False):
            logger.info("Fix Encoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("encoder.layers"):
                    param.requires_grad = False

        if getattr(args, "fixed_decoder_layer", False):
            logger.info("Fix Decoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("decoder.layers"):
                    param.requires_grad = False
        
        return model

    def build_generator(self, models, args):
        if getattr(args, 'score_reference', False):
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(
                self.target_dictionary,
                # eos=self.tgt_dict.index('[{}]'.format(self.args.sum_lang))
            )
        else:
            from fairseq.sequence_generator import SequenceGenerator
            return SequenceGenerator(
                models,
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                temperature=getattr(args, 'temperature', 1.),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
            )

    def build_dataset_for_inference(self, src_tokens, src_lengths):
        src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.doc_lang))
        source_tokens = []
        for s_t in src_tokens:
            s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
            source_tokens.append(s_t)
        dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary)
        return dataset

    def dump_source_embedding(self, models, sample, model_args):
        net_input = sample['net_input']
        model_embeddings = []
        for model in models:
            encoder_output_layer = model_args.encoder_output_layer if hasattr(model_args, "encoder_output_layer") else None
            output_tuple, extra = model.encoder(
                src_tokens=net_input['src_tokens'],
                src_lengths=net_input['src_lengths'],
                return_all_hiddens=True
            )

            encoder_hidden = output_tuple.encoder_out
            if encoder_output_layer is not None:
                encoder_hidden = output_tuple.encoder_states[encoder_output_layer]

            encoder_out = torch.transpose(
                encoder_hidden, 1, 0
            ) # [T, B, C] -> [B, T, C]
            model_embeddings.append(encoder_out)
        return torch.mean(torch.stack(model_embeddings, dim=-1), dim=-1)

    def dump_document_embedding(self, models, sample, model_args):
        """
        Input:
        
        Return:
            embedding: 
        """
        net_input = sample['net_input']
        model_embeddings = []
        for model in models:
            if model_args.doc_state == "fused":
                _, extra = model(
                    src_tokens=net_input['src_tokens'],
                    src_lengths=net_input['src_lengths'],
                    features_only=True
                )
            else:
                output_tuple, extra = model.encoder(
                    src_tokens=net_input['src_tokens'],
                    src_lengths=net_input['src_lengths'],
                    return_all_hiddens=True
                )

            doc_embedding = extra['{}_doc_state'.format(model_args.doc_state)]
            model_embeddings.append(doc_embedding)
        return torch.mean(torch.stack(model_embeddings, dim=-1), dim=-1)

    def build_model(self, args):
        model = super().build_model(args)
        if getattr(args, 'eval_rouge', False):
            detok_args = json.loads(getattr(args, 'eval_rouge_detok_args', '{}') or '{}')
            self.tokenizer = encoders.build_tokenizer(Namespace(
                tokenizer=getattr(args, 'eval_rouge_detok', None),
                **detok_args
            ))

            gen_args = json.loads(getattr(args, 'eval_rouge_args', '{}') or '{}')
            self.sequence_generator = self.build_generator([model], Namespace(**gen_args))
        return model

    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
        if self.args.eval_rouge:
            lang_id = sample['target'][0, :1]
            lang_token = self.src_dict.string(lang_id)
            lang = lang_token.split("_")[0] # like "[en"
            if lang[0] == "[":
                lang = lang[1:]
            rouge = self._inference_with_rouge(self.sequence_generator, sample, model, lang)
            # we split counts into separate entries so that they can be
            # summed efficiently across workers using fast-stat-sync

            logging_output['{}_RL_P'.format(lang)] = round(rouge['rouge_l_precision'] * 100, 2)
            logging_output['{}_RL_R'.format(lang)] = round(rouge['rouge_l_recall'] * 100, 2)
            logging_output['{}_RL_F'.format(lang)] = round(rouge['rouge_l_f_score'] * 100, 2)
        return loss, sample_size, logging_output

    def _inference_with_rouge(self, generator, sample, model, lang):
        from rouge import Rouge

        def decode(toks, escape_unk=False):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args.eval_rouge_remove_bpe,
                # The default unknown string in fairseq is `<unk>`, but
                # this is tokenized by sacrebleu as `< unk >`, inflating
                # BLEU scores. Instead, we use a somewhat more verbose
                # alternative that is unlikely to appear in the real
                # reference, but doesn't get split into multiple tokens.
                unk_string=(
                    "UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"
                ),
                extra_symbols_to_ignore={
                    generator.eos,
                    0
                }
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        def removeLangTag(s: str):
            return re.sub("\[[a-z]{2}_[A-Z]{2}\]", "", s).strip()

        def splitChar(s: str):
            tokens = s.strip().split()
            chars = []
            for token in tokens:
                chars.extend([char for char in token.strip()])
            return " ".join(chars)
        gen_out = self.inference_step(generator, [model], sample, sample['target'][:, :1])
        srcs, hyps, refs = [], [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(decode(
                utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                escape_unk=True,  # don't count <unk> as matches to the hypo
            ))
            srcs.append(decode(
                utils.strip_pad(sample['net_input']['src_tokens'][i], self.src_dict.pad()),
                escape_unk=True,  # don't count <unk> as matches to the hypo
            ))
        srcs = [removeLangTag(src) for src in srcs]
        refs = [removeLangTag(ref) for ref in refs]
        hyps = [removeLangTag(hyp) for hyp in hyps]
        if self.args.eval_rouge_print_samples:
            for (bid, iid) in enumerate(sample['id']):
                if iid < 3:
                    logger.info("id: {}".format(iid))
                    logger.info('source: ' + srcs[bid])
                    logger.info('reference: ' + refs[bid])
                    logger.info('hypothesis: ' + hyps[bid] + "\n")

        if lang == "zh":
            refs = [splitChar(ref) for ref in refs]
            hyps = [splitChar(hyp) for hyp in hyps]
        for (i, hypo) in enumerate(hyps):
            if hypo.strip() == "":
                hyps[i] = "P"
        try:
            return rouge_eval(hyps, refs, zh=(lang == "zh"))
        except Exception as e:
            logger.info("Exception when calculating rouge scores: {}".format(e))
            logger.info("Return 0 as score")
            return {
                "rouge_l_f_score": 0.0,
                "rouge_l_precision": 0.0,
                "rouge_l_recall": 0.0
            }
    
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args.eval_rouge:

            def mean_logs(key):
                return sum(log.get(key, 0) for log in logging_outputs) / len(logging_outputs)

            for key in logging_outputs[0]:
                if "RL" in key:
                    metrics.log_scalar(key, mean_logs(key))
    def set_prefix_tokens(self, sample):
        prefix_tokens = None
        if self.prefix_tokens is not None:
            tokens = [self.tgt_dict.index('[{}]'.format(token)) for token in self.args.prefix_tokens]
            tokens = torch.tensor(tokens, dtype=torch.int64)
            prefix_tokens = torch.unsqueeze(tokens, 0).repeat(sample['target'].size(0), 1)
            prefix_tokens = prefix_tokens.to(sample['target'])
            prefix_tokens = prefix_tokens.to(sample['target'].dtype)
        else:
            if self.args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :self.args.prefix_size]
        return prefix_tokens

    def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        """
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            update_num (int): the current update
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
        """
        model.train()
        model.set_num_updates(update_num)
        loss, sample_size, logging_output = criterion(model, sample)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        return loss, sample_size, logging_output

    def postprocess(self):
        if self.args.extended_dict != "":
            print("TASK.extended_dict: ", self.args.extended_dict)
            self.src_dict.add_from_file(self.args.extended_dict)
            self.tgt_dict.add_from_file(self.args.extended_dict)

            with open(self.args.extended_dict, 'r') as fin:
                extended_tokens = [line.strip().split()[0] for line in fin]

            logging_str = "src_dict: "
            for token in extended_tokens:
                logging_str += "token: {} idx: {}, ".format(
                    token, self.src_dict.index(token) 
                )
            logger.info(
                logging_str
            )

            logging_str = "tgt_dict: "
            for token in extended_tokens:
                logging_str += "token: {} idx: {}, ".format(
                    token, self.tgt_dict.index(token) 
                )
            logger.info(
                logging_str
            )

            # update sequence generator
            if getattr(self.args, 'eval_rouge', False):
                self.sequence_generator.vocab_size = len(self.tgt_dict)
