diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index b9cf03e545..c3f8f88033 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -70,7 +70,6 @@ URLColumn, batch_multiresult, batch_result, - mapper_registry, ) from .group import groups from .image import ImageNode, ImageRow @@ -373,110 +372,134 @@ async def handle_kernel_exception( raise -kernels = sa.Table( - "kernels", - mapper_registry.metadata, +class KernelRow(Base): + __tablename__ = "kernels" + # The Backend.AI-side UUID for each kernel # (mapped to a container in the docker backend and a pod in the k8s backend) - KernelIDColumn(), + id = KernelIDColumn() # session_id == id when the kernel is the main container in a multi-container session or a # single-container session. # Otherwise, it refers the kernel ID of the main container of the belonged multi-container session. - sa.Column( + session_id = sa.Column( "session_id", SessionIDColumnType, sa.ForeignKey("sessions.id"), unique=False, index=True, nullable=False, - ), - sa.Column("session_creation_id", sa.String(length=32), unique=False, index=False), - sa.Column("session_name", sa.String(length=64), unique=False, index=True), # previously sess_id - sa.Column( + ) + session_creation_id = sa.Column( + "session_creation_id", sa.String(length=32), unique=False, index=False + ) + session_name = sa.Column( + "session_name", sa.String(length=64), unique=False, index=True + ) # previously sess_id + session_type = sa.Column( "session_type", EnumType(SessionTypes), index=True, nullable=False, # previously sess_type default=SessionTypes.INTERACTIVE, server_default=SessionTypes.INTERACTIVE.name, - ), - sa.Column( + ) + cluster_mode = sa.Column( "cluster_mode", sa.String(length=16), nullable=False, default=ClusterMode.SINGLE_NODE, server_default=ClusterMode.SINGLE_NODE.name, - ), - sa.Column("cluster_size", sa.Integer, nullable=False, default=1), - sa.Column( + ) + cluster_size = sa.Column("cluster_size", sa.Integer, nullable=False, default=1) + cluster_role = sa.Column( "cluster_role", sa.String(length=16), nullable=False, default=DEFAULT_ROLE, index=True - ), - sa.Column("cluster_idx", sa.Integer, nullable=False, default=0), - sa.Column("local_rank", sa.Integer, nullable=False, default=0), - sa.Column("cluster_hostname", sa.String(length=64), nullable=False, default=default_hostname), + ) + cluster_idx = sa.Column("cluster_idx", sa.Integer, nullable=False, default=0) + local_rank = sa.Column("local_rank", sa.Integer, nullable=False, default=0) + cluster_hostname = sa.Column( + "cluster_hostname", sa.String(length=64), nullable=False, default=default_hostname + ) # Resource ownership - sa.Column("scaling_group", sa.ForeignKey("scaling_groups.name"), index=True, nullable=True), - sa.Column("agent", sa.String(length=64), sa.ForeignKey("agents.id"), nullable=True), - sa.Column("agent_addr", sa.String(length=128), nullable=True), - sa.Column("domain_name", sa.String(length=64), sa.ForeignKey("domains.name"), nullable=False), - sa.Column("group_id", GUID, sa.ForeignKey("groups.id"), nullable=False), - sa.Column("user_uuid", GUID, sa.ForeignKey("users.uuid"), nullable=False), - sa.Column("access_key", sa.String(length=20), sa.ForeignKey("keypairs.access_key")), + scaling_group = sa.Column( + "scaling_group", sa.ForeignKey("scaling_groups.name"), index=True, nullable=True + ) + agent = sa.Column("agent", sa.String(length=64), sa.ForeignKey("agents.id"), nullable=True) + agent_addr = sa.Column("agent_addr", sa.String(length=128), nullable=True) + domain_name = sa.Column( + "domain_name", sa.String(length=64), sa.ForeignKey("domains.name"), nullable=False + ) + group_id = sa.Column("group_id", GUID, sa.ForeignKey("groups.id"), nullable=False) + user_uuid = sa.Column("user_uuid", GUID, sa.ForeignKey("users.uuid"), nullable=False) + access_key = sa.Column("access_key", sa.String(length=20), sa.ForeignKey("keypairs.access_key")) # `image` is a string shaped "/:". it is identical to images.name column - sa.Column("image", sa.String(length=512)), - # ForeignKeyIDColumn("image_id", "images.id"), - sa.Column("architecture", sa.String(length=32), default="x86_64"), - sa.Column("registry", sa.String(length=512)), - sa.Column("tag", sa.String(length=64), nullable=True), + image = sa.Column("image", sa.String(length=512)) + # ForeignKeyIDColumn("image_id", "images.id") + architecture = sa.Column("architecture", sa.String(length=32), default="x86_64") + registry = sa.Column("registry", sa.String(length=512)) + tag = sa.Column("tag", sa.String(length=64), nullable=True) # Resource occupation - sa.Column("container_id", sa.String(length=64)), - sa.Column("occupied_slots", ResourceSlotColumn(), nullable=False), - sa.Column("requested_slots", ResourceSlotColumn(), nullable=False, default=ResourceSlot()), - sa.Column("occupied_shares", pgsql.JSONB(), nullable=False, default={}), # legacy - sa.Column("environ", sa.ARRAY(sa.String), nullable=True), - sa.Column("mounts", sa.ARRAY(sa.String), nullable=True), # list of list; legacy since 22.03 - sa.Column("mount_map", pgsql.JSONB(), nullable=True, default={}), # legacy since 22.03 - sa.Column("vfolder_mounts", StructuredJSONObjectListColumn(VFolderMount), nullable=True), - sa.Column("attached_devices", pgsql.JSONB(), nullable=True, default={}), - sa.Column("resource_opts", pgsql.JSONB(), nullable=True, default={}), - sa.Column("bootstrap_script", sa.String(length=16 * 1024), nullable=True), + container_id = sa.Column("container_id", sa.String(length=64)) + occupied_slots = sa.Column("occupied_slots", ResourceSlotColumn(), nullable=False) + requested_slots = sa.Column( + "requested_slots", ResourceSlotColumn(), nullable=False, default=ResourceSlot() + ) + occupied_shares = sa.Column( + "occupied_shares", pgsql.JSONB(), nullable=False, default={} + ) # legacy + environ = sa.Column("environ", sa.ARRAY(sa.String), nullable=True) + mounts = sa.Column( + "mounts", sa.ARRAY(sa.String), nullable=True + ) # list of list; legacy since 22.03 + mount_map = sa.Column( + "mount_map", pgsql.JSONB(), nullable=True, default={} + ) # legacy since 22.03 + vfolder_mounts = sa.Column( + "vfolder_mounts", StructuredJSONObjectListColumn(VFolderMount), nullable=True + ) + attached_devices = sa.Column("attached_devices", pgsql.JSONB(), nullable=True, default={}) + resource_opts = sa.Column("resource_opts", pgsql.JSONB(), nullable=True, default={}) + bootstrap_script = sa.Column("bootstrap_script", sa.String(length=16 * 1024), nullable=True) # Port mappings # If kernel_host is NULL, it is assumed to be same to the agent host or IP. - sa.Column("kernel_host", sa.String(length=128), nullable=True), - sa.Column("repl_in_port", sa.Integer(), nullable=False), - sa.Column("repl_out_port", sa.Integer(), nullable=False), - sa.Column("stdin_port", sa.Integer(), nullable=False), # legacy for stream_pty - sa.Column("stdout_port", sa.Integer(), nullable=False), # legacy for stream_pty - sa.Column("service_ports", pgsql.JSONB(), nullable=True), - sa.Column("preopen_ports", sa.ARRAY(sa.Integer), nullable=True), - sa.Column("use_host_network", sa.Boolean(), default=False, nullable=False), + kernel_host = sa.Column("kernel_host", sa.String(length=128), nullable=True) + repl_in_port = sa.Column("repl_in_port", sa.Integer(), nullable=False) + repl_out_port = sa.Column("repl_out_port", sa.Integer(), nullable=False) + stdin_port = sa.Column("stdin_port", sa.Integer(), nullable=False) # legacy for stream_pty + stdout_port = sa.Column("stdout_port", sa.Integer(), nullable=False) # legacy for stream_pty + service_ports = sa.Column("service_ports", pgsql.JSONB(), nullable=True) + preopen_ports = sa.Column("preopen_ports", sa.ARRAY(sa.Integer), nullable=True) + use_host_network = sa.Column("use_host_network", sa.Boolean(), default=False, nullable=False) # Lifecycle - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True), - sa.Column( + created_at = sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), index=True + ) + terminated_at = sa.Column( "terminated_at", sa.DateTime(timezone=True), nullable=True, default=sa.null(), index=True - ), - sa.Column("starts_at", sa.DateTime(timezone=True), nullable=True, default=sa.null()), - sa.Column( + ) + starts_at = sa.Column("starts_at", sa.DateTime(timezone=True), nullable=True, default=sa.null()) + status = sa.Column( "status", EnumType(KernelStatus), default=KernelStatus.PENDING, server_default=KernelStatus.PENDING.name, nullable=False, index=True, - ), - sa.Column( + ) + role = sa.Column( "role", EnumType(KernelRole), default=KernelRole.COMPUTE, server_default=KernelRole.COMPUTE.name, nullable=False, index=True, - ), - sa.Column("status_changed", sa.DateTime(timezone=True), nullable=True, index=True), - sa.Column("status_info", sa.Unicode(), nullable=True, default=sa.null()), + ) + status_changed = sa.Column( + "status_changed", sa.DateTime(timezone=True), nullable=True, index=True + ) + status_info = sa.Column("status_info", sa.Unicode(), nullable=True, default=sa.null()) # status_info contains a kebab-cased string that expresses a summary of the last status change. # Examples: "user-requested", "self-terminated", "predicate-checks-failed", "no-available-instances" - sa.Column("status_data", pgsql.JSONB(), nullable=True, default=sa.null()), + status_data = sa.Column("status_data", pgsql.JSONB(), nullable=True, default=sa.null()) # status_data contains a JSON object that contains detailed data for the last status change. # During scheduling (as PENDING + ("no-available-instances" | "predicate-checks-failed")): # { @@ -512,43 +535,43 @@ async def handle_kernel_exception( # // used to prevent duplication of SessionTerminatedEvent # } # } - sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()), - sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()), - sa.Column("startup_command", sa.Text, nullable=True), - sa.Column( + status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()) + callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()) + startup_command = sa.Column("startup_command", sa.Text, nullable=True) + result = sa.Column( "result", EnumType(SessionResult), default=SessionResult.UNDEFINED, server_default=SessionResult.UNDEFINED.name, nullable=False, index=True, - ), - sa.Column("internal_data", pgsql.JSONB(), nullable=True), - sa.Column("container_log", sa.LargeBinary(), nullable=True), + ) + internal_data = sa.Column("internal_data", pgsql.JSONB(), nullable=True) + container_log = sa.Column("container_log", sa.LargeBinary(), nullable=True) # Resource metrics measured upon termination - sa.Column("num_queries", sa.BigInteger(), default=0), - sa.Column("last_stat", pgsql.JSONB(), nullable=True, default=sa.null()), - sa.Index("ix_kernels_sess_id_role", "session_id", "cluster_role", unique=False), - sa.Index("ix_kernels_status_role", "status", "cluster_role"), - sa.Index( - "ix_kernels_updated_order", - sa.func.greatest("created_at", "terminated_at", "status_changed"), - unique=False, - ), - sa.Index( - "ix_kernels_unique_sess_token", - "access_key", - "session_name", - unique=True, - postgresql_where=sa.text( - "status NOT IN ('TERMINATED', 'CANCELLED') and cluster_role = 'main'" + num_queries = sa.Column("num_queries", sa.BigInteger(), default=0) + last_stat = sa.Column("last_stat", pgsql.JSONB(), nullable=True, default=sa.null()) + + __table_args__ = ( + # indexing + sa.Index("ix_kernels_sess_id_role", "session_id", "cluster_role", unique=False), + sa.Index("ix_kernels_status_role", "status", "cluster_role"), + sa.Index( + "ix_kernels_updated_order", + sa.func.greatest("created_at", "terminated_at", "status_changed"), + unique=False, ), - ), -) - + sa.Index( + "ix_kernels_unique_sess_token", + "access_key", + "session_name", + unique=True, + postgresql_where=sa.text( + "status NOT IN ('TERMINATED', 'CANCELLED') and cluster_role = 'main'" + ), + ), + ) -class KernelRow(Base): - __table__ = kernels session = relationship("SessionRow", back_populates="kernels") image_row = relationship( "ImageRow", @@ -715,6 +738,9 @@ async def _update() -> bool: return await execute_with_retry(_update) +# For compatibility +kernels = KernelRow.__table__ + DEFAULT_KERNEL_ORDERING = [ sa.desc( sa.func.greatest(