import torch
import argparse
import logging
import logging.config
from logging import getLogger
import yaml
import numpy as np
import os
import random
import json

logger = getLogger(__name__)

def init_logger():
  with open('config/logging/config.yaml','rt') as f:
    config=yaml.safe_load(f.read())
  logging.config.dictConfig(config)

def setup_device(config: argparse.Namespace):
  device = torch.device(f"cuda:{config.gpu_id}" if torch.cuda.is_available() and not config.no_cuda else "cpu")
  seed_num = config.ramdom_seed if config.ramdom_seed is not None else np.random.randint(1,10000)
  logger.info(f"current seed num is: {seed_num}")
  os.environ["PYTHONHASHSEED"] = str(seed_num)
  np.random.seed(seed_num)
  random.seed(seed_num)
  torch.manual_seed(seed_num)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  return device

def setup_directory(args: argparse.Namespace, original_dest: str, attempt: int):
  out_dir = os.path.join(original_dest, f'attempt_{attempt+1:03}')
  args.dest = out_dir
  if os.path.exists(args.dest): os.system("rm -rf %s" % args.dest)
  os.makedirs(args.dest)
  os.makedirs(f"{args.dest}/log")

class NumpyEncoder(json.JSONEncoder):
  def default(self, obj):
    if isinstance(obj, np.integer):
      return int(obj)
    elif isinstance(obj, np.floating):
      return float(obj)
    elif isinstance(obj, np.ndarray):
      return obj.tolist()
    else:
      return super(NumpyEncoder, self).default(obj)

def result_writeout(args: argparse.Namespace, results: dict, name_override=None):
  with open(os.path.join(args.dest, name_override or "test_result.json"), "w") as writer:
    writer.write(json.dumps(results, ensure_ascii=False, indent=2, cls=NumpyEncoder))

def save_weight(state_dict, args):
    output_model_file = os.path.join(args.dest, "best_weights.bin")
    torch.save(state_dict, output_model_file)
    # output_config_file = os.path.join(args.dest, config_name)
    # with open(output_config_file, 'w') as f:
    #   f.write(model.bert.config.to_json_string())
