from model.transformer import *
from util.batch_generator import *
from util.files import *
from util.trainer import Evaluater
import os
from util.args import EMNLPArgument
import apex
from util.sampling import *


def get_model(args):
    model =Transformer_Model(args.vocab_size, args.batch_seqlen, args.hidden_dim, args.projection_dim, args.n_heads,
                             args.head_dim, args.n_layers, args.cutoffs, args.dropout_rate, args.dropatt_rate,
                             args.padding_index, rel_att=args.relative_pos,experimental_loss=args.experimental_loss)
    initializer = Initializer('normal', 0.02, 0.1)
    initializer.initialize(model)
    model = model.to(args.device)
    model.load_state_dict(torch.load(args.saved_path))
    return model


def get_batchfier(args):
    if args.dataset =='bugs':
        test_batchfier = Lyrics_Batchfier([args.test_path], args.batch_size, seq_len=args.batch_seqlen,
                                          padding_index=args.padding_index, epoch_shuffle=True)
    else:
        test_batchfier = BpttIterator(load_json(args.test_path), args.batch_size, args.batch_seqlen, device=args.device)

    return test_batchfier


if __name__ == '__main__':
    args = EMNLPArgument(is_train=False)
    print(args.learning_rate, 'experimental : {} cutoffs : {}'.format(
        args.experimental_loss, len(args.cutoffs)))

    print(args.__dict__)
    print(args.sampled_savepath)
    model = get_model(args)
    test_batchfier = get_batchfier(args)
    evaluater = Evaluater(model, test_batchfier, args.padding_index, args.experimental_loss)

    evaluater.eval()


    # train_lstm(model,batchfier,optimizer)