diff --git a/changes/.fix.md b/changes/3058.fix.md similarity index 100% rename from changes/.fix.md rename to changes/3058.fix.md diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index be186358b0..a817c35446 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -18,6 +18,7 @@ TYPE_CHECKING, Any, ClassVar, + Concatenate, Final, Generic, NamedTuple, @@ -745,19 +746,23 @@ def get_loader( @staticmethod def _get_func_key( - func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]], + func: Callable[Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT]], + **kwargs, ) -> int: - return hash(func) + func_and_kwargs = (func, *[(k, kwargs[k]) for k in sorted(kwargs.keys())]) + return hash(func_and_kwargs) def get_loader_by_func( self, context: ContextT, - batch_load_func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]], + batch_load_func: Callable[ + Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT] + ], # Using kwargs-only to prevent argument position confusion - # when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **extra_args)(keys)`. - **kwargs: Any, + # when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **kwargs)(keys)`. + **kwargs, ) -> DataLoader: - key = self._get_func_key(batch_load_func) + key = self._get_func_key(batch_load_func, **kwargs) loader = self.cache.get(key) if loader is None: loader = DataLoader(