#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""

import logging
import json
import os
import sys

import torch

from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.logging import progress_bar
from fairseq.data import encoders


def main(args):
    assert args.path is not None, '--path required for saving embedding!'
    return _main(args, sys.stdout)


def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.dump_embedding')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)
    bpe = encoders.build_bpe(args)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x
    
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue
        
        # extras: list of extra information provided by each model
        # x: [bs, max_sent_num, 2]
        x, extras = task.extraction_inference(models, sample)
        
        # get predicted sentence id
        pred_score = x[:, :, 1]
        pred_score = pred_score.masked_fill(sample['ext_mask'], -1000) # 1 means mask
        pred_sent_id = torch.argsort(pred_score, dim=-1, descending=True)
        for (i, sample_id) in enumerate(sample['id'].tolist()):
            src_token = sample['net_input']['src_tokens'][i]
            sent_strs = []
            # get token ids of predicted sentence by sample['margin_token']
            margin_tokens = sample['net_input']['margin'][i] # [max_sent_num, src_len]
            # merge sentences and write to file
            selected_sent_ids = pred_sent_id[i, :args.sent_k]
            for sid in selected_sent_ids:
                margin = ~margin_tokens[sid]
                token_positions = margin.nonzero(as_tuple=False).squeeze(-1)
                if token_positions.numel() != 0: 
                    sent_tokens = torch.index_select(
                        src_token, dim=0, index=token_positions
                    )

                    sent_str = src_dict.string(sent_tokens, args.remove_bpe)
                    if sent_str.strip() != "":
                        sent_strs.append(sent_str)

            hypo_str = " ".join(sent_strs)
            detok_hypo_str = decode_fn(hypo_str)
            if not args.quiet:
                # original hypothesis (after tokenization and BPE)
                has_target = sample['target'] is not None
                target_str = None
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
                        target_str = tgt_dict.string(
                            target_tokens,
                            args.remove_bpe,
                            escape_unk=True,
                            extra_symbols_to_ignore={
                                src_dict.eos(),
                            }
                        )

                src_str = decode_fn(src_str)
                if has_target:
                    target_str = decode_fn(target_str)

                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str), file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str), file=output_file)

                print('H-{}\t{}'.format(sample_id, hypo_str), file=output_file)
                # detokenized hypothesis
                print('D-{}\t{}\n'.format(sample_id, detok_hypo_str), file=output_file)


def cli_main():
    parser = options.get_generation_parser()
    args = options.parse_args_and_arch(parser)
    main(args)

if __name__ == '__main__':
    cli_main()
