Skip to content

Commit

Permalink
fix: Dataloader caches batch functions correctly (#3058)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Nov 8, 2024
1 parent 7b6072a commit 99ddaad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
File renamed without changes.
17 changes: 11 additions & 6 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
Final,
Generic,
NamedTuple,
Expand Down Expand Up @@ -745,19 +746,23 @@ def get_loader(

@staticmethod
def _get_func_key(
func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]],
func: Callable[Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT]],
**kwargs,
) -> int:
return hash(func)
func_and_kwargs = (func, *[(k, kwargs[k]) for k in sorted(kwargs.keys())])
return hash(func_and_kwargs)

def get_loader_by_func(
self,
context: ContextT,
batch_load_func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]],
batch_load_func: Callable[
Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT]
],
# Using kwargs-only to prevent argument position confusion
# when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **extra_args)(keys)`.
**kwargs: Any,
# when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **kwargs)(keys)`.
**kwargs,
) -> DataLoader:
key = self._get_func_key(batch_load_func)
key = self._get_func_key(batch_load_func, **kwargs)
loader = self.cache.get(key)
if loader is None:
loader = DataLoader(
Expand Down

0 comments on commit 99ddaad

Please sign in to comment.