Skip to content

Commit

Permalink
fix: Delete vfolders in background when purge user
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jul 4, 2024
1 parent 6307df1 commit 36b2cf9
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 60 deletions.
73 changes: 73 additions & 0 deletions src/ai/backend/manager/models/keypair.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import base64
import logging
import secrets
import uuid
from datetime import datetime
Expand All @@ -15,10 +16,12 @@
from graphene.types.datetime import DateTime as GQLDateTime
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import false

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AccessKey, SecretKey

if TYPE_CHECKING:
Expand All @@ -43,6 +46,9 @@
from .user import ModifyUserInput, UserRole
from .utils import agg_to_array

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


__all__: Sequence[str] = (
"keypairs",
"KeyPairRow",
Expand Down Expand Up @@ -718,3 +724,70 @@ def verify_dotfile_name(dotfile: str) -> bool:
if dotfile in RESERVED_DOTFILES:
return False
return True


async def delete_kernels_by_access_key(
db_session: SASession,
access_key: AccessKey,
) -> int:
"""
Delete keypair's all kernels.
:param conn: DB connection
:param access_key: access key to delete kernels
:return: number of deleted rows
"""
from . import KernelRow

result = await db_session.execute(
sa.delete(KernelRow).where(KernelRow.access_key == access_key),
)
if result.rowcount > 0:
log.info("deleted {0} keypair's kernels ({1})", result.rowcount, access_key)
return result.rowcount


async def delete_sessions_by_access_key(
db_session: SASession,
access_key: AccessKey,
) -> int:
"""
Delete keypair's all sessions.
:param db_session: SQLAlchemy session
:param access_key: access key to delete sessions
:return: number of deleted rows
"""
from .session import SessionRow

result = await db_session.execute(
sa.delete(SessionRow).where(SessionRow.access_key == access_key)
)
if result.rowcount > 0:
log.info("deleted {0} user's sessions ({1})", result.rowcount, access_key)
return result.rowcount


async def access_key_has_active_sessions(
db_session: SASession,
access_key: AccessKey,
) -> bool:
"""
Check if the keypair does not have active sessions.
:param db_session: SQLAlchemy session
:param access_key: access key
:return: True if the access key has some active sessions.
"""
from . import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, KernelRow

active_kernel_count = await db_session.scalar(
sa.select(sa.func.count())
.select_from(KernelRow)
.where(
(KernelRow.access_key == access_key)
& (KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)),
),
)
return active_kernel_count > 0
195 changes: 135 additions & 60 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import enum
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Optional, Sequence, cast
from uuid import UUID, uuid4

import aiotools
import bcrypt
import graphene
import sqlalchemy as sa
Expand All @@ -23,10 +22,10 @@
from sqlalchemy.types import VARCHAR, TypeDecorator

from ai.backend.common import redis_helper
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import RedisConnectionInfo, VFolderID
from ai.backend.common.types import RedisConnectionInfo

from ..api.exceptions import VFolderOperationFailed
from ..defs import DEFAULT_KEYPAIR_RATE_LIMIT, DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME
from .base import (
Base,
Expand All @@ -48,11 +47,11 @@
from .gql_relay import AsyncNode, Connection, ConnectionResolverResult
from .minilang.ordering import OrderSpecItem, QueryOrderParser
from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter
from .storage import StorageSessionManager
from .utils import ExtendedAsyncSAEngine
from .utils import execute_with_retry

if TYPE_CHECKING:
from .gql import GraphQueryContext
from .vfolder import VFolderRow

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

Expand Down Expand Up @@ -958,6 +957,7 @@ class Arguments:

ok = graphene.Boolean()
msg = graphene.String()
bgtask_id = graphene.UUID(required=False)

@classmethod
async def mutate(
Expand All @@ -967,38 +967,109 @@ async def mutate(
email: str,
props: PurgeUserInput,
) -> PurgeUser:
from .keypair import (
KeyPairRow,
access_key_has_active_sessions,
delete_kernels_by_access_key,
delete_sessions_by_access_key,
)
from .vfolder import VFolderDeletionInfo, delete_vfolders

graph_ctx: GraphQueryContext = info.context
delete_query = sa.delete(users).where(users.c.email == email)

async def _pre_func(conn: SAConnection) -> None:
async with graph_ctx.db.begin_readonly() as conn:
user_uuid = await conn.scalar(
sa.select([users.c.uuid]).select_from(users).where(users.c.email == email),
sa.select(users.c.uuid).select_from(users).where(users.c.email == email),
)
log.info("Purging all records of the user {0}...", email)
# `keypairs.user_id` is email
# Check `src.ai.backend.manager.models.keypair.CreateKeyPair.prepare_new_keypair()`
keypair_rows = (
await SASession(conn).scalars(
sa.select(KeyPairRow).where(KeyPairRow.user_id == email)
)
).all()
keypair_rows = cast(list[KeyPairRow], keypair_rows)

if await cls.user_vfolder_mounted_to_active_kernels(conn, user_uuid):
raise RuntimeError(
"Some of user's virtual folders are mounted to active kernels. "
"Terminate those kernels first.",
async def _pre_purge() -> list[VFolderRow]:
async with graph_ctx.db.begin() as conn:
db_session = SASession(conn)
update_status_query = (
sa.update(users)
.values(status=UserStatus.DELETED, status_info="admin-requested")
.where(users.c.email == email)
)
if await cls.user_has_active_kernels(conn, user_uuid):
raise RuntimeError("User has some active kernels. Terminate them first.")

if not props.purge_shared_vfolders:
await cls.migrate_shared_vfolders(
conn,
deleted_user_uuid=user_uuid,
target_user_uuid=graph_ctx.user["uuid"],
target_user_email=graph_ctx.user["email"],
await conn.execute(update_status_query)

log.info("Purging all records of the user {0}...", email)

if await cls.user_vfolder_mounted_to_active_kernels(conn, user_uuid):
raise RuntimeError(
"Some of user's virtual folders are mounted to active kernels. "
"Terminate those kernels first.",
)
if await cls.user_has_active_kernels(conn, user_uuid):
raise RuntimeError("User has some active kernels. Terminate them first.")
for row in keypair_rows:
if await access_key_has_active_sessions(db_session, row.access_key):
raise RuntimeError(
f"One of keypairs the user owns has some active sessions. Terminate them first. (ak:{row.access_key})"
)

if not props.purge_shared_vfolders:
await cls.migrate_shared_vfolders(
conn,
deleted_user_uuid=user_uuid,
target_user_uuid=graph_ctx.user["uuid"],
target_user_email=graph_ctx.user["email"],
)
await cls.delete_error_logs(conn, user_uuid)
await cls.delete_endpoint(conn, user_uuid)
await cls.delete_kernels(conn, user_uuid)
for row in keypair_rows:
await delete_kernels_by_access_key(db_session, row.access_key)
await cls.delete_sessions(conn, user_uuid)
for row in keypair_rows:
await delete_sessions_by_access_key(db_session, row.access_key)
await cls.delete_keypairs(conn, graph_ctx.redis_stat, user_uuid)

vfolder_rows = await cls.alive_vfolders(conn, user_uuid)
if not vfolder_rows:
await cls.delete_vfolders(conn, user_uuid)
return vfolder_rows

alive_vfolders = await execute_with_retry(_pre_purge)

if alive_vfolders:
# Run bgtask
background_task_manager = graph_ctx.background_task_manager

async def _delete_vfolders_task(reporter: ProgressReporter) -> None:
await delete_vfolders(
[VFolderDeletionInfo(row.vfid, row.host) for row in alive_vfolders],
storage_manager=graph_ctx.storage_manager,
db=graph_ctx.db,
reporter=reporter,
)
await cls.delete_error_logs(conn, user_uuid)
await cls.delete_endpoint(conn, user_uuid)
await cls.delete_kernels(conn, user_uuid)
await cls.delete_sessions(conn, user_uuid)
await cls.delete_vfolders(graph_ctx.db, user_uuid, graph_ctx.storage_manager)
await cls.delete_keypairs(conn, graph_ctx.redis_stat, user_uuid)

delete_query = sa.delete(users).where(users.c.email == email)
return await simple_db_mutate(cls, graph_ctx, delete_query, pre_func=_pre_func)
async def _delete_records() -> None:
async with graph_ctx.db.begin_session() as db_session:
alive_vfolders = await cls.alive_vfolders(db_session.bind, user_uuid)
if not alive_vfolders:
await cls.delete_vfolders(db_session.bind, user_uuid)
await db_session.execute(delete_query)
else:
log.info(
"failed to delete some vfolders. delete them manually and try again."
)

await execute_with_retry(_delete_records)

task_id = await background_task_manager.start(_delete_vfolders_task)
return cls(True, "purge ongoing. finish after deleting all vfolders.", task_id)

else:
return await simple_db_mutate(cls, graph_ctx, delete_query)

@classmethod
async def migrate_shared_vfolders(
Expand Down Expand Up @@ -1089,46 +1160,29 @@ async def migrate_shared_vfolders(
@classmethod
async def delete_vfolders(
cls,
engine: ExtendedAsyncSAEngine,
conn: SAConnection,
user_uuid: UUID,
storage_manager: StorageSessionManager,
) -> int:
"""
Delete user's all virtual folders as well as their physical data.
Delete DB records of user's all virtual folders.
:param conn: DB connection
:param user_uuid: user's UUID to delete virtual folders
:return: number of deleted rows
"""
from . import VFolderDeletionInfo, initiate_vfolder_deletion, vfolder_permissions, vfolders
from . import (
VFolderRow,
vfolder_permissions,
)

async with engine.begin_session() as conn:
await conn.execute(
vfolder_permissions.delete().where(vfolder_permissions.c.user == user_uuid),
)
result = await conn.execute(
sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.quota_scope_id])
.select_from(vfolders)
.where(vfolders.c.user == user_uuid),
)
target_vfs = result.fetchall()

storage_ptask_group = aiotools.PersistentTaskGroup()
try:
await initiate_vfolder_deletion(
engine,
[VFolderDeletionInfo(VFolderID.from_row(vf), vf["host"]) for vf in target_vfs],
storage_manager,
storage_ptask_group,
)
except VFolderOperationFailed as e:
log.error("error on deleting vfolder filesystem directory: {0}", e.extra_msg)
raise
deleted_count = len(target_vfs)
if deleted_count > 0:
log.info("deleted {0} user's virtual folders ({1})", deleted_count, user_uuid)
return deleted_count
await conn.execute(
vfolder_permissions.delete().where(vfolder_permissions.c.user == user_uuid),
)
result = await conn.execute(sa.delete(VFolderRow).where(VFolderRow.user == user_uuid))
if result.rowcount > 0:
log.info("deleted {0} user's vfolders ({1})", result.rowcount, user_uuid)
return result.rowcount

@classmethod
async def user_vfolder_mounted_to_active_kernels(
Expand Down Expand Up @@ -1192,6 +1246,27 @@ async def user_has_active_kernels(
)
return active_kernel_count > 0

@classmethod
async def alive_vfolders(
cls,
conn: SAConnection,
user_uuid: UUID,
) -> list[VFolderRow]:
"""
:param conn: DB connection
:param user_uuid: user's UUID
:return: not-deleted vfolders the user owns.
"""
from .vfolder import VFolderOperationStatus, VFolderRow

vfolder_query = sa.select(VFolderRow).where(
(VFolderRow.user == user_uuid)
& (VFolderRow.status != VFolderOperationStatus.DELETE_COMPLETE)
)
vfolder_rows = (await SASession(conn).scalars(vfolder_query)).all()
return vfolder_rows

@classmethod
async def delete_endpoint(
cls,
Expand Down
Loading

0 comments on commit 36b2cf9

Please sign in to comment.