import argparse
from logging import getLogger
from utils.data_loader import ModalityDatasource
from utils.task_common import init_logger, result_writeout, save_weight, setup_device, setup_directory
from transformers import AutoTokenizer, BertTokenizer

import torch
from models.bert.bert_classifier import BertClassifier
from tasks.test import test
from tasks.train import train
from utils.arg_parser import parse_arg
from utils.result_calculator import calc_result_multi

logger = getLogger(__name__)

def preapare_model(config: argparse.Namespace) -> "tuple[BertClassifier, BertTokenizer]":
  model = BertClassifier(config.bert_model, class_size=len(args.category), loss_type=args.loss_type)
  enc = model.classification.roberta if 'berta' in config.bert_model else model.classification.transformer if 'xlnet' in config.bert_model else model.classification.bert
  # embedding layer in bert must be frozen...
  if 'xlnet' in config.bert_model:
     for param in enc.word_embedding.parameters():
      param.requires_grad = False
  else:
    for param in enc.embeddings.parameters():
      param.requires_grad = False
  # parameter freezing on BERT model itself completely if to do so
  if config.no_bert_tuning:
    for param in enc.encoder.parameters():
        param.requires_grad = False
  # otherwise, first 2 layer in encoder is target to freeze if less_bert_tuning is true
  elif config.less_bert_tuning:
    for layer in enc.encoder.layer[:2]:
      for param in layer.parameters():
        param.requires_grad = False
  elif config.bert_freeze_layers > 0:
    for layer in enc.encoder.layer[:config.bert_freeze_layers]:
      for param in layer.parameters():
        param.requires_grad = False

  tokenizer = AutoTokenizer.from_pretrained(config.bert_model)
  if len(config.pretrained_path) != 0:
    enc.load_state_dict(torch.load(config.pretrained_path))

  return model, tokenizer


def load_pickles(config: argparse.Namespace):
  train_ds = ModalityDatasource()
  eval_ds = ModalityDatasource()
  test_ds = ModalityDatasource()
  return (train_ds, eval_ds, test_ds)

def get_model(model):
  try:
    return model.classification.bert
  except:
    return model.classification.roberta

def main(args: argparse.Namespace, device):

  model, tokenizer = preapare_model(args)
  model.to(device)
  train_data, eval_data, test_data = load_pickles(args)

  train_result, best_state_dict, best_train_state_dict = train(
    model, tokenizer, train_data, eval_data, args, device, test_datasrc=test_data
  )
  logger.info("*** train result ***")
  logger.info(train_result)

  if not args.online_test:
    logger.info("picking best weight state from training phase...")
    model.load_state_dict(best_state_dict)
    test_result = test(
      model, tokenizer, test_data, args, device
    )

    logger.info("picking train best weight state from training phase...")
    model.load_state_dict(best_train_state_dict)
    test_result_train = test(
      model, tokenizer, test_data, args, device
    )

    if test_result['test_loss'] <= test_result_train['test_loss']:
      logger.info("a best score by eval is selected as a result weight")
      result = test_result
      model.load_state_dict(best_state_dict)
      if not args.no_model_save: save_weight(get_model(model).state_dict(), args)
    else:
      logger.info("a best score by train is selected as a result weight")
      result = test_result_train
      model.load_state_dict(best_train_state_dict)
      if not args.no_model_save: save_weight(get_model(model).state_dict(), args)
  else:
    model.load_state_dict(best_state_dict)
    result = test(
      model, tokenizer, test_data, args, device
    )
    if not args.no_model_save: save_weight(get_model(model).state_dict(), args)

  return calc_result_multi(result['predict_lists'], result['truth_lists'], args.category)


if __name__ == '__main__':
  init_logger()
  args = parse_arg()
  device = setup_device(args)

  original_dest = args.dest
  for i in range(args.attempt):
    logger.info("start to exec attempt: {}".format(i+1))

    setup_directory(args, original_dest, i)
    results = main(args, device)
    result_writeout(args, results)