Skip to content

Commit

Permalink
refactor(sdk): refactor paginator base types
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyyli-wandb committed Jan 11, 2025
1 parent 64f41bc commit 436be53
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 70 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Section headings should be at level 3 (e.g. `### Added`).

## Unreleased

### Changed

- Paginated methods (and underlying paginators) that accept a `per_page` argument now only accept `int` values. Default `per_page` values are set directly in method signatures, and explicitly passing `None` is no longer supported.

### Fixed

- Fix `wandb.Settings` update regression in `wandb.integration.metaflow` (@kptkin in https://github.com/wandb/wandb/pull/9211)
120 changes: 82 additions & 38 deletions wandb/apis/paginator.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,103 @@
from typing import TYPE_CHECKING, Any, MutableMapping, Optional
from __future__ import annotations

from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Iterator,
Mapping,
Protocol,
Sized,
TypeVar,
overload,
)

if TYPE_CHECKING:
from wandb_gql import Client
from wandb_graphql.language.ast import Document

T = TypeVar("T")


# Structural type hint for the client instance
class _Client(Protocol):
def execute(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ...


class Paginator(Iterator[T]):
"""An iterator for paginated objects from GraphQL requests."""

class Paginator:
QUERY = None
QUERY: ClassVar[Document | None] = None

def __init__(
self,
client: "Client",
variables: MutableMapping[str, Any],
per_page: Optional[int] = None,
client: _Client,
variables: Mapping[str, Any],
per_page: int = 50, # We don't allow unbounded paging
):
self.client = client
self.variables = variables
# We don't allow unbounded paging
self.per_page = per_page
if self.per_page is None:
self.per_page = 50
self.objects = []
self.index = -1
self.last_response = None
self.client: _Client = client

def __iter__(self):
self.index = -1
return self
# shallow copy partly guards against mutating the original input
self.variables: dict[str, Any] = dict(variables)

def __len__(self):
if self.length is None:
self._load_page()
if self.length is None:
raise ValueError("Object doesn't provide length")
return self.length
self.per_page: int = per_page
self.objects: list[T] = []
self.index: int = -1
self.last_response: object | None = None

@property
def length(self):
raise NotImplementedError
def __iter__(self) -> Iterator[T]:
self.index = -1
return self

@property
def more(self):
@abstractmethod
def more(self) -> bool:
"""Whether there are more pages to be fetched."""
raise NotImplementedError

@property
def cursor(self):
@abstractmethod
def cursor(self) -> str | None:
"""The start cursor to use for the next fetched page."""
raise NotImplementedError

def convert_objects(self):
@abstractmethod
def convert_objects(self) -> list[T]:
"""Convert the last fetched response data into the iterated objects."""
raise NotImplementedError

def update_variables(self):
def update_variables(self) -> None:
"""Update the query variables for the next page fetch."""
self.variables.update({"perPage": self.per_page, "cursor": self.cursor})

def _load_page(self):
if not self.more:
return False
self.update_variables()
def _update_response(self) -> None:
"""Fetch and store the response data for the next page."""
self.last_response = self.client.execute(
self.QUERY, variable_values=self.variables
)

def _load_page(self) -> bool:
"""Fetch the next page, if any, returning True and storing the response if there was one."""
if not self.more:
return False
self.update_variables()
self._update_response()
self.objects.extend(self.convert_objects())
return True

def __getitem__(self, index):
@overload
def __getitem__(self, index: int) -> T: ...
@overload
def __getitem__(self, index: slice) -> list[T]: ...

def __getitem__(self, index: int | slice) -> T | list[T]:
loaded = True
stop = index.stop if isinstance(index, slice) else index
while loaded and stop > len(self.objects) - 1:
loaded = self._load_page()
return self.objects[index]

def __next__(self):
def __next__(self) -> T:
self.index += 1
if len(self.objects) <= self.index:
if not self._load_page():
Expand All @@ -79,3 +107,19 @@ def __next__(self):
return self.objects[self.index]

next = __next__


class SizedPaginator(Paginator[T], Sized):
"""A Paginator for objects with a known total count."""

def __len__(self) -> int:
if self.length is None:
self._load_page()
if self.length is None:
raise ValueError("Object doesn't provide length")
return self.length

@property
@abstractmethod
def length(self) -> int | None:
raise NotImplementedError
20 changes: 8 additions & 12 deletions wandb/apis/public/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,15 +746,14 @@ def _parse_artifact_path(self, path):
return parts

def projects(
self, entity: Optional[str] = None, per_page: Optional[int] = 200
self, entity: Optional[str] = None, per_page: int = 200
) -> "public.Projects":
"""Get projects for a given entity.
Args:
entity: (str) Name of the entity requested. If None, will fall back to the
default entity passed to `Api`. If no default entity, will raise a `ValueError`.
per_page: (int) Sets the page size for query pagination. None will use the default size.
Usually there is no reason to change this.
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
Returns:
A `Projects` object which is an iterable collection of `Project` objects.
Expand Down Expand Up @@ -797,7 +796,7 @@ def project(self, name: str, entity: Optional[str] = None) -> "public.Project":
return public.Project(self.client, entity, name, {})

def reports(
self, path: str = "", name: Optional[str] = None, per_page: Optional[int] = 50
self, path: str = "", name: Optional[str] = None, per_page: int = 50
) -> "public.Reports":
"""Get reports for a given project path.
Expand All @@ -806,8 +805,7 @@ def reports(
Args:
path: (str) path to project the report resides in, should be in the form: "entity/project"
name: (str, optional) optional name of the report requested.
per_page: (int) Sets the page size for query pagination. None will use the default size.
Usually there is no reason to change this.
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
Returns:
A `Reports` object which is an iterable collection of `BetaReport` objects.
Expand Down Expand Up @@ -1093,15 +1091,14 @@ def artifact_type(

@normalize_exceptions
def artifact_collections(
self, project_name: str, type_name: str, per_page: Optional[int] = 50
self, project_name: str, type_name: str, per_page: int = 50
) -> "public.ArtifactCollections":
"""Return a collection of matching artifact collections.
Args:
project_name: (str) The name of the project to filter on.
type_name: (str) The name of the artifact type to filter on.
per_page: (int, optional) Sets the page size for query pagination. None will use the default size.
Usually there is no reason to change this.
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
Returns:
An iterable `ArtifactCollections` object.
Expand Down Expand Up @@ -1160,16 +1157,15 @@ def artifacts(
self,
type_name: str,
name: str,
per_page: Optional[int] = 50,
per_page: int = 50,
tags: Optional[List[str]] = None,
) -> "public.Artifacts":
"""Return an `Artifacts` collection from the given parameters.
Args:
type_name: (str) The type of artifacts to fetch.
name: (str) An artifact collection name. May be prefixed with entity/project.
per_page: (int, optional) Sets the page size for query pagination. None will use the default size.
Usually there is no reason to change this.
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
tags: (list[str], optional) Only return artifacts with all of these tags.
Returns:
Expand Down
22 changes: 10 additions & 12 deletions wandb/apis/public/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import wandb
from wandb.apis import public
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import Paginator
from wandb.apis.paginator import Paginator, SizedPaginator
from wandb.errors.term import termlog
from wandb.sdk.lib import deprecate

Expand Down Expand Up @@ -62,7 +62,7 @@
}"""


class ArtifactTypes(Paginator):
class ArtifactTypes(Paginator["ArtifactType"]):
QUERY = gql(
"""
query ProjectArtifacts(
Expand All @@ -85,7 +85,7 @@ def __init__(
client: Client,
entity: str,
project: str,
per_page: Optional[int] = 50,
per_page: int = 50,
):
self.entity = entity
self.project = project
Expand All @@ -98,7 +98,7 @@ def __init__(
super().__init__(client, variable_values, per_page)

@property
def length(self):
def length(self) -> None:
# TODO
return None

Expand Down Expand Up @@ -207,14 +207,14 @@ def __repr__(self):
return f"<ArtifactType {self.type}>"


class ArtifactCollections(Paginator):
class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
def __init__(
self,
client: Client,
entity: str,
project: str,
type_name: str,
per_page: Optional[int] = 50,
per_page: int = 50,
):
self.entity = entity
self.project = project
Expand Down Expand Up @@ -742,7 +742,7 @@ def __repr__(self):
return f"<ArtifactCollection {self._name} ({self._type})>"


class Artifacts(Paginator):
class Artifacts(SizedPaginator["wandb.Artifact"]):
"""An iterable collection of artifact versions associated with a project and optional filter.
This is generally used indirectly via the `Api`.artifact_versions method.
Expand Down Expand Up @@ -858,10 +858,8 @@ def convert_objects(self):
]


class RunArtifacts(Paginator):
def __init__(
self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
):
class RunArtifacts(SizedPaginator["wandb.Artifact"]):
def __init__(self, client: Client, run: "Run", mode="logged", per_page: int = 50):
from wandb.sdk.artifacts.artifact import _gql_artifact_fragment

output_query = gql(
Expand Down Expand Up @@ -976,7 +974,7 @@ def convert_objects(self):
]


class ArtifactFiles(Paginator):
class ArtifactFiles(SizedPaginator["public.File"]):
QUERY = gql(
"""
query ArtifactFiles(
Expand Down
4 changes: 2 additions & 2 deletions wandb/apis/public/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from wandb import util
from wandb.apis.attrs import Attrs
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import Paginator
from wandb.apis.paginator import SizedPaginator
from wandb.apis.public import utils
from wandb.apis.public.api import Api
from wandb.apis.public.const import RETRY_TIMEDELTA
Expand Down Expand Up @@ -41,7 +41,7 @@
}"""


class Files(Paginator):
class Files(SizedPaginator["File"]):
"""An iterable collection of `File` objects."""

QUERY = gql(
Expand Down
5 changes: 3 additions & 2 deletions wandb/apis/public/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
}"""


class Projects(Paginator):
class Projects(Paginator["Project"]):
"""An iterable collection of `Project` objects."""

QUERY = gql(
Expand Down Expand Up @@ -49,7 +49,8 @@ def __init__(self, client, entity, per_page=50):
super().__init__(client, variables, per_page)

@property
def length(self):
def length(self) -> None:
# For backwards compatibility, even though this isn't a SizedPaginator
return None

@property
Expand Down
4 changes: 2 additions & 2 deletions wandb/apis/public/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import wandb
from wandb.apis import public
from wandb.apis.attrs import Attrs
from wandb.apis.paginator import Paginator
from wandb.apis.paginator import SizedPaginator
from wandb.sdk.lib import ipython


class Reports(Paginator):
class Reports(SizedPaginator["BetaReport"]):
"""Reports is an iterable collection of `BetaReport` objects."""

QUERY = gql(
Expand Down
4 changes: 2 additions & 2 deletions wandb/apis/public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from wandb.apis.attrs import Attrs
from wandb.apis.internal import Api as InternalApi
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import Paginator
from wandb.apis.paginator import SizedPaginator
from wandb.apis.public.const import RETRY_TIMEDELTA
from wandb.sdk.lib import ipython, json_util, runid
from wandb.sdk.lib.paths import LogicalPath
Expand Down Expand Up @@ -61,7 +61,7 @@
}"""


class Runs(Paginator):
class Runs(SizedPaginator["Run"]):
"""An iterable collection of runs associated with a project and optional filter.
This is generally used indirectly via the `Api`.runs method.
Expand Down

0 comments on commit 436be53

Please sign in to comment.