Skip to content

Commit

Permalink
refactor naming of db_connection, /load endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dzakharchuk committed Jan 16, 2024
1 parent ec3efe6 commit 807b59c
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 181 deletions.
68 changes: 34 additions & 34 deletions featureflags/graph/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
from featureflags.utils import select_scalar


async def gen_id(local_id: LocalId, *, db_connection: SAConnection) -> UUID:
async def gen_id(local_id: LocalId, *, conn: SAConnection) -> UUID:
assert local_id.scope and local_id.value, local_id

id_ = await select_scalar(
db_connection,
conn,
(
insert(LocalIdMap.__table__)
.values(
Expand All @@ -54,7 +54,7 @@ async def gen_id(local_id: LocalId, *, db_connection: SAConnection) -> UUID:
)
if id_ is None:
id_ = await select_scalar(
db_connection,
conn,
(
select([LocalIdMap.id]).where(
and_(
Expand All @@ -67,12 +67,12 @@ async def gen_id(local_id: LocalId, *, db_connection: SAConnection) -> UUID:
return id_


async def get_auth_user(username: str, *, db_connection: SAConnection) -> UUID:
async def get_auth_user(username: str, *, conn: SAConnection) -> UUID:
user_id_select = select([AuthUser.id]).where(AuthUser.username == username)
user_id = await select_scalar(db_connection, user_id_select)
user_id = await select_scalar(conn, user_id_select)
if user_id is None:
user_id = await select_scalar(
db_connection,
conn,
(
insert(AuthUser.__table__)
.values(
Expand All @@ -86,7 +86,7 @@ async def get_auth_user(username: str, *, db_connection: SAConnection) -> UUID:
),
)
if user_id is None:
user_id = await select_scalar(db_connection, user_id_select)
user_id = await select_scalar(conn, user_id_select)
assert user_id is not None
return user_id

Expand All @@ -96,19 +96,19 @@ async def sign_in(
username: str,
password: str,
*,
db_connection: SAConnection,
conn: SAConnection,
session: UserSession,
ldap: BaseLDAP,
) -> bool:
assert username and password, "Username and password are required"
if not await ldap.check_credentials(username, password):
return False

user_id = await get_auth_user(username, db_connection=db_connection)
user_id = await get_auth_user(username, conn=conn)

now = datetime.utcnow()
expiration_time = now + AUTH_SESSION_TTL
await db_connection.execute(
await conn.execute(
insert(AuthSession.__table__)
.values(
{
Expand All @@ -133,10 +133,10 @@ async def sign_in(

@track
async def sign_out(
*, db_connection: SAConnection, session: UserSession
*, conn: SAConnection, session: UserSession
) -> None:
if session.ident:
await db_connection.execute(
await conn.execute(
AuthSession.__table__.delete().where(
AuthSession.session == session.ident
)
Expand All @@ -149,14 +149,14 @@ async def sign_out(
async def enable_flag(
flag_id: str,
*,
db_connection: SAConnection,
conn: SAConnection,
dirty: DirtyProjects,
changes: Changes,
) -> None:
assert flag_id, "Flag id is required"

flag_uuid = UUID(hex=flag_id)
await db_connection.execute(
await conn.execute(
Flag.__table__.update()
.where(Flag.id == flag_uuid)
.values({Flag.enabled: True})
Expand All @@ -170,14 +170,14 @@ async def enable_flag(
async def disable_flag(
flag_id: str,
*,
db_connection: SAConnection,
conn: SAConnection,
dirty: DirtyProjects,
changes: Changes,
) -> None:
assert flag_id, "Flag id is required"

flag_uuid = UUID(hex=flag_id)
await db_connection.execute(
await conn.execute(
Flag.__table__.update()
.where(Flag.id == flag_uuid)
.values({Flag.enabled: False})
Expand All @@ -191,19 +191,19 @@ async def disable_flag(
async def reset_flag(
flag_id: str,
*,
db_connection: SAConnection,
conn: SAConnection,
dirty: DirtyProjects,
changes: Changes,
) -> None:
assert flag_id, "Flag id is required"

flag_uuid = UUID(hex=flag_id)
await db_connection.execute(
await conn.execute(
Flag.__table__.update()
.where(Flag.id == flag_uuid)
.values({Flag.enabled: None})
)
await db_connection.execute(
await conn.execute(
Condition.__table__.delete().where(Condition.flag == flag_uuid)
)
dirty.by_flag.add(flag_uuid)
Expand All @@ -213,15 +213,15 @@ async def reset_flag(
@auth_required
@track
async def delete_flag(
flag_id: str, *, db_connection: SAConnection, changes: Changes
flag_id: str, *, conn: SAConnection, changes: Changes
) -> None:
assert flag_id, "Flag id is required"

flag_uuid = UUID(hex=flag_id)
await db_connection.execute(
await conn.execute(
Condition.__table__.delete().where(Condition.flag == flag_uuid)
)
await db_connection.execute(
await conn.execute(
Flag.__table__.delete().where(Flag.id == flag_uuid)
)

Expand All @@ -231,17 +231,17 @@ async def delete_flag(
@auth_required
@track
async def add_check(
op: AddCheckOp, *, db_connection: SAConnection, dirty: DirtyProjects
op: AddCheckOp, *, conn: SAConnection, dirty: DirtyProjects
) -> dict[LocalId, UUID]:
id_ = await gen_id(op.local_id, db_connection=db_connection)
id_ = await gen_id(op.local_id, conn=conn)
variable_id = UUID(hex=op.variable)
values = {
Check.id: id_,
Check.variable: variable_id,
Check.operator: Operator(op.operator),
}
values.update(Check.value_from_op(op)) # type: ignore
await db_connection.execute(
await conn.execute(
insert(Check.__table__).values(values).on_conflict_do_nothing()
)
dirty.by_variable.add(variable_id)
Expand All @@ -253,20 +253,20 @@ async def add_check(
async def add_condition(
op: AddConditionOp,
*,
db_connection: SAConnection,
conn: SAConnection,
ids: dict,
dirty: DirtyProjects,
changes: Changes,
) -> dict:
id_ = await gen_id(op.local_id, db_connection=db_connection)
id_ = await gen_id(op.local_id, conn=conn)

flag_id = UUID(hex=op.flag_id)
checks = [
ids[check.local_id] if check.local_id else UUID(hex=check.id)
for check in op.checks
]

await db_connection.execute(
await conn.execute(
insert(Condition.__table__)
.values(
{
Expand All @@ -287,14 +287,14 @@ async def add_condition(
async def disable_condition(
condition_id: str,
*,
db_connection: SAConnection,
conn: SAConnection,
dirty: DirtyProjects,
changes: Changes,
) -> None:
assert condition_id, "Condition id is required"

flag_id = await select_scalar(
db_connection,
conn,
(
Condition.__table__.delete()
.where(Condition.id == UUID(hex=condition_id))
Expand All @@ -307,7 +307,7 @@ async def disable_condition(


async def postprocess(
*, db_connection: SAConnection, dirty: DirtyProjects
*, conn: SAConnection, dirty: DirtyProjects
) -> None:
selections = []
for flag_id in dirty.by_flag:
Expand All @@ -317,22 +317,22 @@ async def postprocess(
select([Variable.project]).where(Variable.id == variable_id)
)
if selections:
await db_connection.execute(
await conn.execute(
update(Project.__table__)
.where(or_(*[Project.id.in_(sel) for sel in selections]))
.values({Project.version: Project.version + 1})
)


async def update_changelog(
*, session: UserSession, db_connection: SAConnection, changes: Changes
*, session: UserSession, conn: SAConnection, changes: Changes
) -> None:
actions = changes.get_actions()
if actions:
assert session.user is not None
for flag, flag_actions in actions:
assert flag_actions, repr(flag_actions)
await db_connection.execute(
await conn.execute(
insert(Changelog.__table__).values(
{
Changelog.timestamp: datetime.utcnow(),
Expand Down
26 changes: 13 additions & 13 deletions featureflags/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ async def sing_in(ctx: dict, options: dict) -> AuthResult:
success = await actions.sign_in(
username,
password,
db_connection=conn,
conn=conn,
session=ctx[GraphContext.USER_SESSION],
ldap=ctx[GraphContext.LDAP_SERVICE],
)
Expand All @@ -523,7 +523,7 @@ async def sing_out(ctx: dict) -> AuthResult:
return AuthResult(None)
async with ctx[GraphContext.DB_ENGINE].acquire() as conn:
await actions.sign_out(
db_connection=conn,
conn=conn,
session=ctx[GraphContext.USER_SESSION],
)
return AuthResult(None)
Expand All @@ -545,29 +545,29 @@ async def save_flag(ctx: dict, options: dict) -> SaveFlagResult:
case Operation.ENABLE_FLAG:
await actions.enable_flag(
operation_payload["flag_id"],
db_connection=conn,
conn=conn,
dirty=ctx[GraphContext.DIRTY_PROJECTS],
changes=ctx[GraphContext.CHANGES],
)
case Operation.DISABLE_FLAG:
await actions.disable_flag(
operation_payload["flag_id"],
db_connection=conn,
conn=conn,
dirty=ctx[GraphContext.DIRTY_PROJECTS],
changes=ctx[GraphContext.CHANGES],
)
case Operation.ADD_CHECK:
new_ids = await actions.add_check(
AddCheckOp(operation_payload),
db_connection=conn,
conn=conn,
dirty=ctx[GraphContext.DIRTY_PROJECTS],
)
if new_ids is not None:
ctx[GraphContext.CHECK_IDS].update(new_ids)
case Operation.ADD_CONDITION:
new_ids = await actions.add_condition(
AddConditionOp(operation_payload),
db_connection=conn,
conn=conn,
ids=ctx[GraphContext.CHECK_IDS],
dirty=ctx[GraphContext.DIRTY_PROJECTS],
changes=ctx[GraphContext.CHANGES],
Expand All @@ -577,19 +577,19 @@ async def save_flag(ctx: dict, options: dict) -> SaveFlagResult:
case Operation.DISABLE_CONDITION:
await actions.disable_condition(
operation_payload["condition_id"],
db_connection=conn,
conn=conn,
dirty=ctx[GraphContext.DIRTY_PROJECTS],
changes=ctx[GraphContext.CHANGES],
)
case _:
raise ValueError(f"Unknown operation: {operation_type}")

await actions.postprocess(
db_connection=conn, dirty=ctx[GraphContext.DIRTY_PROJECTS]
conn=conn, dirty=ctx[GraphContext.DIRTY_PROJECTS]
)
await actions.update_changelog(
session=ctx[GraphContext.USER_SESSION],
db_connection=conn,
conn=conn,
changes=ctx[GraphContext.CHANGES],
)

Expand All @@ -601,16 +601,16 @@ async def reset_flag(ctx: dict, options: dict) -> ResetFlagResult:
async with ctx[GraphContext.DB_ENGINE].acquire() as conn:
await actions.reset_flag(
options["id"],
db_connection=conn,
conn=conn,
dirty=ctx[GraphContext.DIRTY_PROJECTS],
changes=ctx[GraphContext.CHANGES],
)
await actions.postprocess(
db_connection=conn, dirty=ctx[GraphContext.DIRTY_PROJECTS]
conn=conn, dirty=ctx[GraphContext.DIRTY_PROJECTS]
)
await actions.update_changelog(
session=ctx[GraphContext.USER_SESSION],
db_connection=conn,
conn=conn,
changes=ctx[GraphContext.CHANGES],
)

Expand All @@ -622,7 +622,7 @@ async def delete_flag(ctx: dict, options: dict) -> DeleteFlagResult:
async with ctx[GraphContext.DB_ENGINE].acquire() as conn:
await actions.delete_flag(
options["id"],
db_connection=conn,
conn=conn,
changes=ctx[GraphContext.CHANGES],
)

Expand Down
4 changes: 2 additions & 2 deletions featureflags/http/api/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
router = APIRouter(prefix="/flags", tags=["flags"])


@router.post("/preload")
@router.post("/load")
@inject
async def preload_flags(
async def load_flags(
request: PreloadFlagsRequest,
flags_repo: FlagsRepository = Depends(Provide[Container.flags_repo]),
) -> PreloadFlagsResponse:
Expand Down
Loading

0 comments on commit 807b59c

Please sign in to comment.