-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Terrafrom] Add rate limiting #1084
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import asyncio | ||
import time | ||
from typing import Any, AsyncGenerator, Optional | ||
from port_ocean.utils import http_async_client | ||
import httpx | ||
|
@@ -23,6 +25,12 @@ class CacheKeys(StrEnum): | |
PAGE_SIZE = 100 | ||
|
||
|
||
# https://developer.hashicorp.com/terraform/cloud-docs/api-docs#rate-limiting | ||
RATE_LIMIT_PER_SECOND = 30 | ||
RATE_LIMIT_BUFFER = 5 # Buffer to avoid hitting the exact limit | ||
MAX_CONCURRENT_REQUESTS = 10 | ||
|
||
|
||
class TerraformClient: | ||
def __init__(self, terraform_base_url: str, auth_token: str) -> None: | ||
self.terraform_base_url = terraform_base_url | ||
|
@@ -32,7 +40,30 @@ def __init__(self, terraform_base_url: str, auth_token: str) -> None: | |
} | ||
self.api_url = f"{self.terraform_base_url}/api/v2" | ||
self.client = http_async_client | ||
self.client.headers.update(self.base_headers) | ||
|
||
self.rate_limit = RATE_LIMIT_PER_SECOND | ||
self.rate_limit_remaining = RATE_LIMIT_PER_SECOND | ||
self.rate_limit_reset: float = 0.0 | ||
self.last_request_time = time.time() | ||
self.request_times: list[float] = [] | ||
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) | ||
self.rate_limit_lock = asyncio.Lock() | ||
|
||
async def wait_for_rate_limit(self) -> None: | ||
async with self.rate_limit_lock: | ||
current_time = time.time() | ||
self.request_times = [t for t in self.request_times if current_time - t < 1] | ||
|
||
if len(self.request_times) >= RATE_LIMIT_PER_SECOND: | ||
wait_time = 1 - (current_time - self.request_times[0]) | ||
if wait_time > 0: | ||
logger.info( | ||
f"Rate limit reached, waiting for {wait_time:.2f} seconds" | ||
) | ||
await asyncio.sleep(wait_time) | ||
self.request_times = self.request_times[1:] | ||
|
||
self.request_times.append(current_time) | ||
Comment on lines
+43
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
async def send_api_request( | ||
self, | ||
|
@@ -41,32 +72,64 @@ async def send_api_request( | |
query_params: Optional[dict[str, Any]] = None, | ||
json_data: Optional[dict[str, Any]] = None, | ||
) -> dict[str, Any]: | ||
logger.info(f"Requesting Terraform Cloud data for endpoint: {endpoint}") | ||
try: | ||
url = f"{self.api_url}/{endpoint}" | ||
logger.info( | ||
f"URL: {url}, Method: {method}, Params: {query_params}, Body: {json_data}" | ||
) | ||
response = await self.client.request( | ||
method=method, | ||
url=url, | ||
params=query_params, | ||
json=json_data, | ||
) | ||
response.raise_for_status() | ||
async with self.semaphore: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a concurrency constraint on the Terraform Cloud API as well ? |
||
await self.wait_for_rate_limit() | ||
|
||
logger.info(f"Successfully retrieved data for endpoint: {endpoint}") | ||
logger.info(f"Requesting Terraform Cloud data for endpoint: {endpoint}") | ||
try: | ||
url = f"{self.api_url}/{endpoint}" | ||
logger.info( | ||
f"URL: {url}, Method: {method}, Params: {query_params}, Body: {json_data}" | ||
) | ||
response = await self.client.request( | ||
method=method, | ||
url=url, | ||
params=query_params, | ||
json=json_data, | ||
headers=self.base_headers, | ||
) | ||
response.raise_for_status() | ||
|
||
return response.json() | ||
async with self.rate_limit_lock: | ||
self.rate_limit = int( | ||
response.headers.get("x-ratelimit-limit", RATE_LIMIT_PER_SECOND) | ||
) | ||
self.rate_limit_remaining = int( | ||
response.headers.get( | ||
"x-ratelimit-remaining", RATE_LIMIT_PER_SECOND | ||
) | ||
) | ||
self.rate_limit_reset = float( | ||
response.headers.get("x-ratelimit-reset", "0") | ||
) | ||
|
||
except httpx.HTTPStatusError as e: | ||
logger.error( | ||
f"HTTP error on {endpoint}: {e.response.status_code} - {e.response.text}" | ||
) | ||
raise | ||
except httpx.HTTPError as e: | ||
logger.error(f"HTTP error on {endpoint}: {str(e)}") | ||
raise | ||
logger.debug(f"Successfully retrieved data for endpoint: {endpoint}") | ||
logger.debug( | ||
f"Rate limit: {self.rate_limit_remaining}/{self.rate_limit}" | ||
) | ||
logger.debug(f"Rate limit reset: {self.rate_limit_reset}") | ||
|
||
return response.json() | ||
|
||
except httpx.HTTPStatusError as e: | ||
if e.response.status_code == 429: | ||
retry_after = float( | ||
e.response.headers.get("x-ratelimit-reset", "1") | ||
) | ||
logger.warning( | ||
f"Rate limit exceeded. Waiting for {retry_after} seconds before retrying." | ||
) | ||
await asyncio.sleep(retry_after) | ||
return await self.send_api_request( | ||
endpoint, method, query_params, json_data | ||
) | ||
logger.error( | ||
f"HTTP error on {endpoint}: {e.response.status_code} - {e.response.text}" | ||
) | ||
raise | ||
except httpx.HTTPError as e: | ||
logger.error(f"HTTP error on {endpoint}: {str(e)}") | ||
raise | ||
|
||
async def get_paginated_resources( | ||
self, endpoint: str, params: Optional[dict[str, Any]] = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
import asyncio | ||
from asyncio import gather | ||
from enum import StrEnum | ||
from typing import Any, List, Dict | ||
|
@@ -19,7 +18,7 @@ class ObjectKind(StrEnum): | |
|
||
|
||
SKIP_WEBHOOK_CREATION = False | ||
SEMAPHORE_LIMIT = 30 | ||
CHUNK_SIZE = 10 | ||
|
||
|
||
def init_terraform_client() -> TerraformClient: | ||
|
@@ -40,51 +39,25 @@ def init_terraform_client() -> TerraformClient: | |
async def enrich_state_versions_with_output_data( | ||
http_client: TerraformClient, state_versions: List[dict[str, Any]] | ||
) -> list[dict[str, Any]]: | ||
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT): | ||
tasks = [ | ||
http_client.get_state_version_output(state_version["id"]) | ||
for state_version in state_versions | ||
] | ||
|
||
output_batches = [] | ||
for completed_task in asyncio.as_completed(tasks): | ||
output = await completed_task | ||
output_batches.append(output) | ||
async def get_output(state_version: dict[str, Any]) -> dict[str, Any]: | ||
output = await http_client.get_state_version_output(state_version["id"]) | ||
return {**state_version, "__output": output} | ||
|
||
enriched_state_versions = [ | ||
{**state_version, "__output": output} | ||
for state_version, output in zip(state_versions, output_batches) | ||
] | ||
enriched_versions = [] | ||
for chunk in [ | ||
state_versions[i : i + CHUNK_SIZE] | ||
for i in range(0, len(state_versions), CHUNK_SIZE) | ||
]: | ||
chunk_results = await gather(*[get_output(sv) for sv in chunk]) | ||
enriched_versions.extend(chunk_results) | ||
|
||
return enriched_state_versions | ||
return enriched_versions | ||
|
||
|
||
async def enrich_workspaces_with_tags( | ||
http_client: TerraformClient, workspaces: List[dict[str, Any]] | ||
) -> list[dict[str, Any]]: | ||
async def get_tags_for_workspace(workspace: dict[str, Any]) -> dict[str, Any]: | ||
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT): | ||
try: | ||
tags = [] | ||
async for tag_batch in http_client.get_workspace_tags(workspace["id"]): | ||
tags.extend(tag_batch) | ||
return {**workspace, "__tags": tags} | ||
except Exception as e: | ||
logger.warning( | ||
f"Failed to fetch tags for workspace {workspace['id']}: {e}" | ||
) | ||
return {**workspace, "__tags": []} | ||
|
||
tasks = [get_tags_for_workspace(workspace) for workspace in workspaces] | ||
enriched_workspaces = [await task for task in asyncio.as_completed(tasks)] | ||
|
||
return enriched_workspaces | ||
|
||
|
||
async def enrich_workspace_with_tags( | ||
http_client: TerraformClient, workspace: dict[str, Any] | ||
) -> dict[str, Any]: | ||
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT): | ||
try: | ||
tags = [] | ||
async for tag_batch in http_client.get_workspace_tags(workspace["id"]): | ||
|
@@ -94,6 +67,28 @@ async def enrich_workspace_with_tags( | |
logger.warning(f"Failed to fetch tags for workspace {workspace['id']}: {e}") | ||
return {**workspace, "__tags": []} | ||
|
||
enriched_workspaces = [] | ||
for chunk in [ | ||
workspaces[i : i + CHUNK_SIZE] for i in range(0, len(workspaces), CHUNK_SIZE) | ||
]: | ||
chunk_results = await gather(*[get_tags_for_workspace(w) for w in chunk]) | ||
enriched_workspaces.extend(chunk_results) | ||
|
||
return enriched_workspaces | ||
|
||
|
||
async def enrich_workspace_with_tags( | ||
http_client: TerraformClient, workspace: dict[str, Any] | ||
) -> dict[str, Any]: | ||
try: | ||
tags = [] | ||
async for tag_batch in http_client.get_workspace_tags(workspace["id"]): | ||
tags.extend(tag_batch) | ||
return {**workspace, "__tags": tags} | ||
except Exception as e: | ||
logger.warning(f"Failed to fetch tags for workspace {workspace['id']}: {e}") | ||
return {**workspace, "__tags": []} | ||
|
||
|
||
@ocean.on_resync(ObjectKind.ORGANIZATION) | ||
async def resync_organizations(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: | ||
|
@@ -136,21 +131,20 @@ async def fetch_runs_for_workspace( | |
) | ||
] | ||
|
||
async def fetch_runs_for_all_workspaces() -> ASYNC_GENERATOR_RESYNC_TYPE: | ||
async for workspaces in terraform_client.get_paginated_workspaces(): | ||
logger.info( | ||
f"Received {len(workspaces)} batch workspaces... fetching its associated {kind}" | ||
) | ||
async for workspaces in terraform_client.get_paginated_workspaces(): | ||
logger.info( | ||
f"Received {len(workspaces)} batch workspaces... fetching its associated {kind}" | ||
) | ||
|
||
tasks = [fetch_runs_for_workspace(workspace) for workspace in workspaces] | ||
for completed_task in asyncio.as_completed(tasks): | ||
workspace_runs = await completed_task | ||
for chunk in [ | ||
workspaces[i : i + CHUNK_SIZE] | ||
for i in range(0, len(workspaces), CHUNK_SIZE) | ||
]: | ||
chunk_results = await gather(*[fetch_runs_for_workspace(w) for w in chunk]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a specific reason for replacing as_completed with gather ? waiting for all tasks within a chunk to complete before moving on to the next chunk appears to impede performance in this context. Please correct me. |
||
for workspace_runs in chunk_results: | ||
for runs in workspace_runs: | ||
yield runs | ||
|
||
async for run_batch in fetch_runs_for_all_workspaces(): | ||
yield run_batch | ||
|
||
|
||
@ocean.on_resync(ObjectKind.STATE_VERSION) | ||
async def resync_state_versions(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import asyncio | ||
from typing import AsyncGenerator | ||
import pytest | ||
from unittest.mock import AsyncMock, patch, MagicMock | ||
from client import TerraformClient, RATE_LIMIT_PER_SECOND | ||
import time | ||
import httpx | ||
|
||
|
||
@pytest.fixture | ||
def mock_http_client() -> AsyncMock: | ||
return AsyncMock(spec=httpx.AsyncClient) | ||
|
||
|
||
@pytest.fixture | ||
async def terraform_client( | ||
mock_http_client: AsyncMock, | ||
) -> AsyncGenerator[TerraformClient, None]: | ||
with patch("client.http_async_client", mock_http_client): | ||
client = TerraformClient("https://app.terraform.io", "test_token") | ||
client.rate_limit_lock = asyncio.Lock() | ||
# Manually set the headers to avoid the coroutine warning | ||
client.client.headers = httpx.Headers(client.base_headers) | ||
yield client | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_wait_for_rate_limit(terraform_client: TerraformClient) -> None: | ||
current_time = time.time() | ||
with patch("time.time", side_effect=[current_time, current_time + 0.1]): | ||
with patch.object(asyncio, "sleep", new_callable=AsyncMock) as mock_sleep: | ||
# Simulate rate limit not reached | ||
terraform_client.request_times = [current_time - 0.1] * ( | ||
RATE_LIMIT_PER_SECOND - 1 | ||
) | ||
await terraform_client.wait_for_rate_limit() | ||
mock_sleep.assert_not_called() | ||
|
||
# Simulate rate limit reached | ||
terraform_client.request_times = [ | ||
current_time - 0.1 | ||
] * RATE_LIMIT_PER_SECOND | ||
await terraform_client.wait_for_rate_limit() | ||
mock_sleep.assert_called_once() | ||
assert mock_sleep.call_args[0][0] > 0 # Ensure sleep time is positive | ||
|
||
# Test when wait time is not needed | ||
current_time = time.time() | ||
with patch("time.time", return_value=current_time): | ||
with patch.object(asyncio, "sleep", new_callable=AsyncMock) as mock_sleep: | ||
terraform_client.request_times = [ | ||
current_time - 1.1 | ||
] * RATE_LIMIT_PER_SECOND | ||
await terraform_client.wait_for_rate_limit() | ||
mock_sleep.assert_not_called() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_send_api_request( | ||
terraform_client: TerraformClient, mock_http_client: AsyncMock | ||
) -> None: | ||
mock_response = MagicMock() | ||
mock_response.json.return_value = {"data": [{"id": "test"}]} | ||
mock_response.headers = { | ||
"x-ratelimit-limit": "30", | ||
"x-ratelimit-remaining": "29", | ||
"x-ratelimit-reset": "1.0", | ||
} | ||
mock_http_client.request.return_value = mock_response | ||
|
||
result = await terraform_client.send_api_request("test_endpoint") | ||
|
||
expected_headers = { | ||
"Authorization": "Bearer test_token", | ||
"Content-Type": "application/vnd.api+json", | ||
} | ||
|
||
mock_http_client.request.assert_called_once_with( | ||
method="GET", | ||
url="https://app.terraform.io/api/v2/test_endpoint", | ||
params=None, | ||
json=None, | ||
headers=expected_headers, | ||
) | ||
|
||
assert result == {"data": [{"id": "test"}]} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_paginated_resources(terraform_client: TerraformClient) -> None: | ||
mock_responses = [ | ||
{"data": [{"id": "1"}, {"id": "2"}], "links": {"next": "page2"}}, | ||
{"data": [{"id": "3"}, {"id": "4"}], "links": {"next": None}}, | ||
] | ||
|
||
with patch.object( | ||
terraform_client, "send_api_request", side_effect=mock_responses | ||
) as mock_send: | ||
results = [] | ||
async for resources in terraform_client.get_paginated_resources( | ||
"test_endpoint" | ||
): | ||
results.extend(resources) | ||
|
||
assert len(results) == 4 | ||
assert [r["id"] for r in results] == ["1", "2", "3", "4"] | ||
assert mock_send.call_count == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 10 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 is how many workspaces we will process concurrently when fetching runs (or enriching these workspaces).