Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize base.pyx #493

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions a_sync/a_sync/base.pyx
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# cython: boundscheck=False
import functools
import inspect
from contextlib import suppress
from functools import cached_property
from inspect import signature, _empty
from logging import DEBUG, getLogger
from libc.stdint cimport uintptr_t

from a_sync import exceptions
from a_sync._typing import *
from a_sync.a_sync._flags cimport validate_and_negate_if_necessary, validate_flag_value
from a_sync.a_sync.abstract import ASyncABC
from a_sync.a_sync.flags import VIABLE_FLAGS
from a_sync.exceptions import ASyncFlagException, FlagNotDefined, InvalidFlag, NoFlagsFound, TooManyFlags


logger = getLogger(__name__)

cdef object c_logger = logger
cdef object _logger_is_enabled_for = logger.isEnabledFor
cdef object _logger_debug = logger.debug
cdef object _logger_log = logger._log


class ASyncGenericBase(ASyncABC):
Expand Down Expand Up @@ -63,12 +64,12 @@ class ASyncGenericBase(ASyncABC):
def __a_sync_default_mode__(cls) -> bint: # type: ignore [override]
cdef object flag
cdef bint flag_value
if not c_logger.isEnabledFor(DEBUG):
if not _logger_is_enabled_for(DEBUG):
# we can optimize this if we dont need to log `flag` and the return value
try:
flag = _get_a_sync_flag_name_from_signature(cls, False)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
except NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)
return validate_and_negate_if_necessary(flag, flag_value)
Expand All @@ -79,12 +80,12 @@ class ASyncGenericBase(ASyncABC):
try:
flag = _get_a_sync_flag_name_from_signature(cls, True)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
except NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)

sync = validate_and_negate_if_necessary(flag, flag_value)
c_logger._log(
_logger_log(
DEBUG,
"`%s.%s` indicates default mode is %ssynchronous",
(cls, flag, "a" if sync is False else ""),
Expand All @@ -99,81 +100,82 @@ class ASyncGenericBase(ASyncABC):
)
ASyncABC.__init__(self)

@functools.cached_property
@cached_property
def __a_sync_flag_name__(self) -> str:
# TODO: cythonize this cache
cdef bint debug_logs
if debug_logs := c_logger.isEnabledFor(DEBUG):
c_logger._log(DEBUG, "checking a_sync flag for %s", (self, ))
if debug_logs := _logger_is_enabled_for(DEBUG):
_logger_log(DEBUG, "checking a_sync flag for %s", (self, ))
try:
flag = _get_a_sync_flag_name_from_signature(type(self), debug_logs)
except exceptions.ASyncFlagException:
except ASyncFlagException:
# We can't get the flag name from the __init__ signature,
# but maybe the implementation sets the flag somewhere else.
# Let's check the instance's atributes
if debug_logs:
c_logger._log(
_logger_log(
DEBUG,
"unable to find flag name using `%s.__init__` signature, checking for flag attributes defined on %s",
(self.__class__.__name__, self),
)
present_flags = [flag for flag in VIABLE_FLAGS if hasattr(self, flag)]
if not present_flags:
raise exceptions.NoFlagsFound(self) from None
raise NoFlagsFound(self) from None
if len(present_flags) > 1:
raise exceptions.TooManyFlags(self, present_flags) from None
raise TooManyFlags(self, present_flags) from None
flag = present_flags[0]
if not isinstance(flag, str):
raise exceptions.InvalidFlag(flag)
raise InvalidFlag(flag)
return flag

@functools.cached_property
@cached_property
def __a_sync_flag_value__(self) -> bint:
# TODO: cythonize this cache
"""If you wish to be able to hotswap default modes, just duplicate this def as a non-cached property."""
cdef str flag = self.__a_sync_flag_name__
flag_value = getattr(self, flag)
c_logger.debug("`%s.%s` is currently %s", self, flag, flag_value)
_logger_debug("`%s.%s` is currently %s", self, flag, flag_value)
return validate_flag_value(flag, flag_value)



cdef str _get_a_sync_flag_name_from_class_def(object cls):
c_logger.debug("Searching for flags defined on %s", cls)
_logger_debug("Searching for flags defined on %s", cls)
try:
return _parse_flag_name_from_list(cls, cls.__dict__) # type: ignore [arg-type]
# idk why __dict__ doesn't type check as a dict
except exceptions.NoFlagsFound:
except NoFlagsFound:
for base in cls.__bases__:
with suppress(exceptions.NoFlagsFound):
return _parse_flag_name_from_list(cls, base.__dict__) # type: ignore [arg-type]
# idk why __dict__ doesn't type check as a dict
raise exceptions.NoFlagsFound(cls, list(cls.__dict__.keys()))
try:
return _parse_flag_name_from_list(cls, base.__dict__) # type: ignore [arg-type] idk why __dict__ doesn't type check as a dict
except NoFlagsFound:
pass
raise NoFlagsFound(cls, list(cls.__dict__.keys()))


cdef bint _a_sync_flag_default_value_from_signature(object cls):
cdef object signature = _get_init_signature(cls)
if not c_logger.isEnabledFor(DEBUG):
if not _logger_is_enabled_for(DEBUG):
# we can optimize this much better
return signature.parameters[_get_a_sync_flag_name_from_signature(cls, False)].default

c_logger._log(
_logger_log(
DEBUG, "checking `__init__` signature for default %s a_sync flag value", (cls, )
)
cdef str flag = _get_a_sync_flag_name_from_signature(cls, True)
cdef object flag_value = signature.parameters[flag].default
if flag_value is inspect._empty: # type: ignore [attr-defined]
if flag_value is _empty: # type: ignore [attr-defined]
raise NotImplementedError(
"The implementation for 'cls' uses an arg to specify sync mode, instead of a kwarg. We are unable to proceed. I suppose we can extend the code to accept positional arg flags if necessary"
)
c_logger._log(DEBUG, "%s defines %s, default value %s", (cls, flag, flag_value))
_logger_log(DEBUG, "%s defines %s, default value %s", (cls, flag, flag_value))
return flag_value


cdef str _get_a_sync_flag_name_from_signature(object cls, bint debug_logs):
if cls.__name__ == "ASyncGenericBase":
if debug_logs:
c_logger._log(
_logger_log(
DEBUG, "There are no flags defined on the base class, this is expected. Skipping.", ()
)
return ""
Expand All @@ -183,24 +185,24 @@ cdef str _get_a_sync_flag_name_from_signature(object cls, bint debug_logs):
# we can also skip assigning params to a var
return _parse_flag_name_from_list(cls, _get_init_signature(cls).parameters)

c_logger._log(DEBUG, "Searching for flags defined on %s.__init__", (cls, ))
_logger_log(DEBUG, "Searching for flags defined on %s.__init__", (cls, ))
cdef object parameters = _get_init_signature(cls).parameters
c_logger._log(DEBUG, "parameters: %s", (parameters, ))
_logger_log(DEBUG, "parameters: %s", (parameters, ))
return _parse_flag_name_from_list(cls, parameters)


cdef str _parse_flag_name_from_list(object cls, object items):
cdef str flag
cdef list[str] present_flags = [flag for flag in VIABLE_FLAGS if flag in items]
if not present_flags:
c_logger.debug("There are no flags defined on %s", cls)
raise exceptions.NoFlagsFound(cls, items.keys())
_logger_debug("There are no flags defined on %s", cls)
raise NoFlagsFound(cls, items.keys())
if len(present_flags) > 1:
c_logger.debug("There are too many flags defined on %s", cls)
raise exceptions.TooManyFlags(cls, present_flags)
if c_logger.isEnabledFor(DEBUG):
_logger_debug("There are too many flags defined on %s", cls)
raise TooManyFlags(cls, present_flags)
if _logger_is_enabled_for(DEBUG):
flag = present_flags[0]
c_logger._log(DEBUG, "found flag %s", (flag, ))
_logger_log(DEBUG, "found flag %s", (flag, ))
return flag
return present_flags[0]

Expand All @@ -210,16 +212,16 @@ cdef inline bint _get_a_sync_flag_value_from_class_def(object cls, str flag):
for spec in [cls, *cls.__bases__]:
if flag in spec.__dict__:
return spec.__dict__[flag]
raise exceptions.FlagNotDefined(cls, flag)
raise FlagNotDefined(cls, flag)


cdef dict[uintptr_t, object] _init_signature_cache = {}


cdef _get_init_signature(object cls):
cdef uintptr_t cls_init_id = id(cls.__init__)
signature = _init_signature_cache.get(cls_init_id)
if signature is None:
signature = inspect.signature(cls.__init__)
_init_signature_cache[cls_init_id] = signature
return signature
init_sig = _init_signature_cache.get(cls_init_id)
if init_sig is None:
init_sig = signature(cls.__init__)
_init_signature_cache[cls_init_id] = init_sig
return init_sig
Loading