import argparse
import os

from transformers.models.bert.tokenization_bert import BertTokenizer
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import numpy as np

from models.common.msu import MSU
from tasks.commons.train_common import eval_loss_manager
from utils.data_loader import ExecMode, ModelInput, prepare_dataloader
from logging import getLogger

logger = getLogger(__name__)


def exec_test_batch(model: MSU, dataloader: DataLoader, args: argparse.Namespace, device):
  model.eval()
  stepper, finalizer = eval_loss_manager()
  predict_lists = []
  truth_lists = []

  with torch.no_grad():
    batch: ModelInput
    for batch in tqdm(dataloader, desc="Prediction"):
      input_label_id = batch.labels
      logits, loss, w_topic_loss, u_topic_loss = model(batch.to(device))
      stepper(loss, w_topic_loss, u_topic_loss)

      logits_det: np.ndarray = logits.detach().cpu().numpy()

      for i, logit in enumerate(logits_det):
        predict_lists.append(logit.tolist())
        truth_lists.append(input_label_id[i].detach().cpu().numpy().tolist())

  return finalizer() + (predict_lists, truth_lists)


def test(model: MSU, tokenizer: BertTokenizer, datasource, args: argparse.Namespace, device):
  logger.info("=============== test starting ===============")
  dataloader = prepare_dataloader(tokenizer, datasource, args, ExecMode.TEST)

  test_loss, u_topic_test_loss, w_topic_test_loss, predict_lists, truth_lists = exec_test_batch(model, dataloader, args, device)

  result = {
    'test_loss': test_loss,
    'u_topic_test_loss': u_topic_test_loss,
    'w_topic_test_loss': w_topic_test_loss,
  }

  logger.info("***** test result *****")
  for key in sorted(result.keys()):
      logger.info("  %s = %s", key, str(result[key]))

  result['predict_lists'] = predict_lists
  result['truth_lists'] = truth_lists

  with open(os.path.join(args.dest, 'log', "test_loss_{}.log".format(args.num_train_epochs)), "w") as writer:
      writer.write('test_loss: ' + str(test_loss) + '\n')
      writer.write('u_topic_test_loss: ' + str(u_topic_test_loss) + '\n')
      writer.write('w_topic_test_loss: ' + str(w_topic_test_loss) + '\n')

  return result