import os
import sys
import torch
import argparse

from fairseq.models.bart import BARTModel

from utils.logConfig import Log

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, default="test.source", help='input document file')
    parser.add_argument('-o', type=str, default="test.hypo", help='output candidate file')
    parser.add_argument('-m', type=str, default="checkpoints/checkpoint_best.pt", help='checkpoint file name')
    parser.add_argument('-l', type=str, default="en", help='language')
    parser.add_argument('-d', type=str, default="<q>", help='delimiter')
    args = parser.parse_args()

    logger = Log.getLogger(os.path.basename(sys.argv[0]), "%s.log" % args.l)
    logger.info(args)


    bart = BARTModel.from_pretrained(args.m)

    bart.cuda()
    bart.eval()
    bart.half()
    count = 1
    bsz = 32
    with open(args.i) as source, open(args.o, 'w') as fout:
        sline = source.readline().strip()
        slines = [sline]
        for sline in source:
            if count % bsz == 0:
                with torch.no_grad():
                    hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)

                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
                slines = []

            slines.append(sline.strip())
            count += 1
        if slines:
            hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
            for hypothesis in hypotheses_batch:
                fout.write(hypothesis + '\n')
                fout.flush()