import functools
import threading


from experiments.client_config import max_concurrency_dict, get_client_type

max_occ_dict = {key: int(value * 1.5) for key, value in max_concurrency_dict.items()}


class Resource(threading.Semaphore):
    def acquire(self, n=1, blocking=True, timeout=None):
        """
        Success if previous _value > 0.
        Allow negative.
        """
        if not blocking and timeout is not None:
            raise ValueError("can't specify timeout for non-blocking acquire")
        rc = False
        endtime = None
        with self._cond:
            while self._value <= 0:
                if not blocking:
                    break
                if timeout is not None:
                    if endtime is None:
                        endtime = _time() + timeout
                    else:
                        timeout = endtime - _time()
                        if timeout <= 0:
                            break
                self._cond.wait(timeout)
            else:
                self._value -= n
                rc = True
        return rc

    __enter__ = acquire

    def wait_till_available(self, n=1):
        with self._cond:
            self._cond.wait_for(lambda: self._value >= n)


req_resource_dict = {
    key: Resource(value) for key, value in max_concurrency_dict.items()
}

occ_resource_dict = {key: Resource(value) for key, value in max_occ_dict.items()}

def find_model_str_helper(a):
    # already a string
    if isinstance(a, str):
        return a
    if (
        isinstance(a, dict)
        and "completion_kwargs" in a
        and "model" in a["completion_kwargs"]
    ):
        return a["completion_kwargs"]["model"]
    raise ValueError("Could not find model string")


def get_req_resource(maybe_model_str: any):
    model_str = find_model_str_helper(maybe_model_str)
    client_type = get_client_type(model_str)
    return req_resource_dict[client_type]


def get_occ_resource(maybe_model_str: any):
    model_str = find_model_str_helper(maybe_model_str)
    client_type = get_client_type(model_str)
    return occ_resource_dict[client_type]


def get_max_req(maybe_model_str: any):
    model_str = find_model_str_helper(maybe_model_str)
    client_type = get_client_type(model_str)
    return max_concurrency_dict[client_type]


def get_max_occ(maybe_model_str: any):
    model_str = find_model_str_helper(maybe_model_str)
    client_type = get_client_type(model_str)
    return max_occ_dict[client_type]
