# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import logging
import json

from fairseq.data import (
    LanguagePairDataset,
    ExtractiveDataset,
)

from fairseq.tasks.translation_from_pretrained_bart import TranslationFromPretrainedBARTTask
from .summarization_from_pretrained_wo_langtag import load_langpair_sumdataset
from fairseq.tasks import register_task

logger = logging.getLogger(__name__)

@register_task('extractive_summarization')
class ExtSummarizationTask(TranslationFromPretrainedBARTTask):
    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        TranslationFromPretrainedBARTTask.add_args(parser)
        parser.add_argument('--fixed-encoder-layer', action='store_true', help='fixed encoder layer during from pretrain')
        parser.add_argument('--fixed-decoder-layer', action='store_true', help='fixed decoder layer during from pretrain')
        parser.add_argument('--extraction-head-name', type=str, help='negative sample number for contrastive learning')
        parser.add_argument('--sent-k', type=int, help='select top k sentences as predictions', default=2)
        
        parser.add_argument('--summ-langs', nargs='+', type=str, help='languages with summarization data', default=[])
        parser.add_argument('--denoise-langs', nargs='+', type=str, help='languages with denoise data', default=[])
        parser.add_argument('--use-lang-classifier', action="store_true", default=False)
        parser.add_argument('--loose-load', action="store_true")
        parser.add_argument('--src-lang-token-pos', type=int, default=0, help='the position of the lang token in the source sequence')
        # fmt: on

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)

        logger.info(
            "bos %d, pad %d, eos %d, unk %d", 
            src_dict.index('<s>'),src_dict.index('<pad>'),
            src_dict.index('</s>'),src_dict.index('<unk>')
        )
        

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # src="doc", tgt="sum"
        src, tgt = self.args.source_lang, self.args.target_lang

        summ_ds = load_langpair_sumdataset(
            data_path, split,
            src, self.src_dict,
            tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=getattr(self.args, 'max_source_positions', 1024),
            max_target_positions=getattr(self.args, 'max_target_positions', 1024),
            truncate_source=self.args.truncate_source,
            load_alignments=self.args.load_alignments,
            prepend_bos=getattr(self.args, 'preprend_bos', False),
            append_source_id=True
        )
        
        oracle_file = os.path.join(data_path, '{}.label.jsonl'.format(split))
        with open(oracle_file, 'r') as fin:
            oracle_labels = []
            for line in fin:
                oracle_labels.append(json.loads(line.strip()))
        
        rank_ds = ExtractiveDataset(
            summ_ds, self.src_dict, self.args.seed, 
            oracle_labels=oracle_labels,
            summ_langs=self.args.summ_langs,
            denoise_langs=self.args.denoise_langs
        )

        self.datasets[split] = rank_ds

        for i in range(2):
            sample = self.datasets[split][i]
            if 'margin_tokens' in sample:
                margin = sample['margin_tokens'][0]
                sample['margin_tokens'] = margin
            print("{} {}".format(split, i), sample)

    def build_model(self, args):
        model = super().build_model(args)

        if getattr(args, "fixed_encoder_layer", False):
            logger.info("Fix Encoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("encoder.layers"):
                    param.requires_grad = False

        if getattr(args, "fixed_decoder_layer", False):
            logger.info("Fix Decoder Layers!")
            for name, param in model.named_parameters():
                if name.startswith("decoder.layers"):
                    param.requires_grad = False

        if getattr(args, 'extraction_head_name', None):
            model.register_classification_head(
                getattr(args, 'extraction_head_name', 'extraction_head'),
                num_classes=2,
            )

        return model


    def build_generator(self, models, args):
        if getattr(args, 'score_reference', False):
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(
                self.target_dictionary,
            )
        else:
            from fairseq.sequence_generator import SequenceGenerator
            return SequenceGenerator(
                models,
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                temperature=getattr(args, 'temperature', 1.),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
                # eos=self.tgt_dict.index('[{}]'.format(self.args.sum_lang))  # eos: beginning of sentence token
            )

    def build_dataset_for_inference(self, src_tokens, src_lengths):
        src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.doc_lang))
        source_tokens = []
        for s_t in src_tokens:
            s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
            source_tokens.append(s_t)
        dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary)
        return dataset

    
    def extraction_inference(self, models, sample):
        outputs = []
        extras = []
        for model in models:
            x, extra = model(
                **sample['net_input'],
                classification_head_name=self.args.extraction_head_name
            )
            outputs.append(x)
            extras.append(extra)
        return torch.mean(torch.stack(outputs, dim=-1), dim=-1), extras
