from argparse import Namespace
from statistics import mean
import torch

from utils.result_calculator import calc_result_multi, calc_result_classify
from tasks.commons.train_common import epoch_manager, eval_loss_manager, loss_manager, optimizer_manager, retrieve_optimizer
from tasks.test import exec_test_batch
from tqdm import trange, tqdm

from models.common.msu import MSU
from utils.data_loader import ExecMode, ModelInput, prepare_dataloader
from logging import getLogger

logger = getLogger(__name__)
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'

def exec_train_batch(model: MSU, optim, dataloader, g_step, args: Namespace, device):
  model.train()
  stepper, finalizer = loss_manager(args)
  opt_stepper, opt_finalizer = optimizer_manager(args, optim, g_step)

  i = 0
  batch: ModelInput
  for batch in tqdm(dataloader, desc="Iteration"):
    i+=1
    def closure(update_step=True):
      # with torch.autograd.detect_anomaly():
      _, loss, w_topic_loss, u_topic_loss = model(batch.to(device))
      stepper(loss, w_topic_loss, u_topic_loss, update_step)
      return dict(loss=loss, w_topic_loss=w_topic_loss, u_topic_loss=u_topic_loss)
    opt_stepper(closure, i == len(dataloader))

  return opt_finalizer() + finalizer()


def exec_eval_batch(model: MSU, dataloader, args: Namespace, device):
  model.eval()
  stepper, finalizer = eval_loss_manager()

  batch: ModelInput
  for batch in tqdm(dataloader, desc="Evaluation"):
    with torch.no_grad(): _, loss, w_topic_loss, u_topic_loss = model(batch.to(device))
    stepper(loss, w_topic_loss, u_topic_loss)

  return finalizer()


def train(model: MSU, tokenizer, train_datasrc, eval_datasrc, args: Namespace, device, test_datasrc=None):
  logger.info("=============== train starting ===============")

  train_dataloader = prepare_dataloader(tokenizer, train_datasrc, args, ExecMode.TRAIN)
  eval_dataloader = prepare_dataloader(tokenizer, eval_datasrc, args, ExecMode.EVAL)

  if args.online_test:
    test_dataloader = prepare_dataloader(tokenizer, test_datasrc, args, ExecMode.TEST)

  # prepare learning params and results
  stepper, finalizer = epoch_manager(model, CONFIG_NAME, WEIGHTS_NAME, args)
  global_step = 0
  optimizer = retrieve_optimizer(args, model=model)

  # exec train epochs
  for i in trange(int(args.num_train_epochs), desc="Epoch"):
    logger.info("--------------- starting epoch #{} ---------------".format(i+1))

    global_step, loss, w_topic_loss, u_topic_loss = exec_train_batch(model, optimizer, train_dataloader, global_step, args, device)
    eval_loss, eval_w_topic_loss, eval_u_topic_loss = exec_eval_batch(model, eval_dataloader, args, device)
    result = {'global_step': global_step, 'loss': loss, 'eval_loss': eval_loss}
    train_result = { 'loss': loss, 'w_topic_loss': w_topic_loss, 'u_topic_loss': u_topic_loss }
    eval_result = { 'eval_loss': eval_loss, 'eval_w_topic_loss': eval_w_topic_loss, 'eval_u_topic_loss': eval_u_topic_loss }

    if args.online_test:
      test_loss, _, _, predict_lists, truth_lists = exec_test_batch(model, test_dataloader, args, device)
      if args.loss_type == 'classify':
        cls_result = calc_result_classify(predict_lists, truth_lists, do_log=False)
        eval_result['acc'] = result['test_acc'] = cls_result['acc']
        eval_result['F1'] = result['test_f1'] = cls_result['F1']
        eval_result['F1(Macro)'] = result['test_f1_macro'] = cls_result['F1(Macro)']
      else:
        results = calc_result_multi(predict_lists, truth_lists, args.category, do_log=False)
        mae = mean([item['mae'] for item in results.values()])
        acc7 = mean([item['acc7'] for item in results.values()])
        acc7_f1 = mean([item['F1(acc7)'] for item in results.values()])
        eval_result['test_mae'] = result['test_mae'] = mae
        eval_result['test_loss'] = result['test_loss'] = test_loss
        eval_result['test_acc7'] = result['test_acc7'] = acc7
        eval_result['test_acc7_f1'] = result['test_acc7_f1'] = acc7_f1

    if stepper((1 - result['test_acc'] + 1 - result['test_f1']) if args.loss_type == 'classify' else result['test_mae'] if args.online_test else eval_loss, loss, result, train_result, eval_result): break

  return finalizer()