import os
import argparse
from tqdm import tqdm
import time
import random
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.min_norm_solvers import MinNormSolver, gradient_normalizers

from util import timing, get_model
from config import Config

parser = argparse.ArgumentParser()
parser.add_argument('-d', default='sentiment')
parser.add_argument('-m', default='textcnn')
parser.add_argument('-t', default='tchebycheff')
parser.add_argument('-g', default='0')
parser.add_argument('-c', default='')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.g
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg = Config(data=args.d, model=args.m, task=args.t)
tasks = cfg.data['tasks']
criterion = torch.nn.CrossEntropyLoss()


if __name__ == '__main__':
    from data.dataset import NewsDataset, SentimentDataset
    if args.d == '20news':
        loader = NewsDataset(cfg)
    elif args.d == 'sentiment':
        loader = SentimentDataset(cfg)

    max_epoch = cfg.exp['epochs']

    cfg_exp_str = '_'.join([k + str(cfg.exp[k]) for k in cfg.exp])
    exp_id = f'{time.strftime("%H%M", time.localtime())}_{cfg_exp_str}_{args.m}_{args.c}'
    print(f'# exp_id: {exp_id}')
    writer = SummaryWriter(log_dir=f'./runs/{args.d}/{args.t}/{time.strftime("%m%d")}/{exp_id}')


@timing
def init_model():
    model = get_model(cfg, loader.word_embeddings)
    return model


def scheduler_step(schedulers):
    for s in schedulers:
        s.step()


def get_opt(model, rep=False, dis=False):
    lr = cfg.exp['lr']
    params = {'all': []}
    for m_id, m in enumerate(model):
        model[m].to(device)
        if (m == 'rep' and rep) or (m == 'dis' and dis):
            params[m] = model[m].parameters()
        else:
            params['all'] += model[m].parameters()
    opts = []
    schedulers = []
    for m in ['all', 'rep', 'dis']:
        if m in params:
            if 'opt' in cfg.exp and cfg.exp['opt'] == 'Adam':
                opt = torch.optim.Adam(params[m], lr=lr, weight_decay=cfg.exp['wd'] if 'wd' in cfg.exp else 0)
            else:
                opt = torch.optim.SGD(params[m], lr=lr, momentum=0.9)
                schedulers.append(torch.optim.lr_scheduler.StepLR(opt, step_size=cfg.exp['step'], gamma=0.9))
            opts.append(opt)
    return opts, schedulers


def tchebycheff_adv_multi_task():
    model = init_model()
    [opt_c, opt_g, opt_d], schedulers = get_opt(model, rep=True, dis=True)
    state_dict = torch.load(cfg.checkpoint)
    model['rep'].load_state_dict(state_dict['model']['rep'])
    flag_train_dis = True
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        train_losses = {t: [] for t in tasks+['dis']}
        for i, batch in enumerate(loader.all):
            x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
            with torch.no_grad():
                rep = model['rep'](x)
            opt_c.zero_grad()
            for i, t in enumerate(tasks):
                is_t = t_label == i
                xt = rep[is_t]
                yt = y[is_t]
                yt_ = model[t](xt)
                loss = criterion(yt_, yt) * cfg.w[i]
                train_losses[t].append(loss.item())
                loss.backward()
            opt_c.step()
            t_label_ = model['dis'](rep)
            dis_loss = criterion(t_label_, t_label)
            train_losses['dis'].append(dis_loss.item())
            if flag_train_dis:
                opt_d.zero_grad()
                dis_loss.backward()
                opt_d.step()

        # compute weight
        w = {t: sum(train_losses[t]) / len(train_losses[t]) for t in tasks}
        dis_loss = sum(train_losses['dis']) / len(train_losses['dis'])
        max_t = 'dis' if dis_loss <= cfg.exp['alpha'] else max(w, key=w.get)
        if max_t == 'dis':
            flag_train_dis = False
        w['dis'] = dis_loss
        writer.add_scalar('max_t', len(tasks) if max_t == 'dis' else tasks.index(max_t), epoch)
        for k in w:
            writer.add_scalar(f'w/{k}', w[k], epoch)

        # train max
        max_loader = loader.all if max_t == 'dis' else loader.train[max_t]
        for i, batch in enumerate(max_loader):
            x, y = batch.text.to(device), batch.label.to(device)
            y_ = model[max_t](model['rep'](x))
            loss = criterion(y_, y)
            opt_g.zero_grad()
            if max_t == 'dis':
                opt_d.zero_grad()
                loss = loss * cfg.exp['beta']
                loss.backward()
                opt_d.step()
            else:
                opt_c.zero_grad()
                loss.backward()
                opt_c.step()
            opt_g.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        scheduler_step(schedulers)


def tchebycheff_multi_task():
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    state_dict = torch.load(cfg.checkpoint)
    model['rep'].load_state_dict(state_dict['model']['rep'])
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        train_losses = {t: [] for t in tasks}
        for i, batch in enumerate(loader.all):
            with torch.no_grad():
                x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
                rep = model['rep'](x)
            opt_c.zero_grad()
            for i, t in enumerate(tasks):
                is_t = t_label == i
                xt = rep[is_t]
                yt = y[is_t]
                yt_ = model[t](xt)
                loss = criterion(yt_, yt) * cfg.w[i]
                train_losses[t].append(loss.item())
                loss.backward()
            opt_c.step()
        w = {t: sum(train_losses[t]) / len(train_losses[t]) for t in tasks}
        for t in w:
            writer.add_scalar(f'w/{t}', w[t], epoch)
        max_t = max(w, key=w.get)
        writer.add_scalar('max_t', tasks.index(max_t), epoch)
        for i, batch in enumerate(loader.train[max_t]):
            x, y = batch.text.to(device), batch.label.to(device)
            y_ = model[max_t](model['rep'](x))
            loss = criterion(y_, y)
            opt_g.zero_grad()
            opt_c.zero_grad()
            loss.backward()
            opt_g.step()
            opt_c.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        scheduler_step(schedulers)


def fudan_multi_task():
    model = init_model()
    [opt], schedulers = get_opt({'': model})
    # opt = model.to_and_get_optimizer(device)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        model.train()
        losses = {'c': [], 'd': [], 'r': []}
        for i, batch in enumerate(loader.all):
            x, y, t_label = batch.text.to(device), batch.label.to(device), batch.task.to(device)
            c_losses, d_loss, r_losses = model(x, y, t_label)
            loss = sum(c_losses) + d_loss * 0.05 + sum(r_losses)
            losses['c'].append(sum(c_losses).item())
            losses['d'].append(0.05 * d_loss.item())
            losses['r'].append(sum(r_losses).item())
            opt.zero_grad()
            loss.backward()
            opt.step()
        for k in losses:
            writer.add_scalar(f'train_loss/{k}', sum(losses[k])/len(losses[k]), epoch)

        # validation
        if epoch % 10 == 0:
            with torch.no_grad():
                model.eval()
                accs = []
                for i, t in enumerate(tasks):
                    test_loss = n_acc = n_all = 0
                    for j, batch in enumerate(loader.test[t]):
                        x, y = batch.text.to(device), batch.label.to(device)
                        y_ = model.val(x, i)
                        loss = criterion(y_, y)
                        test_loss += loss.item()
                        n_acc += y_.argmax(1).eq(y).sum()
                        n_all += y.shape[0]
                    acc = n_acc / float(n_all)
                    accs.append(acc)
                    writer.add_scalar(f'loss/test/{t}', test_loss, epoch)
                    writer.add_scalar(f'acc/test/{t}', acc, epoch)
                writer.add_scalar('avg_acc', sum(accs)/len(accs), epoch)


def mgda_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt_c, opt_g], schedulers = get_opt(model, rep=True)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        d = {t: 0 for t in tasks}
        for n in range(n_batches):
            grads = {}
            losses = {}
            xs = {}
            ys = {}
            for t in tasks:
                batch = next(iters[t])
                d[t] += 1
                if batch.text.shape[1] != cfg.data['batch_size']:
                    batch = next(iters[t])
                    d[t] += 1
                x, y = batch.text.to(device), batch.label.to(device)
                xs[t] = x
                ys[t] = y
                with torch.no_grad():
                    rep = model['rep'](x)
                rep = rep.clone().requires_grad_()
                y_ = model[t](rep)
                loss = criterion(y_, y)
                losses[t] = loss.item()
                opt_c.zero_grad()
                loss.backward()
                grads[t] = [rep.grad.clone().requires_grad_(False)]
            gn = gradient_normalizers(grads, losses, 'loss+')
            grads = {t: grads[t] / gn[t] for t in grads}
            sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
            scales = {t: sol[i] for i, t in enumerate(tasks)}
            for t in tasks:
                y_ = model[t](model['rep'](xs[t]))
                loss = criterion(y_, ys[t]) * scales[t]
                opt_g.zero_grad()
                opt_c.zero_grad()
                loss.backward()
                opt_g.step()
                opt_c.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        scheduler_step(schedulers)


def uniform_multi_task():
    n_batches = min([len(loader.train[t].dataset.examples) for t in tasks]) // cfg.data['batch_size']
    model = init_model()
    [opt], schedulers = get_opt(model)
    for epoch in tqdm(range(max_epoch), unit='epoch'):
        # train
        for m in model:
            model[m].train()
        iters = {t: iter(loader.train[t]) for t in tasks}
        for i in range(n_batches):
            shuffled_tasks = tasks.copy()   # shuffle task order every batch
            random.shuffle(shuffled_tasks)
            for t in shuffled_tasks:
                batch = iters[t].__next__()
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss = criterion(y_, y) / len(tasks)
                opt.zero_grad()
                loss.backward()
                opt.step()

        # validation
        if epoch % 10 == 0:
            validate_all_tasks(model, epoch)

        if epoch % 500 == 0 and epoch > 0:
            state_dict = {'model': {}}
            for key in model:
                state_dict['model'][key] = model[key].state_dict()
            os.makedirs(f'checkpoints/{args.d}/uniform/', exist_ok=True)
            torch.save(state_dict, f'checkpoints/{args.d}/uniform/{epoch}')

        scheduler_step(schedulers)


def single_task():
    model = init_model()
    [opt], schedulers = get_opt(model)
    for t in tasks:
        for epoch in tqdm(range(max_epoch), unit='epoch', postfix=t):
            # train
            for m in model:
                model[m].train()
            train_loss = []
            for i, batch in enumerate(loader.train[t]):
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss = criterion(y_, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_loss.append(loss.item())
            scheduler_step(schedulers)
            writer.add_scalar(f'train_loss/{t}', sum(train_loss)/len(train_loss), epoch)

            # validation
            if epoch % 10 == 0:
                with torch.no_grad():
                    for m in model:
                        model[m].eval()
                    test_loss = n_acc = n_all = 0
                    for i, batch in enumerate(loader.test[t]):
                        x, y = batch.text.to(device), batch.label.to(device)
                        y_ = model[t](model['rep'](x))
                        loss = criterion(y_, y)
                        test_loss += loss.item()
                        n_acc += y_.argmax(1).eq(y).sum()
                        n_all += y.shape[0]
                    acc = n_acc / float(n_all)
                    writer.add_scalar(f'loss/test/{t}', test_loss, epoch)
                    writer.add_scalar(f'acc/test/{t}', acc, epoch)

            if epoch % 100 == 0 and epoch > 0:
                state_dict = {'model': {}}
                for key in ['rep', t]:
                    state_dict['model'][key] = model[key].state_dict()
                os.makedirs(f'checkpoints/{args.d}/single/', exist_ok=True)
                torch.save(state_dict, f'checkpoints/{args.d}/single/{t}_{epoch}')
        # re-init model & optimizer
        model = init_model()
        [opt], schedulers = get_opt(model)


def validate_all_tasks(model, epoch):
    with torch.no_grad():
        for m in model:
            model[m].eval()
        accs = []
        for t in tasks:
            test_loss = n_acc = n_all = 0
            for i, batch in enumerate(loader.test[t]):
                x, y = batch.text.to(device), batch.label.to(device)
                y_ = model[t](model['rep'](x))
                loss = criterion(y_, y)
                test_loss += loss.item()
                n_acc += y_.argmax(1).eq(y).sum()
                n_all += y.shape[0]
            acc = n_acc / float(n_all)
            accs.append(acc)
            writer.add_scalar(f'loss/test/{t}', test_loss, epoch)
            writer.add_scalar(f'acc/test/{t}', acc, epoch)
        writer.add_scalar('avg_acc', sum(accs)/len(accs), epoch)


if __name__ == '__main__':
    if args.t == 'single':
        single_task()
    elif args.t == 'uniform':
        uniform_multi_task()
    elif args.t == 'mgda':
        mgda_multi_task()
    elif args.t == 'fudan':
        fudan_multi_task()
    elif args.t == 'tchebycheff':
        tchebycheff_multi_task()
    elif args.t == 'tchebycheff_adv':
        tchebycheff_adv_multi_task()

    print('exit')
