Skip to content

Commit

Permalink
Merge branch 'main' into fine-grained-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
IanMagnusson authored Oct 16, 2023
2 parents 22068df + 4644ff5 commit 8a29d40
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,18 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes
return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)


s3_client = boto3.client(
"s3",
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
)
_s3_client = None


def _get_s3_client():
global _s3_client
if _s3_client is None:
_s3_client = boto3.client(
"s3",
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
)
return _s3_client


def _wait_before_retry(attempt: int):
Expand All @@ -519,7 +526,7 @@ def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
if not save_overwrite:
for attempt in range(1, max_attempts + 1):
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
_get_s3_client().head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
Expand All @@ -537,7 +544,7 @@ def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
raise OlmoNetworkError("Failed to check object existence during s3 upload") from err

try:
s3_client.upload_file(source, bucket_name, key)
_get_s3_client().upload_file(source, bucket_name, key)
except boto_exceptions.ClientError as e:
raise OlmoNetworkError("Failed to upload to s3") from e

Expand All @@ -546,7 +553,7 @@ def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
return _get_s3_client().head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
Expand All @@ -565,9 +572,13 @@ def _s3_get_bytes_range(
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
return (
_get_s3_client()
.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"]
.read()
)
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
Expand Down

0 comments on commit 8a29d40

Please sign in to comment.