Skip to content

Commit

Permalink
[Serve] Use safe_cursor for serve state (#3299)
Browse files Browse the repository at this point in the history
Use safe_cursor for serve
  • Loading branch information
Michaelvll authored Mar 12, 2024
1 parent 1c32bbb commit c873c31
Showing 1 changed file with 117 additions and 107 deletions.
224 changes: 117 additions & 107 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ def add_service(name: str, controller_job_id: int, policy: str, version: int,
exists.
"""
try:
_DB.cursor.execute(
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str, current_version)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str, version))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str, current_version)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str, version))
except sqlite3.IntegrityError as e:
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
raise RuntimeError('Unexpected database error') from e
Expand All @@ -232,48 +232,49 @@ def add_service(name: str, controller_job_id: int, policy: str, version: int,

def remove_service(service_name: str) -> None:
"""Removes a service from the database."""
_DB.cursor.execute("""\
DELETE FROM services WHERE name=(?)""", (service_name,))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute("""\
DELETE FROM services WHERE name=(?)""", (service_name,))


def set_service_uptime(service_name: str, uptime: int) -> None:
"""Sets the uptime of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
uptime=(?) WHERE name=(?)""", (uptime, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
uptime=(?) WHERE name=(?)""", (uptime, service_name))


def set_service_status(service_name: str, status: ServiceStatus) -> None:
"""Sets the service status."""
_DB.cursor.execute(
"""\
UPDATE services SET
status=(?) WHERE name=(?)""", (status.value, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
status=(?) WHERE name=(?)""", (status.value, service_name))


def set_service_controller_port(service_name: str,
controller_port: int) -> None:
"""Sets the controller port of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
controller_port=(?) WHERE name=(?)""", (controller_port, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
controller_port=(?) WHERE name=(?)""",
(controller_port, service_name))


def set_service_load_balancer_port(service_name: str,
load_balancer_port: int) -> None:
"""Sets the load balancer port of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
load_balancer_port=(?) WHERE name=(?)""",
(load_balancer_port, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
load_balancer_port=(?) WHERE name=(?)""",
(load_balancer_port, service_name))


def _get_service_from_row(row) -> Dict[str, Any]:
Expand All @@ -299,7 +300,8 @@ def _get_service_from_row(row) -> Dict[str, Any]:

def get_services() -> List[Dict[str, Any]]:
"""Get all existing service records."""
rows = _DB.cursor.execute('SELECT * FROM services').fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT * FROM services').fetchall()
records = []
for row in rows:
records.append(_get_service_from_row(row))
Expand All @@ -308,7 +310,8 @@ def get_services() -> List[Dict[str, Any]]:

def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:
"""Get all existing service records."""
rows = _DB.cursor.execute('SELECT * FROM services WHERE name=(?)',
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT * FROM services WHERE name=(?)',
(service_name,)).fetchall()
for row in rows:
return _get_service_from_row(row)
Expand All @@ -317,10 +320,11 @@ def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:

def get_service_versions(service_name: str) -> List[int]:
"""Gets all versions of a service."""
rows = _DB.cursor.execute(
"""\
SELECT DISTINCT version FROM version_specs
WHERE service_name=(?)""", (service_name,)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT DISTINCT version FROM version_specs
WHERE service_name=(?)""", (service_name,)).fetchall()
return [row[0] for row in rows]


Expand All @@ -335,50 +339,52 @@ def get_glob_service_names(
Returns:
A list of non-duplicated service names.
"""
if service_names is None:
rows = _DB.cursor.execute('SELECT name FROM services').fetchall()
else:
rows = []
for service_name in service_names:
rows.extend(
_DB.cursor.execute(
'SELECT name FROM services WHERE name GLOB (?)',
(service_name,)).fetchall())
with db_utils.safe_cursor(_DB_PATH) as cursor:
if service_names is None:
rows = cursor.execute('SELECT name FROM services').fetchall()
else:
rows = []
for service_name in service_names:
rows.extend(
cursor.execute(
'SELECT name FROM services WHERE name GLOB (?)',
(service_name,)).fetchall())
return list({row[0] for row in rows})


# === Replica functions ===
def add_or_update_replica(service_name: str, replica_id: int,
replica_info: 'replica_managers.ReplicaInfo') -> None:
"""Adds a replica to the database."""
_DB.cursor.execute(
"""\
INSERT OR REPLACE INTO replicas
(service_name, replica_id, replica_info)
VALUES (?, ?, ?)""",
(service_name, replica_id, pickle.dumps(replica_info)))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT OR REPLACE INTO replicas
(service_name, replica_id, replica_info)
VALUES (?, ?, ?)""",
(service_name, replica_id, pickle.dumps(replica_info)))


def remove_replica(service_name: str, replica_id: int) -> None:
"""Removes a replica from the database."""
_DB.cursor.execute(
"""\
DELETE FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id))


def get_replica_info_from_id(
service_name: str,
replica_id: int) -> Optional['replica_managers.ReplicaInfo']:
"""Gets a replica info from the database."""
rows = _DB.cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id)).fetchall()
for row in rows:
return pickle.loads(row[0])
return None
Expand All @@ -387,16 +393,18 @@ def get_replica_info_from_id(
def get_replica_infos(
service_name: str) -> List['replica_managers.ReplicaInfo']:
"""Gets all replica infos of a service."""
rows = _DB.cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)""", (service_name,)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)""", (service_name,)).fetchall()
return [pickle.loads(row[0]) for row in rows]


def total_number_provisioning_replicas() -> int:
"""Returns the total number of provisioning replicas."""
rows = _DB.cursor.execute('SELECT replica_info FROM replicas').fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT replica_info FROM replicas').fetchall()
provisioning_count = 0
for row in rows:
replica_info: 'replica_managers.ReplicaInfo' = pickle.loads(row[0])
Expand All @@ -409,62 +417,64 @@ def total_number_provisioning_replicas() -> int:
def add_version(service_name: str) -> int:
"""Adds a version to the database."""

_DB.cursor.execute(
"""\
INSERT INTO version_specs
(version, service_name, spec)
VALUES (
(SELECT COALESCE(MAX(version), 0) + 1 FROM
version_specs WHERE service_name = ?), ?, ?)
RETURNING version""", (service_name, service_name, pickle.dumps(None)))
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT INTO version_specs
(version, service_name, spec)
VALUES (
(SELECT COALESCE(MAX(version), 0) + 1 FROM
version_specs WHERE service_name = ?), ?, ?)
RETURNING version""",
(service_name, service_name, pickle.dumps(None)))

inserted_version = _DB.cursor.fetchone()[0]
_DB.conn.commit()
inserted_version = cursor.fetchone()[0]

return inserted_version


def add_or_update_version(service_name: str, version: int,
spec: 'service_spec.SkyServiceSpec') -> None:
_DB.cursor.execute(
"""\
INSERT or REPLACE INTO version_specs
(service_name, version, spec)
VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
_DB.cursor.execute(
"""\
UPDATE services SET
current_version=(?) WHERE name=(?)""", (version, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT or REPLACE INTO version_specs
(service_name, version, spec)
VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
cursor.execute(
"""\
UPDATE services SET
current_version=(?) WHERE name=(?)""", (version, service_name))


def remove_service_versions(service_name: str) -> None:
"""Removes a replica from the database."""
_DB.cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))


def get_spec(service_name: str,
version: int) -> Optional['service_spec.SkyServiceSpec']:
"""Gets spec from the database."""
rows = _DB.cursor.execute(
"""\
SELECT spec FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT spec FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version)).fetchall()
for row in rows:
return pickle.loads(row[0])
return None


def delete_version(service_name: str, version: int) -> None:
"""Deletes a version from the database."""
_DB.cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version))

0 comments on commit c873c31

Please sign in to comment.