Skip to content

Commit

Permalink
reset task and execution, delay add_store in storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Shrinandan-N committed Jul 29, 2024
1 parent 3835992 commit 04d148f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 49 deletions.
53 changes: 14 additions & 39 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def from_cloud(cls, cloud: str) -> 'StoreType':
return StoreType.AZURE
elif cloud.lower() == str(clouds.Lambda()).lower():
with ux_utils.print_exception_no_traceback():
raise ValueError('Lambda Cloud does not provide cloud storage.')
raise ValueError(
'Lambda Cloud does not provide cloud storage.')
elif cloud.lower() == str(clouds.SCP()).lower():
with ux_utils.print_exception_no_traceback():
raise ValueError('SCP does not provide cloud storage.')
Expand Down Expand Up @@ -417,7 +418,6 @@ class StorageMetadata(object):
- (required) Source
- (optional) Storage mode.
- (optional) Set of stores managed by sky added to the Storage object
- (optional) Region of the bucket.
"""

def __init__(
Expand All @@ -427,14 +427,11 @@ def __init__(
source: Optional[SourceType],
mode: Optional[StorageMode] = None,
sky_stores: Optional[Dict[StoreType,
AbstractStore.StoreMetadata]] = None,
region: Optional[str] = None,
):
AbstractStore.StoreMetadata]] = None):
assert storage_name is not None or source is not None
self.storage_name = storage_name
self.source = source
self.mode = mode
self.region = region
# Only stores managed by sky are stored here in the
# global_user_state
self.sky_stores = {} if sky_stores is None else sky_stores
Expand All @@ -444,8 +441,7 @@ def __repr__(self):
f'\n\tstorage_name={self.storage_name},'
f'\n\tsource={self.source},'
f'\n\tmode={self.mode},'
f'\n\tstores={self.sky_stores})'
f'\n\tregion={self.region})')
f'\n\tstores={self.sky_stores})')

def add_store(self, store: AbstractStore) -> None:
storetype = StoreType.from_store(store)
Expand All @@ -462,7 +458,6 @@ def __init__(self,
stores: Optional[Dict[StoreType, AbstractStore]] = None,
persistent: Optional[bool] = True,
mode: StorageMode = StorageMode.MOUNT,
region: Optional[str] = None,
sync_on_reconstruction: bool = True) -> None:
"""Initializes a Storage object.
Expand Down Expand Up @@ -506,7 +501,6 @@ def __init__(self,
self.source = source
self.persistent = persistent
self.mode = mode
self.region = region
assert mode in StorageMode
self.sync_on_reconstruction = sync_on_reconstruction

Expand All @@ -526,20 +520,9 @@ def __init__(self,
handle = global_user_state.get_handle_from_storage_name(self.name)
if handle is not None:
self.handle = handle
self.handle.sky_stores = {
s_type: AbstractStore.StoreMetadata(
name=s_metadata.name,
source=s_metadata.source,
region=self.region,
is_sky_managed=s_metadata.is_sky_managed)
for s_type, s_metadata in self.handle.sky_stores.items()
}

# Reconstruct the Storage object from the global_user_state
logger.debug('Detected existing storage object, '
f'loading Storage: {self.name}')
self._add_store_from_metadata(self.handle.sky_stores)

# TODO(romilb): This logic should likely be in add_store to move
# syncing to file_mount stage..
if self.sync_on_reconstruction:
Expand All @@ -561,7 +544,6 @@ def __init__(self,
self.handle = self.StorageMetadata(storage_name=self.name,
source=self.source,
mode=self.mode,
region=self.region,
sky_stores=sky_managed_stores)

if self.source is not None:
Expand Down Expand Up @@ -842,7 +824,6 @@ def from_metadata(cls, metadata: StorageMetadata,
Used when reconstructing Storage object and AbstractStore objects from
global_user_state.
"""

# Name should not be specified if the source is a cloud store URL.
source = override_args.get('source', metadata.source)
name = override_args.get('name', metadata.storage_name)
Expand Down Expand Up @@ -982,7 +963,8 @@ def delete(self, store_type: Optional[StoreType] = None) -> None:
if delete:
global_user_state.remove_storage(self.name)
else:
global_user_state.set_storage_handle(self.name, self.handle)
global_user_state.set_storage_handle(
self.name, self.handle)
elif self.force_delete:
store.delete()
# Remove store from bookkeeping
Expand Down Expand Up @@ -1029,12 +1011,11 @@ def warn_for_git_dir(source: str):

# Upload succeeded - update state
if store.is_sky_managed:
global_user_state.set_storage_status(self.name, StorageStatus.READY)
global_user_state.set_storage_status(
self.name, StorageStatus.READY)

@classmethod
def from_yaml_config(cls,
config: Dict[str, Any],
region: Optional[str] = None) -> 'Storage':
def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage':
common_utils.validate_schema(config, schemas.get_storage_schema(),
'Invalid storage YAML: ')

Expand All @@ -1043,7 +1024,6 @@ def from_yaml_config(cls,
store = config.pop('store', None)
mode_str = config.pop('mode', None)
force_delete = config.pop('_force_delete', None)

if force_delete is None:
force_delete = False

Expand All @@ -1063,15 +1043,10 @@ def from_yaml_config(cls,
storage_obj = cls(name=name,
source=source,
persistent=persistent,
mode=mode,
region=region)

if store is not None:
storage_obj.add_store(StoreType(store.upper()))
mode=mode)

# Add force deletion flag
storage_obj.force_delete = force_delete

return storage_obj

def to_yaml_config(self) -> Dict[str, str]:
Expand Down Expand Up @@ -1280,7 +1255,7 @@ def delete(self) -> None:
msg_str = f'Deleted S3 bucket {self.name}.'
else:
msg_str = f'S3 bucket {self.name} may have been deleted ' \
f'externally. Removing from local state.'
f'externally. Removing from local state.'
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')

Expand Down Expand Up @@ -1696,7 +1671,7 @@ def delete(self) -> None:
msg_str = f'Deleted GCS bucket {self.name}.'
else:
msg_str = f'GCS bucket {self.name} may have been deleted ' \
f'externally. Removing from local state.'
f'externally. Removing from local state.'
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')

Expand Down Expand Up @@ -2561,7 +2536,7 @@ def _get_bucket(self) -> Tuple[str, bool]:
if 'Name or service not known' in error_message:
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageBucketGetError(
'Attempted to fetch the container from non-existant '
'Attempted to fetch the container from non-existent '
'storage account '
f'name: {self.storage_account_name}. Please check '
'if the name is correct.')
Expand Down Expand Up @@ -2807,7 +2782,7 @@ def delete(self) -> None:
msg_str = f'Deleted R2 bucket {self.name}.'
else:
msg_str = f'R2 bucket {self.name} may have been deleted ' \
f'externally. Removing from local state.'
f'externally. Removing from local state.'
logger.info(f'{colorama.Fore.GREEN}{msg_str}'
f'{colorama.Style.RESET_ALL}')

Expand Down
3 changes: 1 addition & 2 deletions sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,7 @@ def launch(
no_setup=no_setup,
clone_disk_from=clone_disk_from,
_is_launched_by_jobs_controller=_is_launched_by_jobs_controller,
_is_launched_by_sky_serve_controller=
_is_launched_by_sky_serve_controller,
_is_launched_by_sky_serve_controller=_is_launched_by_sky_serve_controller,
)


Expand Down
16 changes: 8 additions & 8 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ def from_yaml_config(
config['service'] = _fill_in_env_vars(config['service'],
config.get('envs', {}))

# Fill in any Task.envs into workdir
if config.get('workdir') is not None:
config['workdir'] = _fill_in_env_vars(config['workdir'],
config.get('envs', {}))

task = Task(
config.pop('name', None),
run=config.pop('run', None),
Expand All @@ -411,7 +416,6 @@ def from_yaml_config(
if file_mounts is not None:
copy_mounts = {}
for dst_path, src in file_mounts.items():

# Check if it is str path
if isinstance(src, str):
copy_mounts[dst_path] = src
Expand All @@ -431,10 +435,7 @@ def from_yaml_config(
mount_path = storage[0]
assert mount_path, 'Storage mount path cannot be empty.'
try:
resources = config.get('resources', None)
region = resources.get('region', None) if resources else None
storage_obj = storage_lib.Storage.from_yaml_config(
storage[1], region)
storage_obj = storage_lib.Storage.from_yaml_config(storage[1])
except exceptions.StorageSourceError as e:
# Patch the error message to include the mount path, if included
e.args = (e.args[0].replace('<destination_path>',
Expand Down Expand Up @@ -841,7 +842,6 @@ def set_storage_mounts(
Raises:
ValueError: if input paths are invalid.
"""

if storage_mounts is None:
self.storage_mounts = {}
# Clear the requires_fuse flag if no storage mounts are set.
Expand Down Expand Up @@ -958,7 +958,6 @@ def sync_storage_mounts(self) -> None:
for storage in self.storage_mounts.values():
if len(storage.stores) == 0:
store_type, store_region = self._get_preferred_store()

self.storage_plans[storage] = store_type
storage.add_store(store_type, store_region)
else:
Expand Down Expand Up @@ -1146,7 +1145,8 @@ def get_required_cloud_features(

# Multi-node
if self.num_nodes > 1:
required_features.add(clouds.CloudImplementationFeatures.MULTI_NODE)
required_features.add(
clouds.CloudImplementationFeatures.MULTI_NODE)

# Storage mounting
for _, storage_mount in self.storage_mounts.items():
Expand Down

0 comments on commit 04d148f

Please sign in to comment.