from model.transformer import *
from util.batch_generator import *
from util.files import *
from util.trainer import EMNLPTrainer
import os
from util.args import EMNLPArgument
import apex
from pytorch_transformers import WarmupLinearSchedule
from util.sampling import *
import pandas as pd


def get_model(args):
    model =Transformer_Model(args.vocab_size, args.batch_seqlen, args.hidden_dim, args.projection_dim, args.n_heads,
                             args.head_dim, args.n_layers, args.cutoffs, args.dropout_rate, args.dropatt_rate,
                             args.padding_index, rel_att=args.relative_pos,experimental_loss=args.experimental_loss)
    initializer = Initializer('normal', 0.02, 0.1)
    initializer.initialize(model)
    model = model.to(args.device)
    model.load_state_dict(torch.load(args.saved_path))
    return model


def get_batchfier(args):
    if args.dataset == 'bugs':
        test_batchfier = LyricsSampleBatchfier([args.test_path], args.batch_size*16,
                              10000, args.nprefix, args.ngenerate, device=args.device)
    else:
        test_batchfier = SamplingIterator(load_json(args.test_path), args.batch_size*16,
                                      args.nprefix, args.ngenerate, device=args.device)
    return test_batchfier


def generate_sample(args, model, batchfier):
    def truncate(x,prefix_len):
        return [i[prefix_len:] for i in x]
    prefixs = []
    truths = []
    generated = []
    for inp in batchfier:
        prefix = inp[0][:,:args.nprefix]
        gt = inp[0][:,args.nprefix:]
        if gt.size(-1) == 0:
            break
        res, _ = sample(model, args.ngenerate, prefix, args.top_k, args.temperature,
                            args.experimental_loss, args.sampling_mode)
        generated.extend(truncate(res,args.nprefix))
        truths.extend(gt.tolist())
        prefixs.extend(prefix.tolist())
    return pd.DataFrame({'prefix':prefixs, 'decoded_predict':generated,'decoded_true':truths})


if __name__ == '__main__':
    args = EMNLPArgument(is_train=False)
    print(args.learning_rate, 'experimental : {} cutoffs : {}'.format(
        args.experimental_loss, len(args.cutoffs)))

    print(args.__dict__)
    print(args.sampled_savepath)
    model = get_model(args)
    test_batchfier = get_batchfier(args)
    df = generate_sample(args,model,test_batchfier)
    if not os.path.exists(os.path.dirname(args.sampled_savepath)):
        os.makedirs(os.path.dirname(args.sampled_savepath))
    df.to_pickle(args.sampled_savepath)
