diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 23af81d4e..c0b544e77 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -51,16 +51,16 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ + valid_schedulers: Dict[str, SchedulerFactory] = {} + if not skip_defaults: + for scheduler_name, path in DEFAULT_SCHEDULER_MODULES.items(): + valid_schedulers[scheduler_name] = _defer_load_scheduler(path) - default_schedulers: Dict[str, SchedulerFactory] = {} - for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): - default_schedulers[scheduler] = _defer_load_scheduler(path) + entry_point_schedulers = load_group(group, default=None, skip_defaults=True) + if entry_point_schedulers: + valid_schedulers.update(entry_point_schedulers) - return load_group( - group, - default=default_schedulers, - skip_defaults=skip_defaults, - ) + return valid_schedulers def get_default_scheduler_name() -> str: