Skip to content

Commit

Permalink
[Storage] Use task's region for initializing new stores (#3319)
Browse files Browse the repository at this point in the history
* make stores region aware

* handle unspecified region for AWS

* Use us-east-1 default region

* lint

* Add region tests to TestStorageWithCredentials

* make _get_preferred_store private

* update smoke tests to include region testing for buckets

* reorder validation check

* lint

* lint

* comments
  • Loading branch information
romilbhardwaj authored Mar 25, 2024
1 parent 90eeb00 commit eb442b0
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 16 deletions.
7 changes: 5 additions & 2 deletions sky/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ def split_cos_path(s3_path: str) -> Tuple[str, str, str]:
return bucket_name, data_path, region


def create_s3_client(region: str = 'us-east-2') -> Client:
def create_s3_client(region: Optional[str] = None) -> Client:
"""Helper method that connects to Boto3 client for S3 Bucket
Args:
region: str; Region name, e.g. us-west-1, us-east-2
region: str; Region name, e.g. us-west-1, us-east-2. If None, default
region us-east-1 is used.
"""
if region is None:
region = 'us-east-1'
return aws.client('s3', region_name=region)


Expand Down
23 changes: 17 additions & 6 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,18 @@ def from_metadata(cls, metadata: StorageMetadata,

return storage_obj

def add_store(self, store_type: Union[str, StoreType]) -> AbstractStore:
def add_store(self,
store_type: Union[str, StoreType],
region: Optional[str] = None) -> AbstractStore:
"""Initializes and adds a new store to the storage.
Invoked by the optimizer after it has selected a store to
add it to Storage.
Args:
store_type: StoreType; Type of the storage [S3, GCS, AZURE, R2, IBM]
region: str; Region to place the bucket in. Caller must ensure that
the region is valid for the chosen store_type.
"""
if isinstance(store_type, str):
store_type = StoreType(store_type)
Expand Down Expand Up @@ -846,6 +850,7 @@ def add_store(self, store_type: Union[str, StoreType]) -> AbstractStore:
store = store_cls(
name=self.name,
source=self.source,
region=region,
sync_on_reconstruction=self.sync_on_reconstruction)
except exceptions.StorageBucketCreateError:
# Creation failed, so this must be sky managed store. Add failure
Expand Down Expand Up @@ -1330,7 +1335,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]:
# Store object is being reconstructed for deletion or re-mount with
# sky start, and error is raised instead.
if self.sync_on_reconstruction:
bucket = self._create_s3_bucket(self.name)
bucket = self._create_s3_bucket(self.name, self.region)
return bucket, True
else:
# Raised when Storage object is reconstructed for sky storage
Expand Down Expand Up @@ -1380,9 +1385,15 @@ def _create_s3_bucket(self,
if region is None:
s3_client.create_bucket(Bucket=bucket_name)
else:
location = {'LocationConstraint': region}
s3_client.create_bucket(Bucket=bucket_name,
CreateBucketConfiguration=location)
if region == 'us-east-1':
# If default us-east-1 region is used, the
# LocationConstraint must not be specified.
# https://stackoverflow.com/a/51912090
s3_client.create_bucket(Bucket=bucket_name)
else:
location = {'LocationConstraint': region}
s3_client.create_bucket(Bucket=bucket_name,
CreateBucketConfiguration=location)
logger.info(f'Created S3 bucket {bucket_name} in {region}')
except aws.botocore_exceptions().ClientError as e:
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -1756,7 +1767,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]:
# is being reconstructed for deletion or re-mount with
# sky start, and error is raised instead.
if self.sync_on_reconstruction:
bucket = self._create_gcs_bucket(self.name)
bucket = self._create_gcs_bucket(self.name, self.region)
return bucket, True
else:
# This is raised when Storage object is reconstructed for
Expand Down
19 changes: 12 additions & 7 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,15 +855,17 @@ def update_storage_mounts(
task_storage_mounts.update(storage_mounts)
return self.set_storage_mounts(task_storage_mounts)

def get_preferred_store_type(self) -> storage_lib.StoreType:
def _get_preferred_store(
self) -> Tuple[storage_lib.StoreType, Optional[str]]:
"""Returns the preferred store type and region for this task."""
# TODO(zhwu, romilb): The optimizer should look at the source and
# destination to figure out the right stores to use. For now, we
# use a heuristic solution to find the store type by the following
# order:
# 1. cloud decided in best_resources.
# 2. cloud specified in the task resources.
# 1. cloud/region decided in best_resources.
# 2. cloud/region specified in the task resources.
# 3. if not specified or the task's cloud does not support storage,
# use the first enabled storage cloud.
# use the first enabled storage cloud with default region.
# This should be refactored and moved to the optimizer.

# This check is not needed to support multiple accelerators;
Expand All @@ -877,9 +879,11 @@ def get_preferred_store_type(self) -> storage_lib.StoreType:

if self.best_resources is not None:
storage_cloud = self.best_resources.cloud
storage_region = self.best_resources.region
else:
resources = list(self.resources)[0]
storage_cloud = resources.cloud
storage_region = resources.region
if storage_cloud is not None:
if str(storage_cloud) not in enabled_storage_clouds:
storage_cloud = None
Expand All @@ -888,9 +892,10 @@ def get_preferred_store_type(self) -> storage_lib.StoreType:
storage_cloud = clouds.CLOUD_REGISTRY.from_str(
enabled_storage_clouds[0])
assert storage_cloud is not None, enabled_storage_clouds[0]
storage_region = None # Use default region in the Store class

store_type = storage_lib.get_storetype_from_cloud(storage_cloud)
return store_type
return store_type, storage_region

def sync_storage_mounts(self) -> None:
"""(INTERNAL) Eagerly syncs storage mounts to cloud storage.
Expand All @@ -901,9 +906,9 @@ def sync_storage_mounts(self) -> None:
"""
for storage in self.storage_mounts.values():
if len(storage.stores) == 0:
store_type = self.get_preferred_store_type()
store_type, store_region = self._get_preferred_store()
self.storage_plans[storage] = store_type
storage.add_store(store_type)
storage.add_store(store_type, store_region)
else:
# We will download the first store that is added to remote.
self.storage_plans[storage] = list(storage.stores.keys())[0]
Expand Down
103 changes: 102 additions & 1 deletion tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,26 @@ def test_spot_storage(generic_cloud: str):
yaml_str = pathlib.Path(
'examples/managed_spot_with_storage.yaml').read_text()
storage_name = f'sky-test-{int(time.time())}'

# Also perform region testing for bucket creation to validate if buckets are
# created in the correct region and correctly mounted in spot jobs.
# However, we inject this testing only for AWS and GCP since they are the
# supported object storage providers in SkyPilot.
region_flag = ''
region_validation_cmd = 'true'
if generic_cloud == 'aws':
region = 'eu-central-1'
region_flag = f' --region {region}'
region_cmd = TestStorageWithCredentials.cli_region_cmd(
storage_lib.StoreType.S3, storage_name)
region_validation_cmd = f'{region_cmd} | grep {region}'
elif generic_cloud == 'gcp':
region = 'us-west2'
region_flag = f' --region {region}'
region_cmd = TestStorageWithCredentials.cli_region_cmd(
storage_lib.StoreType.GCS, storage_name)
region_validation_cmd = f'{region_cmd} | grep {region}'

yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name)
with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
f.write(yaml_str)
Expand All @@ -2604,7 +2624,8 @@ def test_spot_storage(generic_cloud: str):
'spot_storage',
[
*storage_setup_commands,
f'sky spot launch -n {name} --cloud {generic_cloud} {file_path} -y',
f'sky spot launch -n {name} --cloud {generic_cloud}{region_flag} {file_path} -y',
region_validation_cmd, # Check if the bucket is created in the correct region
'sleep 60', # Wait the spot queue to be updated
f'{_SPOT_QUEUE_WAIT}| grep {name} | grep SUCCEEDED',
f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]'
Expand Down Expand Up @@ -3757,6 +3778,19 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''):
bucket_name, Rclone.RcloneClouds.IBM)
return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}'

@staticmethod
def cli_region_cmd(store_type, bucket_name):
if store_type == storage_lib.StoreType.S3:
return ('aws s3api get-bucket-location '
f'--bucket {bucket_name} --output text')
elif store_type == storage_lib.StoreType.GCS:
return (f'gsutil ls -L -b gs://{bucket_name}/ | '
'grep "Location constraint" | '
'awk \'{print tolower($NF)}\'')
else:
raise NotImplementedError(f'Region command not implemented for '
f'{store_type}')

@staticmethod
def cli_count_name_in_bucket(store_type, bucket_name, file_name, suffix=''):
if store_type == storage_lib.StoreType.S3:
Expand Down Expand Up @@ -4404,6 +4438,73 @@ def test_externally_created_bucket_mount_without_source(
if handle:
storage_obj.delete()

@pytest.mark.no_fluidstack
@pytest.mark.parametrize('region', [
'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-south-1',
'ap-southeast-1', 'ap-southeast-2', 'eu-central-1', 'eu-north-1',
'eu-west-1', 'eu-west-2', 'eu-west-3', 'sa-east-1', 'us-east-1',
'us-east-2', 'us-west-1', 'us-west-2'
])
def test_aws_regions(self, tmp_local_storage_obj, region):
# This tests creation and upload to bucket in all AWS s3 regions
# To test full functionality, use test_spot_storage above.
store_type = storage_lib.StoreType.S3
tmp_local_storage_obj.add_store(store_type, region=region)
bucket_name = tmp_local_storage_obj.name

# Confirm that the bucket was created in the correct region
region_cmd = self.cli_region_cmd(store_type, bucket_name)
out = subprocess.check_output(region_cmd, shell=True)
output = out.decode('utf-8')
expected_output_region = region
if region == 'us-east-1':
expected_output_region = 'None' # us-east-1 is the default region
assert expected_output_region in out.decode('utf-8'), (
f'Bucket was not found in region {region} - '
f'output of {region_cmd} was: {output}')

# Check if tmp_source/tmp-file exists in the bucket using cli
ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
out = subprocess.check_output(ls_cmd, shell=True)
output = out.decode('utf-8')
assert 'tmp-file' in output, (
f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')

@pytest.mark.no_fluidstack
@pytest.mark.parametrize('region', [
'northamerica-northeast1', 'northamerica-northeast2', 'us-central1',
'us-east1', 'us-east4', 'us-east5', 'us-south1', 'us-west1', 'us-west2',
'us-west3', 'us-west4', 'southamerica-east1', 'southamerica-west1',
'europe-central2', 'europe-north1', 'europe-southwest1', 'europe-west1',
'europe-west2', 'europe-west3', 'europe-west4', 'europe-west6',
'europe-west8', 'europe-west9', 'europe-west10', 'europe-west12',
'asia-east1', 'asia-east2', 'asia-northeast1', 'asia-northeast2',
'asia-northeast3', 'asia-southeast1', 'asia-south1', 'asia-south2',
'asia-southeast2', 'me-central1', 'me-central2', 'me-west1',
'australia-southeast1', 'australia-southeast2', 'africa-south1'
])
def test_gcs_regions(self, tmp_local_storage_obj, region):
# This tests creation and upload to bucket in all GCS regions
# To test full functionality, use test_spot_storage above.
store_type = storage_lib.StoreType.GCS
tmp_local_storage_obj.add_store(store_type, region=region)
bucket_name = tmp_local_storage_obj.name

# Confirm that the bucket was created in the correct region
region_cmd = self.cli_region_cmd(store_type, bucket_name)
out = subprocess.check_output(region_cmd, shell=True)
output = out.decode('utf-8')
assert region in out.decode('utf-8'), (
f'Bucket was not found in region {region} - '
f'output of {region_cmd} was: {output}')

# Check if tmp_source/tmp-file exists in the bucket using cli
ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
out = subprocess.check_output(ls_cmd, shell=True)
output = out.decode('utf-8')
assert 'tmp-file' in output, (
f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')


# ---------- Testing YAML Specs ----------
# Our sky storage requires credentials to check the bucket existance when
Expand Down

0 comments on commit eb442b0

Please sign in to comment.