import torch
import numpy as np
import json, random
from functools import reduce
from pprint import pprint

__all__ = [
    "ConfLoader",
    "overwrite_config",
    "pprint_config",
    "random_seeder",
]


class ConfLoader:
    """Load JSON config file, allowing attribute access to dictionary items."""

    class DictWithAttributeAccess(dict):
        """Allows dictionary items to be accessed as attributes."""

        def __getattr__(self, item):
            return self.get(item)

        def __setattr__(self, key, value):
            self[key] = value

    def __init__(self, conf_name):
        with open(conf_name, "r") as conf_file:
            self.opt = json.load(
                conf_file, object_hook=ConfLoader.DictWithAttributeAccess
            )


def overwrite_config(opt, args, update_paths):
    """Overwrite configuration using argparse"""

    for arg, path in update_paths.items():
        if (value := getattr(args, arg, None)) is not None:
            target = reduce(getattr, path[:-1], opt)
            setattr(target, path[-1], value)

    return opt


def pprint_config(opt):
    print("\n" + "=" * 50 + " Configuration " + "=" * 50)
    pprint(opt, compact=True)
    print("=" * 115 + "\n")


def random_seeder(seed):
    """Fix randomness."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
