

from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask

@register_task('translation_with_mask')
class TranslationWithMaskTask(TranslationTask):

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        self.mask_idx = getattr(self.src_dict, 'mask_index', None) or self.src_dict.add_symbol("<mask>")
        self.mask_idx = getattr(self.tgt_dict, 'mask_index', None) or self.tgt_dict.add_symbol("<mask>")


