diff --git a/.gitignore b/.gitignore index 7e07a5339..2cb00cc17 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # ----- Project ----- tmp +.idea # Created by https://www.toptal.com/developers/gitignore/api/python,node,visualstudiocode,jetbrains,macos,windows,linux # Edit at https://www.toptal.com/developers/gitignore?templates=python,node,visualstudiocode,jetbrains,macos,windows,linux diff --git a/githubkit/core.py b/githubkit/core.py index 2d9886d4c..2db19e4b9 100644 --- a/githubkit/core.py +++ b/githubkit/core.py @@ -1,5 +1,6 @@ from types import TracebackType from contextvars import ContextVar +from datetime import datetime, timezone, timedelta from contextlib import contextmanager, asynccontextmanager from typing import ( Any, @@ -22,7 +23,6 @@ from .response import Response from .utils import obj_to_jsonable from .config import Config, get_config -from .exception import RequestError, RequestFailed, RequestTimeout from .auth import BaseAuthStrategy, TokenAuthStrategy, UnauthAuthStrategy from .typing import ( URLTypes, @@ -32,6 +32,13 @@ RequestFiles, QueryParamTypes, ) +from .exception import ( + RequestError, + RequestFailed, + RequestTimeout, + PrimaryRateLimitExceeded, + SecondaryRateLimitExceeded, +) T = TypeVar("T") A = TypeVar("A", bound="BaseAuthStrategy") @@ -316,15 +323,73 @@ def _check( if response.is_error: error_models = error_models or {} status_code = str(response.status_code) + error_model = error_models.get( status_code, error_models.get( f"{status_code[:-2]}XX", error_models.get("default", Any) ), ) - rep = Response(response, error_model) - raise RequestFailed(rep) - return Response(response, response_model) + resp = Response(response, error_model) + else: + resp = Response(response, response_model) + + # only check rate limit when response is 403 or 429 + if response.status_code in (403, 429): + self._check_rate_limit(resp) + + if response.is_error: + raise RequestFailed(resp) + return resp + + # check rate limit + def _check_rate_limit(self, response: Response) -> None: + # check rate limit exceeded + # https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api#exceeding-the-rate-limit + # https://docs.github.com/en/graphql/overview/rate-limits-and-node-limits-for-the-graphql-api#exceeding-the-rate-limit + # https://github.com/octokit/plugin-throttling.js/blob/135a0f556752a6c4c0ed3b2798bb58e228cd179a/src/index.ts#L134-L179 + + # Secondary rate limits + # the `retry-after` response header is present + if "retry-after" in response.headers: + raise SecondaryRateLimitExceeded( + response, self._extract_retry_after(response) + ) + + if ( + "x-ratelimit-remaining" in response.headers + and response.headers["x-ratelimit-remaining"] == "0" + ): + retry_after = self._extract_retry_after(response) + + try: + error = response.json() + except Exception: + error = None + + # Secondary rate limits + # error message indicates that you exceeded a secondary rate limit + if ( + isinstance(error, dict) + and "message" in error + and "secondary rate" in error["message"] + ): + raise SecondaryRateLimitExceeded(response, retry_after) + + # Primary rate limits + raise PrimaryRateLimitExceeded(response, retry_after) + + def _extract_retry_after(self, response: Response) -> timedelta: + if "retry-after" in response.headers: + return timedelta(seconds=int(response.headers["retry-after"])) + elif "x-ratelimit-reset" in response.headers: + retry_after = datetime.fromtimestamp( + int(response.headers["x-ratelimit-reset"]), tz=timezone.utc + ) - datetime.now(tz=timezone.utc) + return max(retry_after, timedelta()) + else: + # wait for at least one minute before retrying + return timedelta(seconds=60) # sync request and check def request( diff --git a/githubkit/exception.py b/githubkit/exception.py index c1aefd7a5..d9f92c9fc 100644 --- a/githubkit/exception.py +++ b/githubkit/exception.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import TYPE_CHECKING import httpx @@ -50,6 +51,29 @@ def __repr__(self) -> str: ) +class RateLimitExceeded(RequestFailed): + """API request failed with rate limit exceeded""" + + def __init__(self, response: "Response", retry_after: timedelta): + super().__init__(response) + self.retry_after = retry_after + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(method={self.request.method}, " + f"url={self.request.url}, status_code={self.response.status_code}, " + f"retry_after={self.retry_after})" + ) + + +class PrimaryRateLimitExceeded(RateLimitExceeded): + """API request failed with primary rate limit exceeded""" + + +class SecondaryRateLimitExceeded(RateLimitExceeded): + """API request failed with secondary rate limit exceeded""" + + class GraphQLFailed(GitHubException): """GraphQL request with errors in response""" diff --git a/githubkit/github.py b/githubkit/github.py index 45e19cbb5..d0a891736 100644 --- a/githubkit/github.py +++ b/githubkit/github.py @@ -142,7 +142,8 @@ def graphql( json = build_graphql_request(query, variables) return parse_graphql_response( - self.request("POST", "/graphql", json=json, response_model=GraphQLResponse) + self, + self.request("POST", "/graphql", json=json, response_model=GraphQLResponse), ) async def async_graphql( @@ -151,9 +152,10 @@ async def async_graphql( json = build_graphql_request(query, variables) return parse_graphql_response( + self, await self.arequest( "POST", "/graphql", json=json, response_model=GraphQLResponse - ) + ), ) # rest pagination diff --git a/githubkit/graphql/__init__.py b/githubkit/graphql/__init__.py index 5ff0fbf02..ade1ed4b2 100644 --- a/githubkit/graphql/__init__.py +++ b/githubkit/graphql/__init__.py @@ -1,12 +1,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, cast -from githubkit.exception import GraphQLFailed +from githubkit.exception import GraphQLFailed, PrimaryRateLimitExceeded from .models import GraphQLError as GraphQLError from .models import SourceLocation as SourceLocation from .models import GraphQLResponse as GraphQLResponse if TYPE_CHECKING: + from githubkit.core import GitHubCore from githubkit.response import Response @@ -19,8 +20,18 @@ def build_graphql_request( return json -def parse_graphql_response(response: "Response[GraphQLResponse]") -> Dict[str, Any]: +def parse_graphql_response( + github: "GitHubCore", response: "Response[GraphQLResponse]" +) -> Dict[str, Any]: response_data = response.parsed_data if response_data.errors: + # check rate limit exceeded + # https://docs.github.com/en/graphql/overview/rate-limits-and-node-limits-for-the-graphql-api#exceeding-the-rate-limit + # x-ratelimit-remaining may not be 0, ignore it + # https://github.com/octokit/plugin-throttling.js/pull/636 + if any(error.type == "RATE_LIMITED" for error in response_data.errors): + raise PrimaryRateLimitExceeded( + response, github._extract_retry_after(response) + ) raise GraphQLFailed(response_data) return cast(Dict[str, Any], response_data.data) diff --git a/githubkit/graphql/models.py b/githubkit/graphql/models.py index ad98ce1f5..2c49ceda0 100644 --- a/githubkit/graphql/models.py +++ b/githubkit/graphql/models.py @@ -9,6 +9,7 @@ class SourceLocation(GitHubModel): class GraphQLError(GitHubModel): + type: str # https://github.com/octokit/graphql.js/pull/314 message: str locations: Optional[List[SourceLocation]] = None path: Optional[List[Union[int, str]]] = None