import argparse
from logging import getLogger
import pickle
from models.wa_msf.model import WA_MSF
from utils.task_config import TaskConfig
from utils.data_loader import ModalityDatasource
from utils.task_common import init_logger, result_writeout, save_weight, setup_device, setup_directory
from transformers import AutoTokenizer, BertTokenizer

from tasks.test import test
from tasks.train import train
from utils.arg_parser import parse_arg
from utils.result_calculator import calc_result_multi, calc_result_classify
import torch
import gc

logger = getLogger(__name__)

def preapare_model(cfg: TaskConfig) -> "tuple[torch.nn.Module, BertTokenizer]":
  model = WA_MSF(cfg)
  # embedding layer in bert must be frozen...
  for param in model.bert.embeddings.parameters():
    param.requires_grad = False
  # parameter freezing on BERT model itself completely if to do so
  if cfg.no_bert_tuning:
    for param in model.bert.encoder.parameters():
        param.requires_grad = False
  # otherwise, first 2 layer in encoder is target to freeze if less_bert_tuning is true
  elif cfg.less_bert_tuning:
    for layer in model.bert.encoder.layer[:2]:
      for param in layer.parameters():
        param.requires_grad = False
  elif cfg.bert_freeze_layers > 0:
    for layer in model.bert.encoder.layer[:cfg.bert_freeze_layers]:
      for param in layer.parameters():
        param.requires_grad = False

  tokenizer = AutoTokenizer.from_pretrained(cfg.bert_model)

  if len(cfg.pretrained_path) != 0:
    model.bert.load_state_dict(torch.load(cfg.pretrained_path))

  return model, tokenizer


def load_pickles(config: argparse.Namespace):
  train_ds = ModalityDatasource()
  eval_ds = ModalityDatasource()
  test_ds = ModalityDatasource()
  assert not (config.no_audio and config.no_video)

  if not config.no_audio:
    train_ds.audio_data, eval_ds.audio_data, test_ds.audio_data = pickle.load(
      open(config.src_audio_raw if config.no_align else config.src_audio, mode='rb'))

  if not config.no_video:
    train_ds.video_data, eval_ds.video_data, test_ds.video_data = pickle.load(
      open(config.src_video_raw if config.no_align else config.src_video, mode='rb'))

  return (train_ds, eval_ds, test_ds)

def main(args: argparse.Namespace, device):

  model, tokenizer = preapare_model(args)
  if args.param_confirm:
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(f'total parameter: {pytorch_total_params}')
    exit()
  model.to(device)
  train_data, eval_data, test_data = load_pickles(args)

  train_result, best_state_dict, best_train_state_dict = train(
    model, tokenizer, train_data, eval_data, args, device, test_datasrc=test_data
  )
  logger.info("*** train result ***")
  logger.info(train_result)

  if not args.online_test:
    logger.info("picking best weight state from training phase...")
    model.load_state_dict({k: v.to(device) for k, v in best_state_dict.items()})
    test_result = test(model, tokenizer, test_data, args, device)

    logger.info("picking train best weight state from training phase...")
    model.load_state_dict({k: v.to(device) for k, v in best_train_state_dict.items()})
    test_result_train = test(model, tokenizer, test_data, args, device)

    if test_result['test_loss'] <= test_result_train['test_loss']:
      logger.info("a best score by eval is selected as a result weight")
      result = test_result
      if not args.no_model_save: save_weight(best_state_dict, args)
    else:
      logger.info("a best score by train is selected as a result weight")
      result = test_result_train
      if not args.no_model_save: save_weight(best_train_state_dict, args)

    del best_state_dict
    del best_train_state_dict

  else:
    model.load_state_dict(best_state_dict)
    result = test(model, tokenizer, test_data, args, device)
    if not args.no_model_save: save_weight(best_state_dict, args)
    del best_state_dict

  gc.collect()

  return {
    'classification': calc_result_classify(result['predict_lists'], result['truth_lists'])
  } if args.loss_type == 'classify' else calc_result_multi(result['predict_lists'], result['truth_lists'], args.category)


if __name__ == '__main__':
  init_logger()
  args = parse_arg()
  device = setup_device(args)

  original_dest = args.dest
  for i in range(args.attempt):
    torch.cuda.empty_cache()
    logger.info("start to exec attempt: {}".format(i+1))

    setup_directory(args, original_dest, i)
    results = main(args, device)
    result_writeout(args, results)