Skip to content

Commit

Permalink
Replace len() Zero Checks with Pythonic Empty Sequence Checks (#4298)
Browse files Browse the repository at this point in the history
* style: mainly replace len() comparisons with 0/1 with pythonic empty sequence checks

* chore: more typings

* use `df.empty` for dataframe

* fix: more `df.empty`

* format

* revert partially

* style: add back comments

* style: format

* refactor: `dict[str, str]`

Co-authored-by: Tian Xia <[email protected]>

---------

Co-authored-by: Tian Xia <[email protected]>
  • Loading branch information
andylizf and cblmemo authored Dec 23, 2024
1 parent e596111 commit 2bd7c3e
Show file tree
Hide file tree
Showing 46 changed files with 143 additions and 141 deletions.
2 changes: 1 addition & 1 deletion examples/spot/lightning_cifar10/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main():
)

model_ckpts = glob.glob(argv.root_dir + "/*.ckpt")
if argv.resume and len(model_ckpts) > 0:
if argv.resume and model_ckpts:
latest_ckpt = max(model_ckpts, key=os.path.getctime)
trainer.fit(model, cifar10_dm, ckpt_path=latest_ckpt)
else:
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2626,7 +2626,7 @@ def register_info(self, **kwargs) -> None:
self._optimize_target) or optimizer.OptimizeTarget.COST
self._requested_features = kwargs.pop('requested_features',
self._requested_features)
assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}'
assert not kwargs, f'Unexpected kwargs: {kwargs}'

def check_resources_fit_cluster(
self,
Expand Down
2 changes: 1 addition & 1 deletion sky/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_all_clouds():
'\nNote: The following clouds were disabled because they were not '
'included in allowed_clouds in ~/.sky/config.yaml: '
f'{", ".join([c for c in disallowed_cloud_names])}')
if len(all_enabled_clouds) == 0:
if not all_enabled_clouds:
echo(
click.style(
'No cloud is enabled. SkyPilot will not be able to run any '
Expand Down
50 changes: 24 additions & 26 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]:
glob_clusters = []
for cluster in clusters:
glob_cluster = global_user_state.get_glob_cluster_names(cluster)
if len(glob_cluster) == 0 and not silent:
if not glob_cluster and not silent:
click.echo(f'Cluster {cluster} not found.')
glob_clusters.extend(glob_cluster)
return list(set(glob_clusters))
Expand All @@ -125,7 +125,7 @@ def _get_glob_storages(storages: List[str]) -> List[str]:
glob_storages = []
for storage_object in storages:
glob_storage = global_user_state.get_glob_storage_name(storage_object)
if len(glob_storage) == 0:
if not glob_storage:
click.echo(f'Storage {storage_object} not found.')
glob_storages.extend(glob_storage)
return list(set(glob_storages))
Expand Down Expand Up @@ -1473,7 +1473,7 @@ def _get_services(service_names: Optional[List[str]],
if len(service_records) != 1:
plural = 's' if len(service_records) > 1 else ''
service_num = (str(len(service_records))
if len(service_records) > 0 else 'No')
if service_records else 'No')
raise click.UsageError(
f'{service_num} service{plural} found. Please specify '
'an existing service to show its endpoint. Usage: '
Expand Down Expand Up @@ -1696,8 +1696,7 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
if len(clusters) != 1:
with ux_utils.print_exception_no_traceback():
plural = 's' if len(clusters) > 1 else ''
cluster_num = (str(len(clusters))
if len(clusters) > 0 else 'No')
cluster_num = (str(len(clusters)) if clusters else 'No')
cause = 'a single' if len(clusters) > 1 else 'an existing'
raise ValueError(
_STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format(
Expand All @@ -1722,9 +1721,8 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
with ux_utils.print_exception_no_traceback():
plural = 's' if len(cluster_records) > 1 else ''
cluster_num = (str(len(cluster_records))
if len(cluster_records) > 0 else
f'{clusters[0]!r}')
verb = 'found' if len(cluster_records) > 0 else 'not found'
if cluster_records else f'{clusters[0]!r}')
verb = 'found' if cluster_records else 'not found'
cause = 'a single' if len(clusters) > 1 else 'an existing'
raise ValueError(
_STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format(
Expand Down Expand Up @@ -2470,7 +2468,7 @@ def start(
'(see `sky status`), or the -a/--all flag.')

if all:
if len(clusters) > 0:
if clusters:
click.echo('Both --all and cluster(s) specified for sky start. '
'Letting --all take effect.')

Expand Down Expand Up @@ -2800,7 +2798,7 @@ def _down_or_stop_clusters(
option_str = '{stop,down}'
operation = f'{verb} auto{option_str} on'

if len(names) > 0:
if names:
controllers = [
name for name in names
if controller_utils.Controllers.from_name(name) is not None
Expand All @@ -2814,7 +2812,7 @@ def _down_or_stop_clusters(
# Make sure the controllers are explicitly specified without other
# normal clusters.
if controllers:
if len(names) != 0:
if names:
names_str = ', '.join(map(repr, names))
raise click.UsageError(
f'{operation} controller(s) '
Expand Down Expand Up @@ -2867,7 +2865,7 @@ def _down_or_stop_clusters(

if apply_to_all:
all_clusters = global_user_state.get_clusters()
if len(names) > 0:
if names:
click.echo(
f'Both --all and cluster(s) specified for `sky {command}`. '
'Letting --all take effect.')
Expand All @@ -2894,7 +2892,7 @@ def _down_or_stop_clusters(
click.echo('Cluster(s) not found (tip: see `sky status`).')
return

if not no_confirm and len(clusters) > 0:
if not no_confirm and clusters:
cluster_str = 'clusters' if len(clusters) > 1 else 'cluster'
cluster_list = ', '.join(clusters)
click.confirm(
Expand Down Expand Up @@ -3003,7 +3001,7 @@ def check(clouds: Tuple[str], verbose: bool):
# Check only specific clouds - AWS and GCP.
sky check aws gcp
"""
clouds_arg = clouds if len(clouds) > 0 else None
clouds_arg = clouds if clouds else None
sky_check.check(verbose=verbose, clouds=clouds_arg)


Expand Down Expand Up @@ -3138,7 +3136,7 @@ def _get_kubernetes_realtime_gpu_table(
f'capacity ({list(capacity.keys())}), '
f'and available ({list(available.keys())}) '
'must be same.')
if len(counts) == 0:
if not counts:
err_msg = 'No GPUs found in Kubernetes cluster. '
debug_msg = 'To further debug, run: sky check '
if name_filter is not None:
Expand Down Expand Up @@ -3282,7 +3280,7 @@ def _output():
for tpu in service_catalog.get_tpus():
if tpu in result:
tpu_table.add_row([tpu, _list_to_str(result.pop(tpu))])
if len(tpu_table.get_string()) > 0:
if tpu_table.get_string():
yield '\n\n'
yield from tpu_table.get_string()

Expand Down Expand Up @@ -3393,7 +3391,7 @@ def _output():
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Cloud GPUs{colorama.Style.RESET_ALL}\n')

if len(result) == 0:
if not result:
quantity_str = (f' with requested quantity {quantity}'
if quantity else '')
cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.'
Expand Down Expand Up @@ -3522,7 +3520,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r
# Delete all storage objects.
sky storage delete -a
"""
if sum([len(names) > 0, all]) != 1:
if sum([bool(names), all]) != 1:
raise click.UsageError('Either --all or a name must be specified.')
if all:
storages = sky.storage_ls()
Expand Down Expand Up @@ -3881,8 +3879,8 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
exit_if_not_accessible=True)

job_id_str = ','.join(map(str, job_ids))
if sum([len(job_ids) > 0, name is not None, all]) != 1:
argument_str = f'--job-ids {job_id_str}' if len(job_ids) > 0 else ''
if sum([bool(job_ids), name is not None, all]) != 1:
argument_str = f'--job-ids {job_id_str}' if job_ids else ''
argument_str += f' --name {name}' if name is not None else ''
argument_str += ' --all' if all else ''
raise click.UsageError(
Expand Down Expand Up @@ -4523,9 +4521,9 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool,
# Forcefully tear down a specific replica, even in failed status.
sky serve down my-service --replica-id 1 --purge
"""
if sum([len(service_names) > 0, all]) != 1:
argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len(
service_names) > 0 else ''
if sum([bool(service_names), all]) != 1:
argument_str = (f'SERVICE_NAMES={",".join(service_names)}'
if service_names else '')
argument_str += ' --all' if all else ''
raise click.UsageError(
'Can only specify one of SERVICE_NAMES or --all. '
Expand Down Expand Up @@ -4898,7 +4896,7 @@ def benchmark_launch(
if idle_minutes_to_autostop is None:
idle_minutes_to_autostop = 5
commandline_args['idle-minutes-to-autostop'] = idle_minutes_to_autostop
if len(env) > 0:
if env:
commandline_args['env'] = [f'{k}={v}' for k, v in env]

# Launch the benchmarking clusters in detach mode in parallel.
Expand Down Expand Up @@ -5177,7 +5175,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool],
raise click.BadParameter(
'Either specify benchmarks or use --all to delete all benchmarks.')
to_delete = []
if len(benchmarks) > 0:
if benchmarks:
for benchmark in benchmarks:
record = benchmark_state.get_benchmark_from_name(benchmark)
if record is None:
Expand All @@ -5186,7 +5184,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool],
to_delete.append(record)
if all:
to_delete = benchmark_state.get_benchmarks()
if len(benchmarks) > 0:
if benchmarks:
print('Both --all and benchmark(s) specified '
'for sky bench delete. Letting --all take effect.')

Expand Down
2 changes: 1 addition & 1 deletion sky/cloud_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def is_directory(self, url: str) -> bool:
# If <url> is a bucket root, then we only need `gsutil` to succeed
# to make sure the bucket exists. It is already a directory.
_, key = data_utils.split_gcs_path(url)
if len(key) == 0:
if not key:
return True
# Otherwise, gsutil ls -d url will return:
# --> url.rstrip('/') if url is not a directory
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
ret_permissions = request.execute().get('permissions', [])

diffs = set(gcp_minimal_permissions).difference(set(ret_permissions))
if len(diffs) > 0:
if diffs:
identity_str = identity[0] if identity else None
return False, (
'The following permissions are not enabled for the current '
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _existing_allowed_contexts(cls) -> List[str]:
use the service account mounted in the pod.
"""
all_contexts = kubernetes_utils.get_all_kube_context_names()
if len(all_contexts) == 0:
if not all_contexts:
return []

all_contexts = set(all_contexts)
Expand Down
21 changes: 11 additions & 10 deletions sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,10 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9)
candidate_loc = sorted(candidate_loc)
candidate_strs = ''
if len(candidate_loc) > 0:
if candidate_loc:
candidate_strs = ', '.join(candidate_loc)
candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'

return candidate_strs

def _get_all_supported_regions_str() -> str:
Expand All @@ -286,7 +287,7 @@ def _get_all_supported_regions_str() -> str:
filter_df = df
if region is not None:
filter_df = _filter_region_zone(filter_df, region, zone=None)
if len(filter_df) == 0:
if filter_df.empty:
with ux_utils.print_exception_no_traceback():
error_msg = (f'Invalid region {region!r}')
candidate_strs = _get_candidate_str(
Expand All @@ -310,7 +311,7 @@ def _get_all_supported_regions_str() -> str:
if zone is not None:
maybe_region_df = filter_df
filter_df = filter_df[filter_df['AvailabilityZone'] == zone]
if len(filter_df) == 0:
if filter_df.empty:
region_str = f' for region {region!r}' if region else ''
df = maybe_region_df if region else df
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -378,7 +379,7 @@ def get_vcpus_mem_from_instance_type_impl(
instance_type: str,
) -> Tuple[Optional[float], Optional[float]]:
df = _get_instance_type(df, instance_type, None)
if len(df) == 0:
if df.empty:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'No instance type {instance_type} found.')
assert len(set(df['vCPUs'])) == 1, ('Cannot determine the number of vCPUs '
Expand Down Expand Up @@ -484,7 +485,7 @@ def get_accelerators_from_instance_type_impl(
instance_type: str,
) -> Optional[Dict[str, Union[int, float]]]:
df = _get_instance_type(df, instance_type, None)
if len(df) == 0:
if df.empty:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'No instance type {instance_type} found.')
row = df.iloc[0]
Expand Down Expand Up @@ -518,7 +519,7 @@ def get_instance_type_for_accelerator_impl(
result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) &
(abs(df['AcceleratorCount'] - acc_count) <= 0.01)]
result = _filter_region_zone(result, region, zone)
if len(result) == 0:
if result.empty:
fuzzy_result = df[
(df['AcceleratorName'].str.contains(acc_name, case=False)) &
(df['AcceleratorCount'] >= acc_count)]
Expand All @@ -527,7 +528,7 @@ def get_instance_type_for_accelerator_impl(
fuzzy_result = fuzzy_result[['AcceleratorName',
'AcceleratorCount']].drop_duplicates()
fuzzy_candidate_list = []
if len(fuzzy_result) > 0:
if not fuzzy_result.empty:
for _, row in fuzzy_result.iterrows():
acc_cnt = float(row['AcceleratorCount'])
acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else
Expand All @@ -539,7 +540,7 @@ def get_instance_type_for_accelerator_impl(
result = _filter_with_cpus(result, cpus)
result = _filter_with_mem(result, memory)
result = _filter_region_zone(result, region, zone)
if len(result) == 0:
if result.empty:
return ([], [])

# Current strategy: choose the cheapest instance
Expand Down Expand Up @@ -680,7 +681,7 @@ def get_image_id_from_tag_impl(df: 'pd.DataFrame', tag: str,
df = _filter_region_zone(df, region, zone=None)
assert len(df) <= 1, ('Multiple images found for tag '
f'{tag} in region {region}')
if len(df) == 0:
if df.empty:
return None
image_id = df['ImageId'].iloc[0]
if pd.isna(image_id):
Expand All @@ -694,4 +695,4 @@ def is_image_tag_valid_impl(df: 'pd.DataFrame', tag: str,
df = df[df['Tag'] == tag]
df = _filter_region_zone(df, region, zone=None)
df = df.dropna(subset=['ImageId'])
return len(df) > 0
return not df.empty
2 changes: 1 addition & 1 deletion sky/clouds/service_catalog/data_fetchers/fetch_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_pricing_df(region: Optional[str] = None) -> 'pd.DataFrame':
content_str = r.content.decode('ascii')
content = json.loads(content_str)
items = content.get('Items', [])
if len(items) == 0:
if not items:
break
all_items += items
url = content.get('NextPageLink')
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def initialize_images_csv(csv_saving_path: str, vc_object,
gpu_name = tag_name.split('-')[1]
if gpu_name not in gpu_tags:
gpu_tags.append(gpu_name)
if len(gpu_tags) > 0:
if gpu_tags:
gpu_tags_str = str(gpu_tags).replace('\'', '\"')
f.write(f'{item.id},{vcenter_name},{item_cpu},{item_memory}'
f',,,\'{gpu_tags_str}\'\n')
Expand Down
6 changes: 3 additions & 3 deletions sky/clouds/utils/scp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __setitem__(self, instance_id: str, value: Optional[Dict[str,
if value is None:
if instance_id in metadata:
metadata.pop(instance_id) # del entry
if len(metadata) == 0:
if not metadata:
if os.path.exists(self.path):
os.remove(self.path)
return
Expand All @@ -84,7 +84,7 @@ def refresh(self, instance_ids: List[str]) -> None:
for instance_id in list(metadata.keys()):
if instance_id not in instance_ids:
del metadata[instance_id]
if len(metadata) == 0:
if not metadata:
os.remove(self.path)
return
with open(self.path, 'w', encoding='utf-8') as f:
Expand Down Expand Up @@ -410,7 +410,7 @@ def list_security_groups(self, vpc_id=None, sg_name=None):
parameter.append('vpcId=' + vpc_id)
if sg_name is not None:
parameter.append('securityGroupName=' + sg_name)
if len(parameter) > 0:
if parameter:
url = url + '?' + '&'.join(parameter)
return self._get(url)

Expand Down
Loading

0 comments on commit 2bd7c3e

Please sign in to comment.