#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""

import fileinput
import logging
import math
import os
import sys
import time
from collections import namedtuple

import torch

from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import encoders


logging.basicConfig(
    format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stderr,
)
logger = logging.getLogger('fairseq_cli.interactive')


Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
MultiInputBatch = namedtuple('MultiInputBatch', 'ids doc_tokens doc_lengths prop_tokens prop_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')


def buffered_read(input, buffer_size):
    buffer = []
    with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
        for src_str in h:
            buffer.append(src_str.strip('\n'))
            if len(buffer) >= buffer_size:
                yield buffer
                buffer = []

    if len(buffer) > 0:
        yield buffer


def make_batches(lines, args, task, max_positions, encode_fn):
    lines = list(lines)
    input_num = getattr(task, 'input_num', 1)

    if input_num == 1:
        src_tokens = [
            task.source_dictionary.encode_line(
                encode_fn(src_str), add_if_not_exist=False
            ).long()
            for src_str in lines
        ]
        src_lengths = torch.LongTensor([t.numel() for t in src_tokens])
        itr = task.get_batch_iterator(
            dataset=task.build_dataset_for_inference(src_tokens, src_lengths),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions,
        ).next_epoch_itr(shuffle=False)

        for batch in itr:
            yield Batch(
                ids=batch['id'],
                src_tokens=batch['net_input']['src_tokens'],
                src_lengths=batch['net_input']['src_lengths'],
            )
    elif input_num == 2:
        prop_tokens = [
            task.source_dictionary.encode_line(
                encode_fn(src_str.split('\t')[0]), add_if_not_exist=False
            ).long()
            for src_str in lines
        ]
        prop_lengths = torch.LongTensor([t.numel() for t in prop_tokens])

        doc_tokens = [
            task.source_dictionary.encode_line(
                encode_fn(src_str.split('\t')[1]), add_if_not_exist=False
            ).long()
            for src_str in lines
        ]
        doc_lengths = torch.LongTensor([t.numel() for t in doc_tokens])

        itr = task.get_batch_iterator(
            dataset=task.build_dataset_for_inference(prop_tokens, prop_lengths, doc_tokens, doc_lengths),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions,
        ).next_epoch_itr(shuffle=False)

        for batch in itr:
            yield MultiInputBatch(
                ids=batch['id'],
                doc_tokens=batch['net_input']['doc_tokens'],
                doc_lengths=batch['net_input']['doc_lengths'],
                prop_tokens=batch['net_input']['prop_tokens'],
                prop_lengths=batch['net_input']['prop_lengths'],
            )


def main(args):
    start_time = time.time()
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )
    if hasattr(args, 'mlflow_run_id'):
        args.mlflow_run_id = _model_args.mlflow_run_id

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

    if args.buffer_size > 1:
        logger.info('Sentence buffer size: %s', args.buffer_size)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('The outputs will now be generated. It might take a while, depending on input size.')
    start_id = 0
    itr = buffered_read(args.input, args.buffer_size)
    progress = progress_bar.progress_bar(
        iterator=itr,
        log_format='tqdm',
        prefix=f"Interactive validation on {args.gen_subset}",
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
        mlflow_server_url=getattr(args, 'mlflow_server_url', None),
        mlflow_experiment_name=getattr(args, 'mlflow_experiment_name', None),
        mlflow_run_id=getattr(args, 'mlflow_run_id', None),
    )
    for inputs in progress:
        results = []
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
            input_num = getattr(task, 'input_num', 1)
            if input_num == 1:
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()

                sample = {
                    'net_input': {
                        'src_tokens': src_tokens,
                        'src_lengths': src_lengths,
                    },
                }
            elif input_num == 2:
                doc_tokens = batch.doc_tokens
                doc_lengths = batch.doc_lengths
                prop_tokens = batch.prop_tokens
                prop_lengths = batch.prop_lengths
                if use_cuda:
                    doc_tokens = doc_tokens.cuda()
                    doc_lengths = doc_lengths.cuda()
                    prop_tokens = prop_tokens.cuda()
                    prop_lengths = prop_lengths.cuda()

                sample = {
                    'net_input': {
                        'doc_tokens': doc_tokens,
                        'doc_lengths': doc_lengths,
                        'prop_tokens': prop_tokens,
                        'prop_lengths': prop_lengths,
                    },
                }
                src_tokens = prop_tokens
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                # print('S-{}\t{}'.format(id, src_str), file=sys.stderr)

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                print('{}'.format(hypo_str))
                # print('P-{}\t{}'.format(
                #     id,
                #     ' '.join(map(
                #         lambda x: '{:.4f}'.format(x),
                #         # convert from base e to base 2
                #         hypo['positional_scores'].div_(math.log(2)).tolist(),
                #     ))),
                #     file=sys.stderr)
                if args.print_alignment:
                    alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(
                        id,
                        alignment_str
                    ),
                    file=sys.stderr)

        # update running id counter
        start_id += len(inputs)
    duration = time.time() - start_time
    progress.print({'time': duration}, tag='interactive')

def cli_main():
    parser = options.get_generation_parser(interactive=True)
    args = options.parse_args_and_arch(parser)
    main(args)


if __name__ == '__main__':
    cli_main()
