import os
import argparse
from tqdm import tqdm
import time
import random
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import torchtext
import torch.nn.functional as F
from utils.min_norm_solvers import MinNormSolver, gradient_normalizers
import numpy
from util import timing, get_model
from config import Config
import pandas as pd

parser = argparse.ArgumentParser()
parser.add_argument('-d', default='sentiment')  # dataset: sentiment/20news
parser.add_argument('-m', default='textcnn')    # feature extractor: textcnn/lstm
parser.add_argument('-t', default='meta')       # algorithm
parser.add_argument('-g', default='0')          # GPU number
parser.add_argument('-c', default='')           # remark
parser.add_argument('-alpha', default='0.1')
parser.add_argument('-split_scale', default='0.1')
args = parser.parse_args()

#torch.cuda.set_device(int(args.g))
device = torch.device('cuda:'+str(args.g) if torch.cuda.is_available() else 'cpu')
cfg = Config(data=args.d, model=args.m, task=args.t)    # model config
tasks = cfg.data['tasks']
criterion = torch.nn.CrossEntropyLoss() 
print_result = pd.DataFrame(columns=tasks+['avg_acc'])
print_loss = pd.DataFrame(columns=tasks+['var'])

if __name__ == '__main__':
    from data.dataset import NewsDataset, SentimentDataset
    if args.d == '20news':
        loader = NewsDataset(cfg)
    elif args.d == 'sentiment':
        loader = SentimentDataset(cfg)

    max_epoch = cfg.exp['epochs']
    cfg_exp_str = '_'.join([k + str(cfg.exp[k]) for k in cfg.exp])
    exp_id = f'{time.strftime("%H%M", time.localtime())}_{cfg_exp_str}_{args.m}_{args.c}'
    print(f'# exp_id: {exp_id}')
    save_path = f'./runs/{args.d}/{args.t}/{time.strftime("%m%d")}/{exp_id}'
    writer = SummaryWriter(log_dir=save_path)   #tensorboard


@timing
def init_model():
    model = get_model(cfg, loader.word_embeddings)
    return model

def scheduler_step(schedulers):
    for s in schedulers:
        s.step()

def get_opt(model, rep=False, dis=False):
    lr = cfg.exp['lr'] 
    params = {'all': []}
    for m_id, m in enumerate(model):
        model[m].to(device)
        if (m == 'rep' and rep) or (m == 'dis' and dis):
            params[m] = model[m].parameters()
        else:
            params['all'] += model[m].parameters() 
    opts = []
    schedulers = []
    for m in ['all', 'rep', 'dis']:
        if m in params:
            if 'opt' in cfg.exp and cfg.exp['opt'] == 'Adam':
                opt = torch.optim.Adam(params[m], lr=lr, weight_decay=cfg.exp['wd'] if 'wd' in cfg.exp else 0)
            else:
                opt = torch.optim.SGD(params[m], lr=lr, momentum=0.9)
                schedulers.append(torch.optim.lr_scheduler.StepLR(opt, step_size=cfg.exp['step'], gamma=0.9))
            opts.append(opt)
    return opts, schedulers


def meta_multi_task():
    def split_dataset(scale):
        if args.d == 'sentiment':
            sets = [loader.data[2*i] for i in range(len(tasks))]
        elif args.d == '20news':
            sets = [loader.data[i] for i in range(len(tasks))]
        sets1 = []
        sets2 = []
        for task_set in sets:
            s1, s2 = task_set.split(split_ratio=1-scale, random_state=random.seed(int(time.time()*256)))
            sets1.append(s1)
            sets2.append(s2)
        iter1 = [torchtext.data.BucketIterator(sets1[i], batch_size=cfg.data['batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=True) for i in range(len(tasks))]
        iter2 = [torchtext.data.BucketIterator(sets2[i], batch_size=cfg.exp['query_batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=True) for i in range(len(tasks))]
        return sets1, sets2, iter1, iter2

    def return_grad():
        task_grad = torch.empty(0).to(device)
        for name, parms in model['rep'].named_parameters(): 
            if parms.requires_grad:
                task_grad = torch.cat((task_grad, parms.grad.clone().detach().view(1,-1)[0]), 0)
        return task_grad

    train_loss_pd = pd.DataFrame(columns=tasks)
    query_loss_pd = pd.DataFrame(columns=tasks)

    model = init_model()
    alpha = float(args.alpha)
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    support_set, query_set, support, query = split_dataset(float(args.split_scale))
    w_k = [1. for _ in range(len(tasks))]
    output_weight = pd.DataFrame(columns=tasks)

    for epoch in tqdm(range(max_epoch), unit='epoch'):
        for m in model:
            model[m].train()
        opt_g.zero_grad()
        opt_c.zero_grad()
        task_grad_list = []

        if cfg.exp['split_epoch'] != 'none':
            if epoch % int(cfg.exp['split_epoch']) == 0:
                support_set, query_set, support, query = split_dataset(float(args.split_scale))

        # train
        train_loss = []
        for t_iter, t in enumerate(tasks):
            n_batch = 0
            for i, batch in enumerate(support[t_iter]):
                n_batch += 1
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss_t = criterion(y_, y)
                if i > 0:
                    loss += loss_t
                else:
                    loss = loss_t
            train_t_loss = loss.item() / n_batch
            train_loss.append(train_t_loss)
            writer.add_scalar(f'train_loss/{t}', train_t_loss, epoch)
            loss *= (w_k[t_iter] / len(tasks))
            loss.backward()
            task_grad_list.append(return_grad())

        for i in range(len(tasks)-1, 0, -1):
            task_grad_list[i] -= task_grad_list[i-1]
        for i in range(len(task_grad_list)):
            task_grad_list[i] /= (w_k[i] / len(tasks))

        opt_g.step()
        opt_c.step()

        # query
        query_loss = []
        query_grad_list = []
        for t_iter, t in enumerate(tasks):
            loss = None
            n_batches = 0
            for i, batch in enumerate(query[t_iter]):
                n_batches += 1
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss_t = criterion(y_, y)
                if i > 0:
                    loss += loss_t
                else:
                    loss = loss_t
            loss /= n_batches
            query_loss.append(loss.item())
            writer.add_scalar(f'query_loss/{t}', loss.item(), epoch)
            opt_g.zero_grad()
            opt_c.zero_grad()
            loss.backward()
            query_grad_list.append(return_grad())
        w_list = {}
        for t_iter, t in enumerate(tasks):
            w_list[t] = [torch.matmul(query_grad_list[t_iter], task_grad_list[i]) for i in range(len(tasks))]
        scale, min_norm = MinNormSolver.find_min_norm_element([w_list[t] for t in tasks])
        for t_iter, t in enumerate(tasks):
            w_k[t_iter] += (alpha * scale[t_iter] * sum([w_list[i][t_iter] for i in tasks]).item())  

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        output_weight.loc[epoch] = w_k
        train_loss_pd.loc[epoch] = train_loss
        query_loss_pd.loc[epoch] = query_loss
        output_weight.to_csv(os.path.join(save_path,"meta_scale.csv"))
        train_loss_pd.to_csv(os.path.join(save_path,"train_loss_"+exp_id+".csv"))
        query_loss_pd.to_csv(os.path.join(save_path,"query_loss_"+exp_id+".csv"))
        
        for t_iter, t in enumerate(tasks):
            writer.add_scalar(f'weights/{t}', w_k[t_iter], epoch)

        scheduler_step(schedulers) 


def bandit_multi_task():
    import math
    rho = cfg.exp['rho']
    eta_p = cfg.exp['eta_p']
    output_p = pd.DataFrame(columns=tasks)

    # function for computing lambda
    def compute_l(x,q,rho):
        kk = 1/(x+1)
        q_kk = [math.pow(i,kk) for i in q]
        t1= sum(q_kk)
        t2 = sum([math.log(q[i])*q_kk[i] for i in range(len(q))])/(x+1)
        return math.log(len(tasks)) - rho - math.log(t1) + t2/t1

    # Algorithm 2 in paper
    def find_lambda(e,beta,upper,jump):
        if compute_l(0,q_k,rho) <= 0:
            return 0
        left = 0
        right = beta
        flag = 0
        while compute_l(right,q_k,rho) > 0:
            flag += 1
            left = right
            right = right + beta
            if right > upper:
                return upper
            if flag > jump:  
                break
        x = (left + right)/2
        ans = compute_l(x,q_k,rho)
        flag = 0
        while abs(ans) > e:
            flag += 1
            if ans > 0:
                left = x
            else:
                right = x
            x = (left + right)/2
            ans = compute_l(x,q_k,rho)
            if flag > jump:        # if lambda is too large, skip out the loop
                return upper
        return x

    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    p_k = [1/len(tasks) for i in tasks]
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        d = {t: 0 for t in tasks}
        for n in range(n_batches): 
            losses = []
            loss = None
            for i,t in enumerate(tasks):
                batch = next(iters[t])
                d[t] += 1 
                if batch.text.shape[1] != cfg.data['batch_size']:
                    batch = next(iters[t])
                    d[t] += 1
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss_t = criterion(y_, y)
                losses.append(loss_t.item())
                if i > 0:
                    loss += loss_t * p_k[i]
                else:
                    loss = loss_t * p_k[i]
            loss = loss/len(tasks)
            opt_g.zero_grad()
            opt_c.zero_grad()
            loss.backward()
            opt_g.step()
            opt_c.step()
            
            q_k = [p_k[i] * math.exp(eta_p*losses[i]) for i in range(len(tasks))]
            lam = find_lambda(1e-15,10,2e5,1e5)
            q_lam = [math.pow(i,1/(lam+1)) for i in q_k]
            q_sum = sum(q_lam)
            p_k = [i/q_sum for i in q_lam]
        output_p.loc[epoch] = p_k
        output_p.to_csv(os.path.join(save_path,f'rho_{args.rho}_eta_{args.eta}'+"_p_k.csv"))

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)
        scheduler_step(schedulers)   


def uncertain_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    output_weight = pd.DataFrame(columns=tasks)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        d = {t: 0 for t in tasks} 
        output_w = []
        for n in range(n_batches):
            loss = None
            for i,t in enumerate(tasks): 
                batch = next(iters[t])
                d[t] += 1 
                if batch.text.shape[1] != cfg.data['batch_size']:
                    batch = next(iters[t])
                    d[t] += 1
                x, y = batch.text.to(device), batch.label.to(device)
                y_ , logsigma= model[t](model['rep'](x))
                if n == 0:
                    output_w.append(float(torch.exp(-logsigma[0])/2))
                loss = (torch.exp(-logsigma[0]) * criterion(y_, y) + logsigma[0]) / 2
                opt_g.zero_grad()
                opt_c.zero_grad()
                loss.backward()
                opt_g.step()
                opt_c.step()
        output_weight.loc[epoch] = output_w
        output_weight.to_csv(os.path.join(save_path,"uncertain_scale.csv"))

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)
        scheduler_step(schedulers)   


def gn_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    output_gn = pd.DataFrame(columns=tasks)

    first_train_loss = torch.zeros(len(tasks),dtype=torch.float32).to(device)
    gn_train_loss_arr = torch.zeros(len(tasks),dtype=torch.float32).to(device)
    gn_scale = torch.ones(len(tasks),dtype=torch.float32,requires_grad=True).to(device)
    gn_scale = Variable(gn_scale,requires_grad=True)
    opt_gn = torch.optim.SGD([gn_scale], lr=cfg.exp['gn_lr'])
    gn_alpha = cfg.exp['gn_alpha']
    dic_tasks = {}
    gn_scheduler = torch.optim.lr_scheduler.StepLR(opt_gn, step_size=cfg.exp['gn_step'], gamma=0.9)
    for i,j in enumerate(tasks):
        dic_tasks[j] = i

    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        d = {t: 0 for t in tasks} 
        for n in range(n_batches):
            grads = {}
            losses = {}
            xs = {}
            ys = {}
            for t in tasks:
                batch = next(iters[t])
                d[t] += 1
                x, y = batch.text.to(device), batch.label.to(device)
                xs[t] = x
                ys[t] = y
                y_ = model[t](model['rep'](xs[t]))
                loss = criterion(y_, ys[t])
                gn_train_loss_arr[dic_tasks[t]] = loss.item() 
                loss = loss * float(gn_scale[dic_tasks[t]]) 
                opt_g.zero_grad()
                opt_c.zero_grad()
                loss.backward()
                opt_g.step()
                opt_c.step()
                if epoch == 0:
                    kk = dic_tasks[t]
                    first_train_loss[kk] = (first_train_loss[kk]*n+loss.item())/(n+1)    
            for t in tasks: 
                with torch.no_grad():
                    rep = model['rep'](xs[t])
                rep = rep.clone().requires_grad_()
                y_ = model[t](rep)
                loss = criterion(y_, ys[t])
                losses[t] = loss.item()
                opt_c.zero_grad()
                loss.backward()
                grads[t] = rep.grad.clone().requires_grad_(False) # get gradient of the last layer

            # w_i
            gw_arr = [torch.sum(torch.norm(gn_scale[i]*grads[t],p=2, dim=1, keepdim=False)) for i,t in enumerate(tasks)]
            gw_arr = torch.stack(gw_arr)
            loss_ratio = [gn_train_loss_arr[i] / float(first_train_loss[i]) for i in range(len(tasks))]
            loss_ratio = torch.stack(loss_ratio)                # list -> tensor
            loss_ratio = loss_ratio / torch.mean(loss_ratio)    # loss_ratio is absolutely positive
            # constant term -> ideal gradient
            constant_term = (loss_ratio ** gn_alpha) * gw_arr.mean()
            constant_term = constant_term.detach()              # as a constant  
            # L_{grad}              
            grad_loss = F.l1_loss(gw_arr,constant_term,reduction="sum")
            grad_loss.backward()
            opt_gn.step()       # grad_norm optimizer, just for weights of loss
            # renormalizaiton
            gn_scale.data = gn_scale.data * len(tasks) / torch.sum(gn_scale.data)

        output_gn.loc[epoch] = gn_scale.cpu().detach().numpy().tolist()
        output_gn.to_csv(os.path.join(save_path,"gn_scale.csv"))

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)
        gn_scheduler.step()


def tchebycheff_adv_multi_task():
    model = init_model()
    [opt_c, opt_g, opt_d], schedulers = get_opt(model, rep=True, dis=True)
    flag_train_dis = True
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        train_losses = {t: [] for t in tasks+['dis']}
        for i, batch in enumerate(loader.all):
            x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
            with torch.no_grad():
                rep = model['rep'](x)
            opt_c.zero_grad()
            for i, t in enumerate(tasks):
                is_t = t_label == i
                xt = rep[is_t]
                yt = y[is_t]
                yt_ = model[t](xt)
                loss = criterion(yt_, yt) * cfg.w[i]
                train_losses[t].append(loss.item())
                loss.backward()
            opt_c.step()
            # train discriminator
            t_label_ = model['dis'](rep)
            dis_loss = criterion(t_label_, t_label)
            train_losses['dis'].append(dis_loss.item())
            if flag_train_dis:
                opt_d.zero_grad()
                dis_loss.backward()
                opt_d.step()
        
        # compute weight
        w = {t: sum(train_losses[t]) / len(train_losses[t]) for t in tasks}
        dis_loss = sum(train_losses['dis']) / len(train_losses['dis'])
        max_t = 'dis' if dis_loss <= cfg.exp['alpha'] else max(w, key=w.get)
        if max_t == 'dis':
            flag_train_dis = False
        w['dis'] = dis_loss
        writer.add_scalar('max_t', len(tasks) if max_t == 'dis' else tasks.index(max_t), epoch)
        for k in w:
            writer.add_scalar(f'w/{k}', w[k], epoch)

        # train max
        max_loader = loader.all if max_t == 'dis' else loader.train[max_t]
        for i, batch in enumerate(max_loader):
            x, y = batch.text.to(device), batch.label.to(device)
            y_ = model[max_t](model['rep'](x))
            loss = criterion(y_, y)
            opt_g.zero_grad()
            if max_t == 'dis':
                opt_d.zero_grad()
                loss = loss * cfg.exp['beta']
                loss.backward()
                opt_d.step()
            else:
                opt_c.zero_grad()
                loss.backward()
                opt_c.step()
            opt_g.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)
        scheduler_step(schedulers)


def tchebycheff_multi_task():
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train 
        for m in model:
            model[m].train()
        train_losses = {t: [] for t in tasks}
        for i, batch in enumerate(loader.all):
            with torch.no_grad():
                x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
                rep = model['rep'](x)
            opt_c.zero_grad()
            for i, t in enumerate(tasks):
                is_t = t_label == i
                xt = rep[is_t]
                yt = y[is_t]
                yt_ = model[t](xt)
                loss = criterion(yt_, yt) * cfg.w[i]
                train_losses[t].append(loss.item())
                loss.backward() 
            opt_c.step() 
        w = {t: sum(train_losses[t]) / len(train_losses[t]) for t in tasks} 
        for t in w:
            writer.add_scalar(f'w/{t}', w[t], epoch)
        max_t = max(w, key=w.get) 
        writer.add_scalar('max_t', tasks.index(max_t), epoch)
        for i, batch in enumerate(loader.train[max_t]):
            x, y = batch.text.to(device), batch.label.to(device)
            y_ = model[max_t](model['rep'](x))
            loss = criterion(y_, y)
            opt_g.zero_grad()
            opt_c.zero_grad()
            loss.backward()
            opt_g.step()
            opt_c.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)
        scheduler_step(schedulers)


def fudan_multi_task():
    model = init_model()
    [opt], schedulers = get_opt({'': model})
    # opt = model.to_and_get_optimizer(device)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        model.train()
        losses = {'c': [], 'd': [], 'r': []}
        for i, batch in enumerate(loader.all):
            x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
            c_losses, d_loss, r_losses = model(x, y, t_label)
            loss = sum(c_losses) + d_loss * 0.05 + sum(r_losses)
            losses['c'].append(sum(c_losses).item())
            losses['d'].append(0.05 * d_loss.item())
            losses['r'].append(sum(r_losses).item())
            opt.zero_grad()
            loss.backward()
            opt.step()
        for k in losses:
            writer.add_scalar(f'train_loss/{k}', sum(losses[k])/len(losses[k]), epoch)

        # validation
        if epoch % 10 == 0:
            with torch.no_grad():
                model.eval()
                accs = []
                losses_var = []
                for i, t in enumerate(tasks):
                    test_loss = n_acc = n_all = 0
                    for j, batch in enumerate(loader.test[t]):
                        x, y = batch.text.to(device), batch.label.to(device)
                        y_ = model.val(x, i)
                        loss = criterion(y_, y)
                        test_loss += loss.item()
                        n_acc += y_.argmax(1).eq(y).sum()
                        n_all += y.shape[0]
                    acc = n_acc / float(n_all)
                    accs.append(acc)
                    writer.add_scalar(f'loss/test/{t}', test_loss, epoch)
                    writer.add_scalar(f'acc/test/{t}', acc, epoch)
                    losses_var.append(test_loss)
                writer.add_scalar('avg_var', numpy.var(losses_var), epoch)
                writer.add_scalar('avg_acc', sum(accs)/len(accs), epoch)


def mgda_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    output_weight = pd.DataFrame(columns=tasks)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        d = {t: 0 for t in tasks} 
        output_w = []
        for n in range(n_batches): 
            grads = {}
            losses = {}
            xs = {}
            ys = {}
            for t in tasks:
                batch = next(iters[t])
                d[t] += 1 
                if batch.text.shape[1] != cfg.data['batch_size']:
                    batch = next(iters[t])
                    d[t] += 1
                x, y = batch.text.to(device), batch.label.to(device)
                xs[t] = x
                ys[t] = y
                with torch.no_grad():
                    rep = model['rep'](x)
                rep = rep.clone().requires_grad_()
                y_ = model[t](rep)
                loss = criterion(y_, y)
                losses[t] = loss.item()
                opt_c.zero_grad()
                loss.backward()
                grads[t] = [rep.grad.clone().requires_grad_(False)]
            # Normalize all gradients, this is optional and not included in the paper.
            gn = gradient_normalizers(grads, losses, 'loss+')
            grads = {t: grads[t] / gn[t] for t in grads}
            # Frank-Wolfe iteration to compute scales.
            sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
            scales = {t: sol[i] for i, t in enumerate(tasks)}
            for t in tasks:
                if n == 0:
                    output_w.append(scales[t])
                y_ = model[t](model['rep'](xs[t]))
                loss = criterion(y_, ys[t]) * scales[t]
                opt_g.zero_grad()
                opt_c.zero_grad()
                loss.backward()
                opt_g.step()
                opt_c.step()

        output_weight.loc[epoch] = output_w
        output_weight.to_csv(os.path.join(save_path,"mgda_scale.csv"))
        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        scheduler_step(schedulers)


def uniform_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt], schedulers = get_opt(model)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        for i in range(n_batches):
            for t in tasks:
                batch = iters[t].__next__()
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss = criterion(y_, y) / len(tasks)
                opt.zero_grad()
                loss.backward()
                opt.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        if epoch % 500 == 0 and epoch > 0:
            state_dict = {'model': {}}
            for key in model:
                state_dict['model'][key] = model[key].state_dict()
            os.makedirs(f'checkpoints/{args.d}/uniform/', exist_ok=True)
            torch.save(state_dict, f'checkpoints/{args.d}/uniform/{epoch}')

        scheduler_step(schedulers)


def single_task():
    test_batch = 10
    model = init_model()
    [opt], schedulers = get_opt(model)
    losses_var = []
    avg_acc = []
    avg_num = []
    for t in tasks:
        losses_var.append([])
        avg_acc.append([])
        avg_num.append([])
        for epoch in tqdm(range(max_epoch), unit='epoch', postfix=t):
            # train
            for m in model:
                model[m].train()
            train_loss = []
            for i, batch in enumerate(loader.train[t]):
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss = criterion(y_, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_loss.append(loss.item())
            scheduler_step(schedulers)
            writer.add_scalar(f'train_loss/{t}', sum(train_loss)/len(train_loss), epoch)

            # validation
            if epoch % test_batch == 0:
                with torch.no_grad():
                    for m in model:
                        model[m].eval()
                    test_loss = n_acc = n_all = 0
                    for i, batch in enumerate(loader.test[t]):
                        x, y = batch.text.to(device), batch.label.to(device)
                        y_ = model[t](model['rep'](x))
                        loss = criterion(y_, y)
                        test_loss += loss.item()
                        n_acc += y_.argmax(1).eq(y).sum()
                        n_all += y.shape[0]
                    acc = n_acc / float(n_all)
                    writer.add_scalar(f'loss/test/{t}', test_loss, epoch)
                    writer.add_scalar(f'acc/test/{t}', acc, epoch)
                    losses_var[-1].append(test_loss)
                    avg_acc[-1].append(n_acc)
                    avg_num[-1].append(n_all)

            if epoch % 100 == 0 and epoch > 0:
                state_dict = {'model': {}}
                for key in ['rep', t]:
                    state_dict['model'][key] = model[key].state_dict()
                os.makedirs(f'checkpoints/{args.d}/single/', exist_ok=True)
                torch.save(state_dict, f'checkpoints/{args.d}/single/{t}_{epoch}')
        # re-init model & optimizer
        model = init_model()
        [opt], schedulers = get_opt(model)

    for i in range(len(avg_num[0])):
        tt_var = []
        tt_acc = 0
        tt_num = 0
        for j in range(len(tasks)):
            tt_var.append(losses_var[j][i])
            tt_acc += avg_acc[j][i]
            tt_num += avg_num[j][i]
        writer.add_scalar('avg_acc', tt_acc/float(tt_num), i*test_batch)
        writer.add_scalar('avg_var', numpy.var(tt_var), i*test_batch)


def validate_all_tasks(model, epoch):
    with torch.no_grad():
        for m in model:
            model[m].eval()
        accs = []
        losses_var = []
        for t in tasks:
            test_loss = n_acc = n_all = n_batch = 0
            for i, batch in enumerate(loader.test[t]):
                n_batch += 1
                x, y = batch.text.to(device), batch.label.to(device)
                if cfg.task == 'uncertain':
                    y_ , _ = model[t](model['rep'](x))
                else:
                    y_ = model[t](model['rep'](x))
                loss = criterion(y_, y)
                test_loss += loss.item()
                n_acc += y_.argmax(1).eq(y).sum()
                n_all += y.shape[0]
            acc = n_acc / float(n_all)
            accs.append(acc)
            test_loss /= n_batch
            losses_var.append(test_loss)
            writer.add_scalar(f'test_loss/{t}', test_loss, epoch)
            writer.add_scalar(f'acc/test/{t}', acc, epoch)
        avg_acc = sum(accs)/len(accs)
        avg_var = numpy.var(losses_var)
        writer.add_scalar('avg_acc', avg_acc, epoch)
        writer.add_scalar('avg_var', avg_var, epoch)
        accs.append(avg_acc)
        losses_var.append(avg_var)
        accs = [i.item() for i in accs]
        print_result.loc[epoch] = accs
        print_loss.loc[epoch] = losses_var
        print_result.to_csv(os.path.join(save_path,"result_"+exp_id+".csv"))
        print_loss.to_csv(os.path.join(save_path,"test_loss_"+exp_id+".csv"))
        print(
            'Epoch: [%d | %d] | Natural Test Acc %.2f\n' % (
            epoch + 1,
            max_epoch,
            avg_acc,
            ))


if __name__ == '__main__':
    if args.t == 'single':
        single_task()
    elif args.t == 'uniform':
        uniform_multi_task()
    elif args.t == 'mgda':
        mgda_multi_task()
    elif args.t == 'gn':
        gn_multi_task()
    elif args.t == 'fudan':
        fudan_multi_task()
    elif args.t == 'tchebycheff':
        tchebycheff_multi_task()
    elif args.t == 'tchebycheff_adv':
        tchebycheff_adv_multi_task()
    elif args.t == 'uncertain':
        uncertain_multi_task()
    elif args.t == 'bandit':
        bandit_multi_task()
    elif args.t == 'meta':
        meta_multi_task()
    print('exit')
