import torch
from data import *
from models import StyleFlow, SentimentClassifier, MLPAttention
from train import *
from torch import nn


class Config():
    data_path = './data/yelp/'
    log_dir = 'runs/exp'
    save_path = './save'
    # if change to other base model, just modify the path here
    T5_path = 'T5-base'
    pretrained_embed_path = './embedding/'
    device = torch.device('cuda' if True and torch.cuda.is_available() else 'cpu')
    load_pretrained_embed = True

    batch_size = 128
    attn_size = 300
    embed_size = 300
    d_model = 256
    hidden_size = 300

    min_freq = 3
    max_length = 16
    h = 4
    num_styles = 2
    num_classes = num_styles
    num_layers = 4
    lr_F = 0.0001
    lr_D = 0.0001
    L2 = 0

    log_steps = 5
    eval_steps = 25
    cls_epochs = 20
    learned_pos_embed = False

    dropout = 0
    drop_rate_config = [(1, 0)]
    temperature_config = [(1, 0)]

    slf_factor = 1 / 6
    cyc_factor = 1 / 2
    content_factor = 1 / 6
    style_factor = 1 / 6

    inp_shuffle_len = 0
    inp_unk_drop_fac = 0
    inp_rand_drop_fac = 0
    inp_drop_prob = 0


def main():
    config = Config()
    train_iters, dev_iters, test_iters, vocab = load_dataset(config)
    print('Vocab size:', len(vocab))

    model_F = StyleFlow(config, vocab).to(config.device)

    model_D = SentimentClassifier(config.hidden_size).to(config.device)
    model_D.load_state_dict(torch.load(config.senti_cls_path))

    train(config, vocab, model_F, model_D, train_iters, test_iters)


if __name__ == '__main__':
    main()
