Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] Use task's region for initializing new stores #3319

Merged
merged 12 commits into from
Mar 25, 2024
7 changes: 5 additions & 2 deletions sky/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ def split_cos_path(s3_path: str) -> Tuple[str, str, str]:
return bucket_name, data_path, region


def create_s3_client(region: str = 'us-east-2') -> Client:
def create_s3_client(region: Optional[str] = 'us-east-1') -> Client:
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
"""Helper method that connects to Boto3 client for S3 Bucket

Args:
region: str; Region name, e.g. us-west-1, us-east-2
region: str; Region name, e.g. us-west-1, us-east-2. If None, default
region us-east-1 is used.
"""
if region is None:
region = 'us-east-1'
return aws.client('s3', region_name=region)


Expand Down
4 changes: 3 additions & 1 deletion sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,9 @@ def _create_s3_bucket(self,
s3_client.create_bucket(Bucket=bucket_name)
else:
if region == 'us-east-1':
# If default us-east-1 region is used, the LocationConstraint must not be specified. https://stackoverflow.com/a/51912090
# If default us-east-1 region is used, the
# LocationConstraint must not be specified.
# https://stackoverflow.com/a/51912090
s3_client.create_bucket(Bucket=bucket_name)
else:
location = {'LocationConstraint': region}
Expand Down
7 changes: 4 additions & 3 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,8 @@ def update_storage_mounts(
task_storage_mounts.update(storage_mounts)
return self.set_storage_mounts(task_storage_mounts)

def get_preferred_store(self) -> Tuple[storage_lib.StoreType, Optional[str]]:
def _get_preferred_store(
self) -> Tuple[storage_lib.StoreType, Optional[str]]:
"""Returns the preferred store type and region for this task."""
# TODO(zhwu, romilb): The optimizer should look at the source and
# destination to figure out the right stores to use. For now, we
Expand Down Expand Up @@ -891,7 +892,7 @@ def get_preferred_store(self) -> Tuple[storage_lib.StoreType, Optional[str]]:
storage_cloud = clouds.CLOUD_REGISTRY.from_str(
enabled_storage_clouds[0])
assert storage_cloud is not None, enabled_storage_clouds[0]
storage_region = None # Use default region in the Store class
storage_region = None # Use default region in the Store class

store_type = storage_lib.get_storetype_from_cloud(storage_cloud)
return store_type, storage_region
Expand All @@ -905,7 +906,7 @@ def sync_storage_mounts(self) -> None:
"""
for storage in self.storage_mounts.values():
if len(storage.stores) == 0:
store_type, store_region = self.get_preferred_store()
store_type, store_region = self._get_preferred_store()
self.storage_plans[storage] = store_type
storage.add_store(store_type, store_region)
else:
Expand Down
104 changes: 103 additions & 1 deletion tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,26 @@ def test_spot_storage(generic_cloud: str):
yaml_str = pathlib.Path(
'examples/managed_spot_with_storage.yaml').read_text()
storage_name = f'sky-test-{int(time.time())}'

# Also perform region testing for bucket creation to validate if buckets are
# created in the correct region and correctly mounted in spot jobs.
# However, we inject this testing only for AWS and GCP since they are the
# supported object storage providers in SkyPilot.
region_flag = ''
region_validation_cmd = 'true'
if generic_cloud == 'aws':
region = 'eu-central-1'
region_flag = f' --region {region}'
region_cmd = TestStorageWithCredentials.cli_region_cmd(
storage_lib.StoreType.S3, storage_name)
region_validation_cmd = f'{region_cmd} | grep {region}'
elif generic_cloud == 'gcp':
region = 'us-west2'
region_flag = f' --region {region}'
region_cmd = TestStorageWithCredentials.cli_region_cmd(
storage_lib.StoreType.GCS, storage_name)
region_validation_cmd = f'{region_cmd} | grep {region}'

yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name)
with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f:
f.write(yaml_str)
Expand All @@ -2604,7 +2624,8 @@ def test_spot_storage(generic_cloud: str):
'spot_storage',
[
*storage_setup_commands,
f'sky spot launch -n {name} --cloud {generic_cloud} {file_path} -y',
f'sky spot launch -n {name} --cloud {generic_cloud}{region_flag} {file_path} -y',
region_validation_cmd, # Check if the bucket is created in the correct region
'sleep 60', # Wait the spot queue to be updated
f'{_SPOT_QUEUE_WAIT}| grep {name} | grep SUCCEEDED',
f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]'
Expand Down Expand Up @@ -3757,6 +3778,20 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''):
bucket_name, Rclone.RcloneClouds.IBM)
return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}'

@staticmethod
def cli_region_cmd(store_type, bucket_name):
if store_type == storage_lib.StoreType.S3:
return ('aws s3api get-bucket-location '
f'--bucket {bucket_name} --output text')
if store_type == storage_lib.StoreType.GCS:
return (f'gsutil ls -L -b gs://{bucket_name}/ | '
'grep "Location constraint" | '
'awk \'{print tolower($NF)}\'')
if store_type == storage_lib.StoreType.R2:
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError
if store_type == storage_lib.StoreType.IBM:
raise NotImplementedError

@staticmethod
def cli_count_name_in_bucket(store_type, bucket_name, file_name, suffix=''):
if store_type == storage_lib.StoreType.S3:
Expand Down Expand Up @@ -4404,6 +4439,73 @@ def test_externally_created_bucket_mount_without_source(
if handle:
storage_obj.delete()

@pytest.mark.no_fluidstack
@pytest.mark.parametrize('region', [
'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-south-1',
'ap-southeast-1', 'ap-southeast-2', 'eu-central-1', 'eu-north-1',
'eu-west-1', 'eu-west-2', 'eu-west-3', 'sa-east-1', 'us-east-1',
'us-east-2', 'us-west-1', 'us-west-2'
])
def test_aws_regions(self, tmp_local_storage_obj, region):
# This tests creation and upload to bucket in all AWS s3 regions
# To test full functionality, use test_spot_storage above.
store_type = storage_lib.StoreType.S3
tmp_local_storage_obj.add_store(store_type, region=region)
bucket_name = tmp_local_storage_obj.name

# Confirm that the bucket was created in the correct region
region_cmd = self.cli_region_cmd(store_type, bucket_name)
out = subprocess.check_output(region_cmd, shell=True)
output = out.decode('utf-8')
expected_output_region = region
if region == 'us-east-1':
expected_output_region = 'None' # us-east-1 is the default region
assert expected_output_region in out.decode('utf-8'), (
f'Bucket was not found in region {region} - '
f'output of {region_cmd} was: {output}')

# Check if tmp_source/tmp-file exists in the bucket using cli
ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
out = subprocess.check_output(ls_cmd, shell=True)
output = out.decode('utf-8')
assert 'tmp-file' in output, (
f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')

@pytest.mark.no_fluidstack
@pytest.mark.parametrize('region', [
'northamerica-northeast1', 'northamerica-northeast2', 'us-central1',
'us-east1', 'us-east4', 'us-east5', 'us-south1', 'us-west1', 'us-west2',
'us-west3', 'us-west4', 'southamerica-east1', 'southamerica-west1',
'europe-central2', 'europe-north1', 'europe-southwest1', 'europe-west1',
'europe-west2', 'europe-west3', 'europe-west4', 'europe-west6',
'europe-west8', 'europe-west9', 'europe-west10', 'europe-west12',
'asia-east1', 'asia-east2', 'asia-northeast1', 'asia-northeast2',
'asia-northeast3', 'asia-southeast1', 'asia-south1', 'asia-south2',
'asia-southeast2', 'me-central1', 'me-central2', 'me-west1',
'australia-southeast1', 'australia-southeast2', 'africa-south1'
])
def test_gcs_regions(self, tmp_local_storage_obj, region):
# This tests creation and upload to bucket in all GCS regions
# To test full functionality, use test_spot_storage above.
store_type = storage_lib.StoreType.GCS
tmp_local_storage_obj.add_store(store_type, region=region)
bucket_name = tmp_local_storage_obj.name

# Confirm that the bucket was created in the correct region
region_cmd = self.cli_region_cmd(store_type, bucket_name)
out = subprocess.check_output(region_cmd, shell=True)
output = out.decode('utf-8')
assert region in out.decode('utf-8'), (
f'Bucket was not found in region {region} - '
f'output of {region_cmd} was: {output}')

# Check if tmp_source/tmp-file exists in the bucket using cli
ls_cmd = self.cli_ls_cmd(store_type, bucket_name)
out = subprocess.check_output(ls_cmd, shell=True)
output = out.decode('utf-8')
assert 'tmp-file' in output, (
f'tmp-file not found in bucket - output of {ls_cmd} was: {output}')


# ---------- Testing YAML Specs ----------
# Our sky storage requires credentials to check the bucket existance when
Expand Down
Loading