Skip to content

Commit

Permalink
Use an enum for load balancer types
Browse files Browse the repository at this point in the history
  • Loading branch information
ucbstudent committed Dec 10, 2024
1 parent e1037ab commit 74bcb7d
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
12 changes: 9 additions & 3 deletions sky/serve/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Constants used for SkyServe."""

from enum import Enum

CONTROLLER_TEMPLATE = 'sky-serve-controller.yaml.j2'

SKYSERVE_METADATA_DIR = '~/.sky/serve'
Expand Down Expand Up @@ -105,6 +107,10 @@
ENVOY_THREADS = '1'
ENVOY_VERSION = '1.32.0'

LB_TYPE_PYTHON = 'python'
LB_TYPE_ENVOY = 'envoy'
LB_TYPES = [LB_TYPE_PYTHON, LB_TYPE_ENVOY]

class LbType(Enum):
PYTHON = 'python'
ENVOY = 'envoy'


ALL_LB_TYPES = [t.value for t in LbType]
2 changes: 1 addition & 1 deletion sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def up(
# Check that the Envoy load balancer isn't being used on an unsupported
# cloud.
lb_type = task_config.get('service', {}).get('load_balancer_type', None)
if lb_type == serve_constants.LB_TYPE_ENVOY:
if lb_type == serve_constants.LbType.ENVOY.value:
for resource in controller_resources:
if resource.cloud is None:
continue
Expand Down
4 changes: 2 additions & 2 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,15 @@ def run_load_balancer(service_name: str,
None.
"""

if load_balancer_type == constants.LB_TYPE_PYTHON \
if load_balancer_type == constants.LbType.PYTHON.value \
or load_balancer_type is None:
plb = PythonLoadBalancer(
service_name=service_name,
controller_url=controller_addr,
load_balancer_port=load_balancer_port,
load_balancing_policy_name=load_balancing_policy_name)
plb.run()
elif load_balancer_type == constants.LB_TYPE_ENVOY:
elif load_balancer_type == constants.LbType.ENVOY.value:
elb = EnvoyLoadBalancer(service_name=service_name,
controller_url=controller_addr,
load_balancer_port=load_balancer_port)
Expand Down
4 changes: 2 additions & 2 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def __init__(
f'Available policies: {list(serve.LB_POLICIES.keys())}')

if (load_balancer_type is not None and
load_balancer_type not in constants.LB_TYPES):
load_balancer_type not in constants.ALL_LB_TYPES):
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Unknown load balancer type: {load_balancer_type}. '
f'Available load balancers: {constants.LB_TYPES}')
f'Available load balancers: {constants.ALL_LB_TYPES}')
self._readiness_path: str = readiness_path
self._initial_delay_seconds: int = initial_delay_seconds
self._readiness_timeout_seconds: int = readiness_timeout_seconds
Expand Down
2 changes: 1 addition & 1 deletion sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_service_schema():
},
'load_balancer_type': {
'type': 'string',
'case_insensitive_enum': serve_constants.LB_TYPES
'case_insensitive_enum': serve_constants.ALL_LB_TYPES,
},
}
}
Expand Down

0 comments on commit 74bcb7d

Please sign in to comment.