# coding=latin-1
import util201217 as ut
import step210125 as st
import models210212 as mod
import argparse, re, sys, os
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, precision_recall_fscore_support, log_loss, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from collections import defaultdict, Counter
import pandas as pd
import random
import torch
import torch.nn as nn
from torch import optim
# import warnings filter
from warnings import simplefilter
# ignore all future warnings
# simplefilter(action='ignore', category=(FutureWarning, UserWarning))
###################################################################################################
parser = argparse.ArgumentParser()
# torch settings
parser.add_argument("-seed",   type=int, default=1234)
parser.add_argument("-device", type=str, default='cuda:0')
parser.add_argument("-dtype",  type=int, default=32, choices=[32, 64])
# inputs
parser.add_argument("-dir_data", type=str, default='preproc/210217165643/')

parser.add_argument("-bert_lookup_kind",   type=str, default='mat', choices=['cls', 'meanmat', 'mat'])
parser.add_argument("-bertcontexts", type=str, nargs='+', default=['single', 'paired_cont1'], help="['single', 'paired_cont1', 'paired_cont2']")
parser.add_argument("-contexts", type=int, nargs='+', default=[0, 1])
# preproc
# # model settings
parser.add_argument("-patience",    type=int,   default = 3)
parser.add_argument("-min_gain",    type=int,   default = 12)
parser.add_argument("-save",        type=bool,  default = False)
parser.add_argument("-nr_exps",     type=int,   default = 1)
parser.add_argument("-splits",      type=int,   default = 10, help='almeno 3 o d un errore, credo dovuto all\'output dello stratified')
parser.add_argument("-big_batsize",     type=int,   default = 64)
parser.add_argument("-small_batsize",   type=int,   default = 64)
parser.add_argument("-high_learate",    type=float, default = 0.002)
parser.add_argument("-low_learate",     type=float, default = 0.001)
parser.add_argument("-droprob",     type=float, default = 0.2)
parser.add_argument("-trainable",   type=bool,  default = False)
# # attention
parser.add_argument("-att_heads",      type=int, default = 1)
parser.add_argument("-att_layers",     type=int, default = 1)
parser.add_argument("-txt_fc_outsize", type=int, default = 10)
parser.add_argument("-doc_fc_outsize", type=int, default = 10)
# fc settings
parser.add_argument("-fc_layers", type=int, default=1)
# conv settings
parser.add_argument("-conv_channels",     type=int, nargs='+', default=[32, 64], help="nr of channels conv by conv")
parser.add_argument("-conv_filter_sizes", type=int, nargs='+', default=[2, 3],   help="sizes of filters: window, in each conv")
parser.add_argument("-conv_stridesizes",  type=int, nargs='+', default=[1, 1],   help="conv stride size, conv by conv")
parser.add_argument("-pool_filtersizes",  type=int, nargs='+', default=[2, 2],   help="pool filter size, conv by conv. in order to have a vector as output, the last value will be substituted with the column size of the last conv, so that the last column size will be 1, then squeezed")
parser.add_argument("-pool_stridesizes",  type=int, nargs='+', default=[1, 1],   help="pool stride size, conv by conv")
# bootstrap
parser.add_argument("-n_short_loops", type=int,   default=100)
parser.add_argument("-n_loops",       type=int,   default=100)
parser.add_argument("-perc_sample",   type=float, default=.3)

args = parser.parse_args()
sys.stdout = sys.stderr = log = ut.log(__file__, f"exp{args.nr_exps}_splits{args.splits}_bat{args.big_batsize}_{args.bert_lookup_kind}_lr{str(args.high_learate)[2:]}_drop{str(args.droprob)[2:]}_pat{args.patience}gain{args.min_gain}_broadto0_contexts")
os.system(f"cp {__file__} {ut.__file__} {st.__file__} {mod.__file__} {log.pathtime}")
startime = ut.start()
align_size = ut.print_args(args)
print(f"{'dirout':.<{align_size}} {log.pathtime}")
###################################################################################################
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
dtype_float = torch.float64 if args.dtype == 64 else torch.float32
dtype_int = torch.int64 # if args.dtype == 64 else torch.int32 # o 64 o s'inkazza
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(f"{'GPU in use':.<{align_size}} {device}\n{'#'*80}" if torch.cuda.is_available() else f"No GPU available, using the CPU.\n{'#'*80}")
###################################################################################################
int2str = {0: 'anger', 1: 'disgust', 2: 'fear', 3: 'joy', 4: 'neutral', 5: 'sadness', 6: 'surprise'}
bce  = nn.BCELoss().to(device=device)
ce   = nn.CrossEntropyLoss().to(device=device) # non ammette target float
mse  = nn.MSELoss().to(device=device)

x_gen_trn = np.load(f"{args.dir_data}x_gender_trn.npy")
x_gen_dev = np.load(f"{args.dir_data}x_gender_dev.npy")
x_gen_tst = np.load(f"{args.dir_data}x_gender_tst.npy")
print(f"{'x_gen_trn':<50}{x_gen_trn.shape}")
print(f"{'x_gen_dev':<50}{x_gen_dev.shape}")
print(f"{'x_gen_tst':<50}{x_gen_tst.shape}")
x_inc_trn = np.load(f"{args.dir_data}x_income_trn.npy")
x_inc_dev = np.load(f"{args.dir_data}x_income_dev.npy")
x_inc_tst = np.load(f"{args.dir_data}x_income_tst.npy")
print(f"{'x_inc_trn':<50}{x_inc_trn.shape}")
print(f"{'x_inc_dev':<50}{x_inc_dev.shape}")
print(f"{'x_inc_tst':<50}{x_inc_tst.shape}")
x_iri_trn = np.load(f"{args.dir_data}x_iri_trn.npy")
x_iri_dev = np.load(f"{args.dir_data}x_iri_dev.npy")
x_iri_tst = np.load(f"{args.dir_data}x_iri_tst.npy")
print(f"{'x_iri_trn':<50}{x_iri_trn.shape}")
print(f"{'x_iri_dev':<50}{x_iri_dev.shape}")
print(f"{'x_iri_tst':<50}{x_iri_tst.shape}")

emo_trn = np.load(f"{args.dir_data}y_emo_trn.npy")
emo_dev = np.load(f"{args.dir_data}y_emo_dev.npy")
emo_tst = np.load(f"{args.dir_data}y_emo_tst.npy")
emp_trn = np.load(f"{args.dir_data}y_emp_trn.npy")
emp_dev = np.load(f"{args.dir_data}y_emp_dev.npy")
dis_trn = np.load(f"{args.dir_data}y_dis_trn.npy")
dis_dev = np.load(f"{args.dir_data}y_dis_dev.npy")
print(f"{'emo_trn':<50}{emo_trn.shape}, {emo_trn[:5]}")
print(f"{'emo_dev':<50}{emo_dev.shape}, {emo_dev[:5]}")
print(f"{'emo_tst':<50}{emo_tst.shape}, {emo_tst[:5]}")
print(f"{'emp_trn':<50}{emp_trn.shape}, {emp_trn[:5]}")
print(f"{'emp_dev':<50}{emp_dev.shape}, {emp_dev[:5]}")
print(f"{'dis_trn':<50}{dis_trn.shape}, {dis_trn[:5]}")
print(f"{'dis_dev':<50}{dis_dev.shape}, {dis_dev[:5]}")

file_lookup_trn = f"bert_enlarge_single_emp_trn_lookup_{args.bert_lookup_kind}.pt"
file_lookup_dev = f"bert_enlarge_single_emp_dev_lookup_{args.bert_lookup_kind}.pt"
file_lookup_tst = f"bert_enlarge_single_emp_tst_lookup_{args.bert_lookup_kind}.pt"
lookup_trn = torch.load(f"{args.dir_data}{file_lookup_trn}").to(device=device)
lookup_dev = torch.load(f"{args.dir_data}{file_lookup_dev}").to(device=device)
lookup_tst = torch.load(f"{args.dir_data}{file_lookup_tst}").to(device=device)
print(f"{'lookup_trn':<50}{lookup_trn.shape}")
print(f"{'lookup_dev':<50}{lookup_dev.shape}")
print(f"{'lookup_tst':<50}{lookup_tst.shape}")

dirout = f"{log.pathtime}mtl_bert{args.bert_lookup_kind}/"
os.mkdir(dirout)
print(f"{'#'*80}\n{dirout}")

proc = st.Processing(dirout, device)
boot = st.Bootstrap(dirout)
if args.bert_lookup_kind != 'mat':

    for exp in range(1, args.nr_exps + 1):
        info = f"mi0_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.StlVecTransFc(emb_size   = lookup_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn],
                                                     x_inputs_dev    = [lookup_dev],
                                                     x_inputs_tst    = [lookup_tst],
                                                     x_dtypes        = [dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        control = f"{model.__class__.__name__}"
        boot.feed(control=control, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)
    
        info = f"mi1_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi1StlVecTransFc(emb_size  = lookup_trn.shape[1],
                                     x0_size   = x_gen_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst],
                                                     x_dtypes        = [dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)
    
        info = f"mi2_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi2StlVecTransFc(emb_size  = lookup_trn.shape[1],
                                     x0_size   = x_gen_trn.shape[1],
                                     x1_size   = x_inc_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi3_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi3StlVecTransFc(emb_size  = lookup_trn.shape[1],
                                     x0_size   = x_gen_trn.shape[1],
                                     x1_size   = x_inc_trn.shape[1],
                                     x2_size   = x_iri_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn, x_iri_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev, x_iri_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst, x_iri_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi0_mtl1"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mtl2VecTransFc(emb_size   = lookup_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      y1_size    = 1,
                                      y2_size    = 1,
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce, mse, mse],
                                                     x_inputs        = [lookup_trn],
                                                     x_inputs_dev    = [lookup_dev],
                                                     x_inputs_tst    = [lookup_tst],
                                                     x_dtypes        = [dtype_float],
                                                     y_inputs        = [emo_trn, emp_trn, dis_trn],
                                                     y_inputs_dev    = [emo_dev, emp_dev, dis_dev],
                                                     y_inputs_tst    = [emo_tst, emo_tst, emo_tst],
                                                     y_dtypes        = [dtype_int, dtype_float, dtype_float],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi3_mtl2"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi3Mtl2VecTransFc(emb_size   = lookup_trn.shape[1],
                                      x0_size   = x_gen_trn.shape[1],
                                      x1_size   = x_inc_trn.shape[1],
                                      x2_size   = x_iri_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      y1_size    = 1,
                                      y2_size    = 1,
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce, mse, mse],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn, x_iri_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev, x_iri_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst, x_iri_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn, emp_trn, dis_trn],
                                                     y_inputs_dev    = [emo_dev, emp_dev, dis_dev],
                                                     y_inputs_tst    = [emo_tst, emo_tst, emo_tst],
                                                     y_dtypes        = [dtype_int, dtype_float, dtype_float],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

    boot.run(args.n_loops, args.perc_sample)

else:

    for exp in range(1, args.nr_exps + 1):
        info = f"mi0_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.StlLookupTransFc(emb_size   = lookup_trn.shape[2],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn],
                                                     x_inputs_dev    = [lookup_dev],
                                                     x_inputs_tst    = [lookup_tst],
                                                     x_dtypes        = [dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        control = f"{model.__class__.__name__}"
        boot.feed(control=control, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)
        ut.list2file([int2str[pred] for pred in preds], f"{dirout}/predictions_EMO.tsv")
    
        info = f"mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi1StlLookupTransFc(emb_size  = lookup_trn.shape[2],
                                     x0_size   = x_gen_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst],
                                                     x_dtypes        = [dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi2_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi2StlLookupTransFc(emb_size  = lookup_trn.shape[2],
                                     x0_size   = x_gen_trn.shape[1],
                                     x1_size   = x_inc_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi3_mtl0"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi3StlLookupTransFc(emb_size  = lookup_trn.shape[2],
                                     x0_size   = x_gen_trn.shape[1],
                                     x1_size   = x_inc_trn.shape[1],
                                     x2_size   = x_iri_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn, x_iri_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev, x_iri_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst, x_iri_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn],
                                                     y_inputs_dev    = [emo_dev],
                                                     y_inputs_tst    = [emo_tst],
                                                     y_dtypes        = [dtype_int],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

        info = f"mi0_mtl2"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mtl2LookupTransFc(emb_size   = lookup_trn.shape[2],
                                      y0_size    = len(set(emo_trn)),
                                      y1_size    = 1,
                                      y2_size    = 1,
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce, mse, mse],
                                                     x_inputs        = [lookup_trn],
                                                     x_inputs_dev    = [lookup_dev],
                                                     x_inputs_tst    = [lookup_tst],
                                                     x_dtypes        = [dtype_float],
                                                     y_inputs        = [emo_trn, emp_trn, dis_trn],
                                                     y_inputs_dev    = [emo_dev, emp_dev, dis_dev],
                                                     y_inputs_tst    = [emo_tst, emo_tst, emo_tst],
                                                     y_dtypes        = [dtype_int, dtype_float, dtype_float],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)
    
        info = f"mi3_mtl2"
        direxp  = f"exp{exp}_{info}"
        model = mod.Mi3Mtl2LookupTransFc(emb_size   = lookup_trn.shape[2],
                                      x0_size   = x_gen_trn.shape[1],
                                      x1_size   = x_inc_trn.shape[1],
                                      x2_size   = x_iri_trn.shape[1],
                                      y0_size    = len(set(emo_trn)),
                                      y1_size    = 1,
                                      y2_size    = 1,
                                      att_heads  = args.att_heads,
                                      att_layers = args.att_layers,
                                      emb_fc_layers  = args.fc_layers,
                                      x0_fc_layers   = args.fc_layers,
                                      x1_fc_layers   = args.fc_layers,
                                      out_fc_layers  = args.fc_layers,
                                      droprob    = args.droprob,
                                      device     = device)
        optimizer = optim.Adam(model.parameters(), lr=args.high_learate)
        dirres, preds, targs, _, m_epochs = proc.exp(model           = model,
                                                     optimizer       = optimizer,
                                                     lossfuncs       = [ce, mse, mse],
                                                     x_inputs        = [lookup_trn, x_gen_trn, x_inc_trn, x_iri_trn],
                                                     x_inputs_dev    = [lookup_dev, x_gen_dev, x_inc_dev, x_iri_dev],
                                                     x_inputs_tst    = [lookup_tst, x_gen_tst, x_inc_tst, x_iri_tst],
                                                     x_dtypes        = [dtype_float, dtype_float, dtype_float, dtype_float],
                                                     y_inputs        = [emo_trn, emp_trn, dis_trn],
                                                     y_inputs_dev    = [emo_dev, emp_dev, dis_dev],
                                                     y_inputs_tst    = [emo_tst, emo_tst, emo_tst],
                                                     y_dtypes        = [dtype_int, dtype_float, dtype_float],
                                                     batsize         = args.big_batsize,
                                                     patience        = args.patience,
                                                     min_gain        = args.min_gain,
                                                     n_splits        = args.splits, # solo crossval, nr folds
                                                     save            = args.save,
                                                     additional_tsts = (),
                                                     str_info        = direxp)
        treatment = f"{model.__class__.__name__}_{info}"
        boot.feed(control=control, treatment=treatment, fold=dirres, preds=preds, targs=targs, epochs=m_epochs)

    boot.run(args.n_loops, args.perc_sample)

ut.end(startime)
