diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index d0d71af6b..8f0fdc394 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -58,8 +58,17 @@ def _get_bucket_resource(bucket_name: str) -> "Bucket": mode="standard", ) ) - s3 = boto3.resource('s3', config=config) - return s3.Bucket(bucket_name) + + R2_ACCOUNT_ID = os.environ.get("R2_ACCOUNT_ID", None) + if R2_ACCOUNT_ID: + s3 = boto3.resource('s3', + endpoint_url = f'https://{R2_ACCOUNT_ID}.r2.cloudflarestorage.com', + config=config + ) + return s3.Bucket(bucket_name) + else: + s3 = boto3.resource('s3', config=config) + return s3.Bucket(bucket_name) def get_s3_model_local_dir(model_id: str): @@ -222,4 +231,4 @@ def download_model_assets(self): return download_model_from_s3(self.bucket, self.model_id, self.extension) def get_local_path(self, model_id: str): - return get_s3_model_local_dir(model_id) \ No newline at end of file + return get_s3_model_local_dir(model_id)