from io import TextIOWrapper
from typing import Callable, OrderedDict
import torch
import os
from copy import deepcopy

from torch import nn
from optimizer.radam import RAdam
from optimizer.sam import SAM
from contextlib import ExitStack
from logging import getLogger
from torch.optim.optimizer import Optimizer


logger = getLogger(__name__)

def loss_manager(args):
  ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss = 0, 0, 0
  steps = 0
  def step(loss, w_topic_loss, u_topic_loss, update_step=True):
    nonlocal ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss, steps

    loss.backward(retain_graph=True)
    if w_topic_loss is not None:
      w_topic_loss.backward(retain_graph=True)
    if u_topic_loss is not None:
      u_topic_loss.backward()

    if update_step:
      ttl_loss += (loss / args.gradient_accumulation_steps).item()
      if w_topic_loss is not None:
        ttl_w_topic_loss += (w_topic_loss / args.gradient_accumulation_steps).item()
      if u_topic_loss is not None:
        ttl_u_topic_loss += (u_topic_loss / args.gradient_accumulation_steps).item()
      steps += 1

      return ttl_loss / steps, ttl_w_topic_loss / steps, ttl_u_topic_loss / steps

  def finalize():
    nonlocal ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss, steps

    loss = ttl_loss / steps
    w_topic_loss = ttl_w_topic_loss / steps
    u_topic_loss = ttl_u_topic_loss / steps

    return loss, w_topic_loss, u_topic_loss

  return step, finalize

def eval_loss_manager():
  ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss = 0, 0, 0
  steps = 0
  def step(loss, w_topic_loss, u_topic_loss):
    nonlocal ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss, steps

    ttl_loss += loss.item()
    if w_topic_loss is not None:
      ttl_w_topic_loss += w_topic_loss.item()
    if u_topic_loss is not None:
      ttl_u_topic_loss += u_topic_loss.item()
    steps += 1

  def finalize():
    nonlocal ttl_loss, ttl_w_topic_loss, ttl_u_topic_loss, steps

    loss = ttl_loss / steps
    w_topic_loss = ttl_w_topic_loss / steps
    u_topic_loss = ttl_u_topic_loss / steps

    return loss, w_topic_loss, u_topic_loss

  return step, finalize

def optimizer_manager(args, optimizer: Optimizer, initial_global_step: int = 0):
  steps = 0
  global_step = initial_global_step

  def step(closure: "Callable[[bool], dict[str, float]]" = None, force=False):
    nonlocal steps, global_step
    steps += 1
    if steps % args.gradient_accumulation_steps == 0 or force:
      values = optimizer.step(closure)
      optimizer.zero_grad()
      if values is None: return
      print("") # just an empty line
      logger.info("paramter is updated!")
      for k, v in values.items():
        if v is None: continue
        logger.info(f"step {steps} {k} is {v.item()}")
      global_step += 1
    else:
      closure and closure()


  def finalize():
    return (global_step, )

  return step, finalize


def retrieve_optimizer(args, model=None, params=None):
  optimizer = RAdam(model.parameters() if model is not None else params, lr=args.learning_rate, eps=1e-6, weight_decay=0)
  return optimizer


def epoch_manager(model: nn.Module, config_name, weights_name, args):
  min_loss = float('inf')
  min_train_loss = float('inf')
  stalled_sequences = 0
  train_loss_results = []
  eval_loss_results = []
  best_state_dict: OrderedDict = None
  best_train_state_dict: OrderedDict = None
  best_result: dict = None
  step = 0
  def step_result(eval_loss, train_loss, result, train_result, eval_result) -> bool:
    nonlocal min_loss, stalled_sequences, best_state_dict, best_result, step, min_train_loss, best_train_state_dict
    train_loss_results.append(train_result)
    eval_loss_results.append(eval_result)
    step += 1

    logger.info("***** Eval results *****")
    for key in result.keys():
      logger.info("  %s = %s", key, str(result[key]))

    if train_loss<min_train_loss:
      logger.info("Better train loss has been retrieved!")
      min_train_loss = train_loss
      best_train_state_dict = deepcopy((model.module if hasattr(model, 'module') else model).state_dict())
      stalled_sequences = 0
    else:
      stalled_sequences += 1
      logger.info(f"Epoch {step} training seems to be stalled. current sequences: {stalled_sequences} and limit is {args.max_stalled_attempt}")

    output_latest_file = os.path.join(args.dest, "latest_results.txt")
    with open(output_latest_file, "w") as writer:
      for key in result.keys():
        writer.write("%s = %s\n" % (key, str(result[key])))
      writer.write("============ eval results ============\n")
      for key in eval_result.keys():
        writer.write("%s = %s\n" % (key, str(eval_result[key])))

    if eval_loss<min_loss:
      logger.info("Better loss has been retrieved!")
      min_loss = eval_loss
      best_state_dict = (model.module if hasattr(model, 'module') else model).state_dict()
      best_state_dict = {k: v.detach().cpu() for k, v in best_state_dict.items()}
      best_result = deepcopy(result)
      output_eval_file = os.path.join(args.dest, "eval_results.txt")
      with open(output_eval_file, "w") as writer:
        for key in result.keys():
          writer.write("%s = %s\n" % (key, str(result[key])))


    if stalled_sequences >= args.max_stalled_attempt:
      logger.info("Sequential stalled count exceeded. Making earlly stop for this training task!")

    return stalled_sequences >= args.max_stalled_attempt

  def log_out_path(type):
    return os.path.join(args.dest, 'log', "{}_{}.log".format(type, args.num_train_epochs))

  def write_result(stack: ExitStack, results: "list[dict[str, float]]"):
    result_dict = results[0]
    writer_dict: dict[str, TextIOWrapper] = {}
    for key in result_dict.keys():
      writer_dict[key] = stack.enter_context(open(log_out_path("train_word_topic_loss"), "w"))
    for result in results:
      for key in writer_dict.keys(): writer_dict[key].write(str(result[key]) + '\n')

  def finalize():
    with ExitStack() as stack:
      write_result(stack, train_loss_results)
      write_result(stack, eval_loss_results)

    return best_result, best_state_dict, best_train_state_dict

  return step_result, finalize