import torch
import wandb
import random

from .Trainer import Trainer, TrainerWordSimilarity, TrainerWordSimilarityCBOW
from pytorch_utils import TensorDataLoader, cuda_if_available

from ..models import LBLModel, BoxAffineTransform
from ..models import Word2Box, Word2Vec, Word2VecPooled, Word2BoxPooled, Word2BoxConjunction, Word2BoxConjunctionConditional, Word2BoxConjunctionBounded
from ..datasets.utils import load_lines, get_token_ids, get_iter, get_iter_on_device

global use_cuda
use_cuda = torch.cuda.is_available()
device = torch.cuda.current_device() if use_cuda else "cpu"


def training(config):

    # Set the seed
    if config["seed"] is None:
        config["seed"] = random.randint(0, 2 ** 32)
    torch.manual_seed(config["seed"])
    random.seed(config["seed"])

    wandb.init(project="box-language-model", reinit=True)
    wandb.config.update(config)

    TEXT, train_iter, val_iter, test_iter, subsampling_prob = get_iter_on_device(
        config["batch_size"],
        config["dataset"],
        config["model_type"],
        config["n_gram"],
        config["subsample_thresh"],
        config["data_device"],
        config["add_pad"],
        config["eos_mask"]
    )
    if config["model_type"] == "lbl":
        model = LBLModel(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
            sep_output=config["sep_output"],
            diag_context=config["diag_context"],
        )

    elif config["model_type"] == "box_affine":
        model = BoxAffineTransform(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
            intersection_temp=config["int_temp"],
            volume_temp=config["vol_temp"],
            box_type=config["box_type"],
            pooling=config["pooling"],
        )

    elif config["model_type"] == "Word2Box":
        model = Word2Box(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
            intersection_temp=config["int_temp"],
            volume_temp=config["vol_temp"],
            box_type=config["box_type"],
            pooling=config["pooling"],
        )

    elif config["model_type"] == "Word2Vec":
        model = Word2Vec(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
        )

    elif config["model_type"] == "Word2VecPooled":
        model = Word2VecPooled(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
            pooling=config["pooling"],
        )

    elif config["model_type"] == "Word2BoxPooled":
        model = Word2BoxPooled(
            TEXT=TEXT,
            embedding_dim=config["embedding_dim"],
            batch_size=config["batch_size"],
            n_gram=config["n_gram"],
            intersection_temp=config["int_temp"],
            volume_temp=config["vol_temp"],
            box_type=config["box_type"],
            pooling=config["pooling"],
            alpha_dim=config["alpha_dim"],
        )
    elif config["model_type"] == "Word2BoxConjunction":
        model = Word2BoxConjunction(
                TEXT=TEXT,
                embedding_dim=config["embedding_dim"],
                batch_size=config["batch_size"],
                n_gram=config["n_gram"],
                intersection_temp=config["int_temp"],
                volume_temp=config["vol_temp"],
                box_type=config["box_type"],
            )
    elif config["model_type"] == "Word2BoxConjunctionConditional":
        model = Word2BoxConjunctionConditional(
                TEXT=TEXT,
                embedding_dim=config["embedding_dim"],
                batch_size=config["batch_size"],
                n_gram=config["n_gram"],
                intersection_temp=config["int_temp"],
                volume_temp=config["vol_temp"],
                box_type=config["box_type"],
            )
    elif config["model_type"] == "Word2BoxConjunctionBounded":
        model = Word2BoxConjunctionBounded(
                TEXT=TEXT,
                embedding_dim=config["embedding_dim"],
                batch_size=config["batch_size"],
                n_gram=config["n_gram"],
                intersection_temp=config["int_temp"],
                volume_temp=config["vol_temp"],
                box_type=config["box_type"],
            )
    else:
        raise ValueError("Model type is not valid. Please enter a valid model type")

    if use_cuda:
        model.cuda()

    # Instance of trainer
    if config["model_type"] == "Word2Box" or config["model_type"] == "Word2Vec":
        trainer = TrainerWordSimilarity(
            train_iter=train_iter,
            val_iter=val_iter,
            vocab=TEXT,
            lr=config["lr"],
            n_gram=config["n_gram"],
            loss_fn=config["loss_fn"],
            negative_samples=config["negative_samples"],
            log_frequency=config["log_frequency"],
            margin=config["margin"],
            similarity_datasets_dir=config["eval_file"],
            subsampling_prob=None,  # pass: subsampling_prob, when you want to adjust neg_sampling distn
        )
    elif (
        config["model_type"] == "Word2BoxPooled"
        or config["model_type"] == "Word2VecPooled"
        or config["model_type"] == "Word2BoxConjunction"
        or config["model_type"] == "Word2BoxConjunctionConditional"
        or config["model_type"] == "Word2BoxConjunctionBounded"
    ):
        trainer = TrainerWordSimilarityCBOW(
            train_iter=train_iter,
            val_iter=val_iter,
            vocab=TEXT,
            lr=config["lr"],
            n_gram=config["n_gram"],
            loss_fn=config["loss_fn"],
            negative_samples=config["negative_samples"],
            log_frequency=config["log_frequency"],
            margin=config["margin"],
            similarity_datasets_dir=config["eval_file"],
            subsampling_prob=None,  # pass: subsampling_prob, when you want to adjust neg_sampling distn
        )
    else:
        trainer = Trainer(
            train_iter=train_iter,
            val_iter=val_iter,
            vocab=TEXT,
            lr=config["lr"],
            n_gram=config["n_gram"],
            negative_samples=config["negative_samples"],
        )

    trainer.train_model(
        model=model,
        num_epochs=config["num_epochs"],
        path=wandb.run.dir,
        save_model=config.get("save_model", False),
    )
