Skip to content

Commit

Permalink
chore: Migrate imperative kernels table to declarative KernelRow
Browse files Browse the repository at this point in the history
…ORM class (#2309)

We are migrating [imperative mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#imperative-mapping) to [declarative mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#imperative-mapping).
Add `KernelRow` declarative ORM class and leave `kernels` table object for compatibility

**Checklist:** (if applicable)

- [x] Milestone metadata specifying the target backport version
  • Loading branch information
fregataa committed Jun 20, 2024
1 parent 6e6a420 commit 6c2ec78
Showing 1 changed file with 115 additions and 89 deletions.
204 changes: 115 additions & 89 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
URLColumn,
batch_multiresult,
batch_result,
mapper_registry,
)
from .group import groups
from .image import ImageNode, ImageRow
Expand Down Expand Up @@ -373,110 +372,134 @@ async def handle_kernel_exception(
raise


kernels = sa.Table(
"kernels",
mapper_registry.metadata,
class KernelRow(Base):
__tablename__ = "kernels"

# The Backend.AI-side UUID for each kernel
# (mapped to a container in the docker backend and a pod in the k8s backend)
KernelIDColumn(),
id = KernelIDColumn()
# session_id == id when the kernel is the main container in a multi-container session or a
# single-container session.
# Otherwise, it refers the kernel ID of the main container of the belonged multi-container session.
sa.Column(
session_id = sa.Column(
"session_id",
SessionIDColumnType,
sa.ForeignKey("sessions.id"),
unique=False,
index=True,
nullable=False,
),
sa.Column("session_creation_id", sa.String(length=32), unique=False, index=False),
sa.Column("session_name", sa.String(length=64), unique=False, index=True), # previously sess_id
sa.Column(
)
session_creation_id = sa.Column(
"session_creation_id", sa.String(length=32), unique=False, index=False
)
session_name = sa.Column(
"session_name", sa.String(length=64), unique=False, index=True
) # previously sess_id
session_type = sa.Column(
"session_type",
EnumType(SessionTypes),
index=True,
nullable=False, # previously sess_type
default=SessionTypes.INTERACTIVE,
server_default=SessionTypes.INTERACTIVE.name,
),
sa.Column(
)
cluster_mode = sa.Column(
"cluster_mode",
sa.String(length=16),
nullable=False,
default=ClusterMode.SINGLE_NODE,
server_default=ClusterMode.SINGLE_NODE.name,
),
sa.Column("cluster_size", sa.Integer, nullable=False, default=1),
sa.Column(
)
cluster_size = sa.Column("cluster_size", sa.Integer, nullable=False, default=1)
cluster_role = sa.Column(
"cluster_role", sa.String(length=16), nullable=False, default=DEFAULT_ROLE, index=True
),
sa.Column("cluster_idx", sa.Integer, nullable=False, default=0),
sa.Column("local_rank", sa.Integer, nullable=False, default=0),
sa.Column("cluster_hostname", sa.String(length=64), nullable=False, default=default_hostname),
)
cluster_idx = sa.Column("cluster_idx", sa.Integer, nullable=False, default=0)
local_rank = sa.Column("local_rank", sa.Integer, nullable=False, default=0)
cluster_hostname = sa.Column(
"cluster_hostname", sa.String(length=64), nullable=False, default=default_hostname
)
# Resource ownership
sa.Column("scaling_group", sa.ForeignKey("scaling_groups.name"), index=True, nullable=True),
sa.Column("agent", sa.String(length=64), sa.ForeignKey("agents.id"), nullable=True),
sa.Column("agent_addr", sa.String(length=128), nullable=True),
sa.Column("domain_name", sa.String(length=64), sa.ForeignKey("domains.name"), nullable=False),
sa.Column("group_id", GUID, sa.ForeignKey("groups.id"), nullable=False),
sa.Column("user_uuid", GUID, sa.ForeignKey("users.uuid"), nullable=False),
sa.Column("access_key", sa.String(length=20), sa.ForeignKey("keypairs.access_key")),
scaling_group = sa.Column(
"scaling_group", sa.ForeignKey("scaling_groups.name"), index=True, nullable=True
)
agent = sa.Column("agent", sa.String(length=64), sa.ForeignKey("agents.id"), nullable=True)
agent_addr = sa.Column("agent_addr", sa.String(length=128), nullable=True)
domain_name = sa.Column(
"domain_name", sa.String(length=64), sa.ForeignKey("domains.name"), nullable=False
)
group_id = sa.Column("group_id", GUID, sa.ForeignKey("groups.id"), nullable=False)
user_uuid = sa.Column("user_uuid", GUID, sa.ForeignKey("users.uuid"), nullable=False)
access_key = sa.Column("access_key", sa.String(length=20), sa.ForeignKey("keypairs.access_key"))
# `image` is a string shaped "<REGISTRY>/<IMAGE>:<TAG>". it is identical to images.name column
sa.Column("image", sa.String(length=512)),
# ForeignKeyIDColumn("image_id", "images.id"),
sa.Column("architecture", sa.String(length=32), default="x86_64"),
sa.Column("registry", sa.String(length=512)),
sa.Column("tag", sa.String(length=64), nullable=True),
image = sa.Column("image", sa.String(length=512))
# ForeignKeyIDColumn("image_id", "images.id")
architecture = sa.Column("architecture", sa.String(length=32), default="x86_64")
registry = sa.Column("registry", sa.String(length=512))
tag = sa.Column("tag", sa.String(length=64), nullable=True)
# Resource occupation
sa.Column("container_id", sa.String(length=64)),
sa.Column("occupied_slots", ResourceSlotColumn(), nullable=False),
sa.Column("requested_slots", ResourceSlotColumn(), nullable=False, default=ResourceSlot()),
sa.Column("occupied_shares", pgsql.JSONB(), nullable=False, default={}), # legacy
sa.Column("environ", sa.ARRAY(sa.String), nullable=True),
sa.Column("mounts", sa.ARRAY(sa.String), nullable=True), # list of list; legacy since 22.03
sa.Column("mount_map", pgsql.JSONB(), nullable=True, default={}), # legacy since 22.03
sa.Column("vfolder_mounts", StructuredJSONObjectListColumn(VFolderMount), nullable=True),
sa.Column("attached_devices", pgsql.JSONB(), nullable=True, default={}),
sa.Column("resource_opts", pgsql.JSONB(), nullable=True, default={}),
sa.Column("bootstrap_script", sa.String(length=16 * 1024), nullable=True),
container_id = sa.Column("container_id", sa.String(length=64))
occupied_slots = sa.Column("occupied_slots", ResourceSlotColumn(), nullable=False)
requested_slots = sa.Column(
"requested_slots", ResourceSlotColumn(), nullable=False, default=ResourceSlot()
)
occupied_shares = sa.Column(
"occupied_shares", pgsql.JSONB(), nullable=False, default={}
) # legacy
environ = sa.Column("environ", sa.ARRAY(sa.String), nullable=True)
mounts = sa.Column(
"mounts", sa.ARRAY(sa.String), nullable=True
) # list of list; legacy since 22.03
mount_map = sa.Column(
"mount_map", pgsql.JSONB(), nullable=True, default={}
) # legacy since 22.03
vfolder_mounts = sa.Column(
"vfolder_mounts", StructuredJSONObjectListColumn(VFolderMount), nullable=True
)
attached_devices = sa.Column("attached_devices", pgsql.JSONB(), nullable=True, default={})
resource_opts = sa.Column("resource_opts", pgsql.JSONB(), nullable=True, default={})
bootstrap_script = sa.Column("bootstrap_script", sa.String(length=16 * 1024), nullable=True)
# Port mappings
# If kernel_host is NULL, it is assumed to be same to the agent host or IP.
sa.Column("kernel_host", sa.String(length=128), nullable=True),
sa.Column("repl_in_port", sa.Integer(), nullable=False),
sa.Column("repl_out_port", sa.Integer(), nullable=False),
sa.Column("stdin_port", sa.Integer(), nullable=False), # legacy for stream_pty
sa.Column("stdout_port", sa.Integer(), nullable=False), # legacy for stream_pty
sa.Column("service_ports", pgsql.JSONB(), nullable=True),
sa.Column("preopen_ports", sa.ARRAY(sa.Integer), nullable=True),
sa.Column("use_host_network", sa.Boolean(), default=False, nullable=False),
kernel_host = sa.Column("kernel_host", sa.String(length=128), nullable=True)
repl_in_port = sa.Column("repl_in_port", sa.Integer(), nullable=False)
repl_out_port = sa.Column("repl_out_port", sa.Integer(), nullable=False)
stdin_port = sa.Column("stdin_port", sa.Integer(), nullable=False) # legacy for stream_pty
stdout_port = sa.Column("stdout_port", sa.Integer(), nullable=False) # legacy for stream_pty
service_ports = sa.Column("service_ports", pgsql.JSONB(), nullable=True)
preopen_ports = sa.Column("preopen_ports", sa.ARRAY(sa.Integer), nullable=True)
use_host_network = sa.Column("use_host_network", sa.Boolean(), default=False, nullable=False)
# Lifecycle
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True),
sa.Column(
created_at = sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True
)
terminated_at = sa.Column(
"terminated_at", sa.DateTime(timezone=True), nullable=True, default=sa.null(), index=True
),
sa.Column("starts_at", sa.DateTime(timezone=True), nullable=True, default=sa.null()),
sa.Column(
)
starts_at = sa.Column("starts_at", sa.DateTime(timezone=True), nullable=True, default=sa.null())
status = sa.Column(
"status",
EnumType(KernelStatus),
default=KernelStatus.PENDING,
server_default=KernelStatus.PENDING.name,
nullable=False,
index=True,
),
sa.Column(
)
role = sa.Column(
"role",
EnumType(KernelRole),
default=KernelRole.COMPUTE,
server_default=KernelRole.COMPUTE.name,
nullable=False,
index=True,
),
sa.Column("status_changed", sa.DateTime(timezone=True), nullable=True, index=True),
sa.Column("status_info", sa.Unicode(), nullable=True, default=sa.null()),
)
status_changed = sa.Column(
"status_changed", sa.DateTime(timezone=True), nullable=True, index=True
)
status_info = sa.Column("status_info", sa.Unicode(), nullable=True, default=sa.null())
# status_info contains a kebab-cased string that expresses a summary of the last status change.
# Examples: "user-requested", "self-terminated", "predicate-checks-failed", "no-available-instances"
sa.Column("status_data", pgsql.JSONB(), nullable=True, default=sa.null()),
status_data = sa.Column("status_data", pgsql.JSONB(), nullable=True, default=sa.null())
# status_data contains a JSON object that contains detailed data for the last status change.
# During scheduling (as PENDING + ("no-available-instances" | "predicate-checks-failed")):
# {
Expand Down Expand Up @@ -512,43 +535,43 @@ async def handle_kernel_exception(
# // used to prevent duplication of SessionTerminatedEvent
# }
# }
sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()),
sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()),
sa.Column("startup_command", sa.Text, nullable=True),
sa.Column(
status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null())
callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null())
startup_command = sa.Column("startup_command", sa.Text, nullable=True)
result = sa.Column(
"result",
EnumType(SessionResult),
default=SessionResult.UNDEFINED,
server_default=SessionResult.UNDEFINED.name,
nullable=False,
index=True,
),
sa.Column("internal_data", pgsql.JSONB(), nullable=True),
sa.Column("container_log", sa.LargeBinary(), nullable=True),
)
internal_data = sa.Column("internal_data", pgsql.JSONB(), nullable=True)
container_log = sa.Column("container_log", sa.LargeBinary(), nullable=True)
# Resource metrics measured upon termination
sa.Column("num_queries", sa.BigInteger(), default=0),
sa.Column("last_stat", pgsql.JSONB(), nullable=True, default=sa.null()),
sa.Index("ix_kernels_sess_id_role", "session_id", "cluster_role", unique=False),
sa.Index("ix_kernels_status_role", "status", "cluster_role"),
sa.Index(
"ix_kernels_updated_order",
sa.func.greatest("created_at", "terminated_at", "status_changed"),
unique=False,
),
sa.Index(
"ix_kernels_unique_sess_token",
"access_key",
"session_name",
unique=True,
postgresql_where=sa.text(
"status NOT IN ('TERMINATED', 'CANCELLED') and cluster_role = 'main'"
num_queries = sa.Column("num_queries", sa.BigInteger(), default=0)
last_stat = sa.Column("last_stat", pgsql.JSONB(), nullable=True, default=sa.null())

__table_args__ = (
# indexing
sa.Index("ix_kernels_sess_id_role", "session_id", "cluster_role", unique=False),
sa.Index("ix_kernels_status_role", "status", "cluster_role"),
sa.Index(
"ix_kernels_updated_order",
sa.func.greatest("created_at", "terminated_at", "status_changed"),
unique=False,
),
),
)

sa.Index(
"ix_kernels_unique_sess_token",
"access_key",
"session_name",
unique=True,
postgresql_where=sa.text(
"status NOT IN ('TERMINATED', 'CANCELLED') and cluster_role = 'main'"
),
),
)

class KernelRow(Base):
__table__ = kernels
session = relationship("SessionRow", back_populates="kernels")
image_row = relationship(
"ImageRow",
Expand Down Expand Up @@ -715,6 +738,9 @@ async def _update() -> bool:
return await execute_with_retry(_update)


# For compatibility
kernels = KernelRow.__table__

DEFAULT_KERNEL_ORDERING = [
sa.desc(
sa.func.greatest(
Expand Down

0 comments on commit 6c2ec78

Please sign in to comment.