from solver import *
from data_loader import get_loader
from configs import get_config
from utils import Vocab
import os
import pickle
from models import VariationalModels
from latent_plot import latentPlot
def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    config = get_config(mode='test')
    print(config)

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size,
        shuffle=False)

    # for testing
    # train_data_loader = eval_data_loader
    if config.model in VariationalModels:
        solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False)
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)


    solver.build()
    print("Start ploting...")
    z_conv_prior, z_sent_q_prior, z_sent_a_prior = solver.solver_plot_prior()
    print(len(z_conv_prior))
    print(len(z_sent_q_prior))
    print(len(z_sent_a_prior))
    latentPlot(z_sent_q_prior, z_sent_a_prior, 'prior')
    z_conv_posterior, z_sent_q_posterior, z_sent_a_posterior = solver.solver_plot_posterior()
    print(len(z_conv_posterior))
    print(len(z_sent_q_posterior))
    print(len(z_sent_a_posterior))
    latentPlot(z_sent_q_posterior, z_sent_a_posterior, 'posterior')
