from solver import Solver, VariationalSolver
from data_loader import get_loader
from configs import get_config
from utils import Vocab, Tokenizer
import os
import pickle
from models import VariationalModels
import re


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    config = get_config(mode='test')

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    print('Loading Type Vocabulary...')
    vocab_t = Vocab()
    vocab_t.load(config.word2id_t_path, config.id2word_t_path)
    print(f'Type Vocabulary size: {vocab_t.vocab_size}')

    config.vocab_size = vocab.vocab_size

    data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        types=load_pickle(config.types_path),
        vocab=vocab,
        vocab_t=vocab_t,
        batch_size=config.batch_size,
        shuffle=False)

    if config.model in VariationalModels:
        solver = VariationalSolver(config, None, data_loader, vocab=vocab, vocab_t=vocab_t, is_train=False)
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, vocab_t=vocab_t, is_train=False)

    solver.build()
    solver.embedding_metric()
