diff --git a/olmo/util.py b/olmo/util.py index 0e9c3a249..e29d98a43 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -503,11 +503,18 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) -s3_client = boto3.client( - "s3", - config=Config(retries={"max_attempts": 10, "mode": "standard"}), - use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")), -) +_s3_client = None + + +def _get_s3_client(): + global _s3_client + if _s3_client is None: + _s3_client = boto3.client( + "s3", + config=Config(retries={"max_attempts": 10, "mode": "standard"}), + use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")), + ) + return _s3_client def _wait_before_retry(attempt: int): @@ -519,7 +526,7 @@ def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = if not save_overwrite: for attempt in range(1, max_attempts + 1): try: - s3_client.head_object(Bucket=bucket_name, Key=key) + _get_s3_client().head_object(Bucket=bucket_name, Key=key) raise FileExistsError( f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." ) @@ -537,7 +544,7 @@ def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = raise OlmoNetworkError("Failed to check object existence during s3 upload") from err try: - s3_client.upload_file(source, bucket_name, key) + _get_s3_client().upload_file(source, bucket_name, key) except boto_exceptions.ClientError as e: raise OlmoNetworkError("Failed to upload to s3") from e @@ -546,7 +553,7 @@ def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int: err: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: - return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"] + return _get_s3_client().head_object(Bucket=bucket_name, Key=key)["ContentLength"] except boto_exceptions.ClientError as e: if int(e.response["Error"]["Code"]) == 404: raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e @@ -565,9 +572,13 @@ def _s3_get_bytes_range( err: Optional[Exception] = None for attempt in range(1, max_attempts + 1): try: - return s3_client.get_object( - Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}" - )["Body"].read() + return ( + _get_s3_client() + .get_object( + Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}" + )["Body"] + .read() + ) except boto_exceptions.ClientError as e: if int(e.response["Error"]["Code"]) == 404: raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e