"""Script for encoding the whole challenge"""
import configargparse
import contextlib
import sys
import functools
from fairseq.models.roberta import RobertaModel


def encode_line(line, encoder, max_len, stats):

    line = line.strip()
    stats["num_processed"] += 1
    if len(line) > 0:
        line = encoder.encode(line)
        if len(line) > max_len:
            stats["num_shortened"] += 1
            return line[:max_len]
        return line
    else:
        stats["num_empty"] += 1
    return None


def main(all_path, dirs, max_len):
    bpe_encoder = RobertaModel.from_pretrained('roberta.base')

    for one_dir in dirs:
        with contextlib.ExitStack() as stack:
            inputs = [
                stack.enter_context(open(f'{all_path}/{one_dir}/{input}.tsv', "r", encoding="utf-8")) \
                    if input != "-" else sys.stdin
                for input in ['labels', 'articles', 'expected']
            ]
            outputs = [
                stack.enter_context(open(f'{all_path}/{one_dir}/{output}-bpe_roberta_gpt2-{max_len}.tsv',
                                         "w", encoding="utf-8")) \
                    if output != "-" else sys.stdout
                for output in ['labels', 'articles', 'expected']
            ]

            stats = {
                "num_empty": 0,
                "num_processed": 0,
                "num_shortened": 0,
            }
            for i, lines in enumerate(zip(*inputs), start=1):
                enc_lines = list(map(functools.partial(encode_line,
                                                       encoder=bpe_encoder,
                                                       max_len=max_len,
                                                       stats=stats), lines))
                if not any(enc_line is None for enc_line in enc_lines):
                    for enc_line, output_h in zip(enc_lines, outputs):
                        print(" ".join([str(x) for x in enc_line.numpy()]), file=output_h)
                if i % 10000 == 0:
                    print("processed {} lines".format(i), file=sys.stderr)

            print("processed {} lines".format(stats["num_processed"]))
            print("skipped {} empty lines".format(stats["num_empty"]))
            print("shortened {} lines".format(stats["num_shortened"]))


if __name__ == '__main__':
    parser = configargparse.ArgumentParser()
    parser.add_argument('--challenge-path',
                        type=str,
                        help='path to the challenge directory')

    parser.add_argument('--dirs',
                        default=['dev-0', 'train', 'test-A', 'test-B'],
                        type=str,
                        help='names of the subdirectories in the challenge directory to be processed')

    parser.add_argument('--max-line-len',
                        default=None,
                        type=int,
                        help='number of bpe-encoded tokens to leave after encoding')

    args = parser.parse_args()
    main(args.challenge_path, args.dirs, args.max_line_len)
