import torch
import time
from data import *
from transformers import logging
from torch import nn, optim
logging.set_verbosity_warning()



class Config():
    data_path = '../data/yelp/'
    log_dir = 'runs/exp'
    save_path = '../save'
    bert_path = '../bert-base-uncased'
    pretrained_embed_path = '/'
    device = torch.device('cuda' if True and torch.cuda.is_available() else 'cpu')
    discriminator_method = 'Multi'  # 'Multi' or 'Cond'
    load_pretrained_embed = False
    min_freq = 3
    max_length = 16
    embed_size = 200
    d_model = 256
    hidden_size = 500
    h = 4
    num_styles = 2
    num_classes = num_styles + 1 if discriminator_method == 'Multi' else 2
    num_layers = 4
    batch_size = 256
    lr_F = 0.0001
    lr_D = 0.0001
    L2 = 0
    iter_D = 10
    iter_F = 5
    F_pretrain_iter = 500
    log_steps = 5
    eval_steps = 25
    cls_epochs = 20
    learned_pos_embed = True
    dropout = 0
    drop_rate_config = [(1, 0)]
    temperature_config = [(1, 0)]

    slf_factor = 0.25
    cyc_factor = 0.5
    adv_factor = 1

    inp_shuffle_len = 0
    inp_unk_drop_fac = 0
    inp_rand_drop_fac = 0
    inp_drop_prob = 0



config = Config()
train_iters, dev_iters, test_iters, vocab = load_dataset(config)

vectors = torchtext.vocab.GloVe('6B', dim=config.embed_size, cache=config.pretrained_embed_path)
vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim)
input_emb = nn.Embedding.from_pretrained(vocab.vectors)
pos_idx = torch.arange(16).unsqueeze(0).expand((256, -1))

