#from fairseq.data import BertDictionary

from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask

from .bert_dictionary import BertDictionary
import pdb

@register_task('translation_prophetnet')
class TranslationProphetnetTask(TranslationTask):
    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)

    @classmethod
    def load_dictionary(cls, filename):
        return BertDictionary.load_from_file(filename)

    def max_positions(self):
        """Return the max sentence length allowed by the task."""
        return (self.args.max_source_positions, self.args.max_target_positions)
    
    def build_generator(self, args):
        if getattr(args, 'score_reference', False):
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(self.target_dictionary)
        else:
            from prnet_sequence_generator import PRNetSequenceGenerator, SequenceGeneratorWithAlignment
#             pdb.set_trace()
            if getattr(args, 'print_alignment', False):
                seq_gen_cls = SequenceGeneratorWithAlignment
            else:
                seq_gen_cls = PRNetSequenceGenerator
            return seq_gen_cls(
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                sampling=getattr(args, 'sampling', False),
                sampling_topk=getattr(args, 'sampling_topk', -1),
                sampling_topp=getattr(args, 'sampling_topp', -1.0),
                temperature=getattr(args, 'temperature', 1.),
                diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
                diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
            )


# @register_task('translation_prnet')
# class TranslationProphetnetTask(TranslationTask):
#     def __init__(self, args, src_dict, tgt_dict):
#         super().__init__(args, src_dict, tgt_dict)

#     @classmethod
#     def load_dictionary(cls, filename):
#         return BertDictionary.load_from_file(filename)

#     def max_positions(self):
#         """Return the max sentence length allowed by the task."""
#         return (self.args.max_source_positions, self.args.max_target_positions)
    
#     def build_generator(self, args):
#         if getattr(args, 'score_reference', False):
#             from fairseq.sequence_scorer import SequenceScorer
#             return SequenceScorer(self.target_dictionary)
#         else:
#             from prnet_sequence_generator import PENetSequenceGenerator, SequenceGeneratorWithAlignment
#             pdb.set_trace()
#             if getattr(args, 'print_alignment', False):
#                 seq_gen_cls = SequenceGeneratorWithAlignment
#             else:
#                 seq_gen_cls = PENetSequenceGenerator
#             return seq_gen_cls(
#                 self.target_dictionary,
#                 beam_size=getattr(args, 'beam', 5),
#                 max_len_a=getattr(args, 'max_len_a', 0),
#                 max_len_b=getattr(args, 'max_len_b', 200),
#                 min_len=getattr(args, 'min_len', 1),
#                 normalize_scores=(not getattr(args, 'unnormalized', False)),
#                 len_penalty=getattr(args, 'lenpen', 1),
#                 unk_penalty=getattr(args, 'unkpen', 0),
#                 sampling=getattr(args, 'sampling', False),
#                 sampling_topk=getattr(args, 'sampling_topk', -1),
#                 sampling_topp=getattr(args, 'sampling_topp', -1.0),
#                 temperature=getattr(args, 'temperature', 1.),
#                 diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
#                 diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
#                 match_source_len=getattr(args, 'match_source_len', False),
#                 no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
#             )