Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw SQL query for fetching users #381

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apps/backend/src/app/api/v1/team-member-profiles/crud.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
include: fullInclude,
});

const lastActiveAtMillis = await getUsersLastActiveAtMillis(db.map(user => user.projectUserId), db.map(user => user.createdAt));
const lastActiveAtMillis = await getUsersLastActiveAtMillis(auth.project.id, db.map(user => user.projectUserId), db.map(user => user.createdAt));

return {
items: db.map((user, index) => prismaToCrud(user, lastActiveAtMillis[index])),
Expand Down Expand Up @@ -118,7 +118,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
throw new KnownErrors.TeamMembershipNotFound(params.team_id, params.user_id);
}

return prismaToCrud(db, await getUserLastActiveAtMillis(db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
return prismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
});
},
onUpdate: async ({ auth, data, params }) => {
Expand Down Expand Up @@ -151,7 +151,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
include: fullInclude,
});

return prismaToCrud(db, await getUserLastActiveAtMillis(db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
return prismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
});
},
}));
204 changes: 180 additions & 24 deletions apps/backend/src/app/api/v1/users/crud.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ensureTeamMembershipExists, ensureUserExists } from "@/lib/request-checks";
import { PrismaTransaction } from "@/lib/types";
import { sendTeamMembershipDeletedWebhook, sendUserCreatedWebhook, sendUserDeletedWebhook, sendUserUpdatedWebhook } from "@/lib/webhooks";
import { prismaClient, retryTransaction } from "@/prisma-client";
import { RawQuery, prismaClient, rawQuery, retryTransaction } from "@/prisma-client";
import { createCrudHandlers } from "@/route-handlers/crud-handler";
import { runAsynchronouslyAndWaitUntil } from "@/utils/vercel";
import { BooleanTrue, Prisma } from "@prisma/client";
Expand All @@ -11,8 +11,10 @@ import { UsersCrud, usersCrud } from "@stackframe/stack-shared/dist/interface/cr
import { userIdOrMeSchema, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { validateBase64Image } from "@stackframe/stack-shared/dist/utils/base64";
import { decodeBase64 } from "@stackframe/stack-shared/dist/utils/bytes";
import { getNodeEnvironment } from "@stackframe/stack-shared/dist/utils/env";
import { StackAssertionError, StatusError, throwErr } from "@stackframe/stack-shared/dist/utils/errors";
import { hashPassword, isPasswordHashValid } from "@stackframe/stack-shared/dist/utils/hashes";
import { deepPlainEquals } from "@stackframe/stack-shared/dist/utils/objects";
import { createLazyProxy } from "@stackframe/stack-shared/dist/utils/proxies";
import { typedToLowercase } from "@stackframe/stack-shared/dist/utils/strings";
import { teamPrismaToCrud, teamsCrudHandlers } from "../teams/crud";
Expand Down Expand Up @@ -203,45 +205,199 @@ async function getOtpConfig(tx: PrismaTransaction, projectConfigId: string) {
return otpConfig.length === 0 ? null : otpConfig[0];
}

export const getUserLastActiveAtMillis = async (userId: string): Promise<number | null> => {
const event = await prismaClient.event.findFirst({
where: {
data: {
path: ["$.userId"],
equals: userId,
},
},
orderBy: {
createdAt: 'desc',
},
});

return event?.createdAt.getTime() ?? null;
export const getUserLastActiveAtMillis = async (projectId: string, userId: string): Promise<number | null> => {
const res = (await getUsersLastActiveAtMillis(projectId, [userId], [0]))[0];
if (res === 0) {
return null;
}
return res;
};

// same as userIds.map(userId => getUserLastActiveAtMillis(userId, fallbackTo)), but uses a single query
export const getUsersLastActiveAtMillis = async (userIds: string[], fallbackTo: (number | Date)[]): Promise<number[]> => {
// same as userIds.map(userId => getUserLastActiveAtMillis(projectId, userId)), but uses a single query
export const getUsersLastActiveAtMillis = async (projectId: string, userIds: string[], userSignedUpAtMillis: (number | Date)[]): Promise<number[]> => {
if (userIds.length === 0) {
// Prisma.join throws an error if the array is empty, so we need to handle that case
return [];
}

const events = await prismaClient.$queryRaw<Array<{ userId: string, lastActiveAt: Date }>>`
SELECT data->>'userId' as "userId", MAX("createdAt") as "lastActiveAt"
SELECT data->>'userId' as "userId", MAX("eventStartedAt") as "lastActiveAt"
FROM "Event"
WHERE data->>'userId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(userIds)}]`})
WHERE data->>'userId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(userIds)}]`}) AND data->>'projectId' = ${projectId} AND "systemEventTypeIds" @> '{"$user-activity"}'
GROUP BY data->>'userId'
`;

return userIds.map((userId, index) => {
const event = events.find(e => e.userId === userId);
return event ? event.lastActiveAt.getTime() : (
typeof fallbackTo[index] === "number" ? (fallbackTo[index] as number) : (fallbackTo[index] as Date).getTime()
typeof userSignedUpAtMillis[index] === "number" ? (userSignedUpAtMillis[index] as number) : (userSignedUpAtMillis[index] as Date).getTime()
);
});
};

export function getUserQuery(projectId: string, userId: string): RawQuery<UsersCrud["Admin"]["Read"] | null> {
return {
sql: Prisma.sql`
SELECT to_json(
(
SELECT (
to_jsonb("ProjectUser".*) ||
jsonb_build_object(
'lastActiveAt', (
SELECT MAX("eventStartedAt") as "lastActiveAt"
FROM "Event"
WHERE data->>'projectId' = "ProjectUser"."projectId" AND ("data"->>'userId')::UUID = "ProjectUser"."projectUserId" AND "systemEventTypeIds" @> '{"$user-activity"}'
),
'UserActivityEvents', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("Event")
), '{}')
FROM "Event"
WHERE data->>'projectId' = "ProjectUser"."projectId" AND ("data"->>'userId')::UUID = "ProjectUser"."projectUserId" AND "systemEventTypeIds" @> '{"$user-activity"}'
),
'ContactChannels', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("ContactChannel") ||
jsonb_build_object()
), '{}')
FROM "ContactChannel"
WHERE "ContactChannel"."projectId" = "ProjectUser"."projectId" AND "ContactChannel"."projectUserId" = "ProjectUser"."projectUserId" AND "ContactChannel"."isPrimary" = 'TRUE'
),
'ProjectUserOAuthAccounts', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("ProjectUserOAuthAccount") ||
jsonb_build_object(
'ProviderConfig', (
SELECT to_jsonb("OAuthProviderConfig")
FROM "OAuthProviderConfig"
WHERE "OAuthProviderConfig"."id" = "ProjectUserOAuthAccount"."oauthProviderConfigId"
)
)
), '{}')
FROM "ProjectUserOAuthAccount"
WHERE "ProjectUserOAuthAccount"."projectId" = "ProjectUser"."projectId" AND "ProjectUserOAuthAccount"."projectUserId" = "ProjectUser"."projectUserId"
),
'AuthMethods', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("AuthMethod") ||
jsonb_build_object(
'PasswordAuthMethod', (
SELECT to_jsonb("PasswordAuthMethod")
FROM "PasswordAuthMethod"
WHERE "PasswordAuthMethod"."projectId" = "ProjectUser"."projectId" AND "PasswordAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasswordAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'OtpAuthMethod', (
SELECT to_jsonb("OtpAuthMethod")
FROM "OtpAuthMethod"
WHERE "OtpAuthMethod"."projectId" = "ProjectUser"."projectId" AND "OtpAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OtpAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'PasskeyAuthMethod', (
SELECT to_jsonb("PasskeyAuthMethod")
FROM "PasskeyAuthMethod"
WHERE "PasskeyAuthMethod"."projectId" = "ProjectUser"."projectId" AND "PasskeyAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasskeyAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'OAuthAuthMethod', (
SELECT to_jsonb("OAuthAuthMethod")
FROM "OAuthAuthMethod"
WHERE "OAuthAuthMethod"."projectId" = "ProjectUser"."projectId" AND "OAuthAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OAuthAuthMethod"."authMethodId" = "AuthMethod"."id"
)
)
), '{}')
FROM "AuthMethod"
WHERE "AuthMethod"."projectId" = "ProjectUser"."projectId" AND "AuthMethod"."projectUserId" = "ProjectUser"."projectUserId"
),
'SelectedTeamMember', (
SELECT
to_jsonb("TeamMember") ||
jsonb_build_object(
'Team', (
SELECT
to_jsonb("Team")
FROM "Team"
WHERE "Team"."projectId" = "ProjectUser"."projectId" AND "Team"."teamId" = "TeamMember"."teamId"
)
)
FROM "TeamMember"
WHERE "TeamMember"."projectId" = "ProjectUser"."projectId" AND "TeamMember"."projectUserId" = "ProjectUser"."projectUserId" AND "TeamMember"."isSelected" = 'TRUE'
)
)
)
FROM "ProjectUser"
WHERE "ProjectUser"."projectId" = ${projectId} AND "ProjectUser"."projectUserId" = ${userId}::UUID
)
) AS "row_data_json"
`,
postProcess: (queryResult) => {
if (queryResult.length === 0) {
return null;
}
if (queryResult.length !== 1) {
throw new StackAssertionError("Expected 1 result, got " + queryResult.length, queryResult);
}

const row = queryResult[0].row_data_json;console.log(row);
N2D4 marked this conversation as resolved.
Show resolved Hide resolved

const primaryEmailContactChannel = row.ContactChannels.find((c: any) => c.type === 'EMAIL' && c.isPrimary);
const passwordAuth = row.AuthMethods.find((m: any) => m.PasswordAuthMethod);
const otpAuth = row.AuthMethods.find((m: any) => m.OtpAuthMethod);
const passkeyAuth = row.AuthMethods.find((m: any) => m.PasskeyAuthMethod);

return {
id: row.projectUserId,
display_name: row.displayName,
primary_email: primaryEmailContactChannel?.value || null,
primary_email_verified: primaryEmailContactChannel?.isVerified || false,
primary_email_auth_enabled: primaryEmailContactChannel?.usedForAuth === 'TRUE' ? true : false,
profile_image_url: row.profileImageUrl,
signed_up_at_millis: new Date(row.createdAt + "Z").getTime(),
client_metadata: row.clientMetadata,
client_read_only_metadata: row.clientReadOnlyMetadata,
server_metadata: row.serverMetadata,
has_password: !!passwordAuth,
otp_auth_enabled: !!otpAuth,
auth_with_email: !!passwordAuth || !!otpAuth,
requires_totp_mfa: row.requiresTotpMfa,
passkey_auth_enabled: !!passkeyAuth,
oauth_providers: row.ProjectUserOAuthAccounts.map((a: any) => ({
id: a.oauthProviderConfigId,
account_id: a.providerAccountId,
email: a.email,
})),
selected_team_id: row.SelectedTeamMember?.teamId ?? null,
selected_team: row.SelectedTeamMember ? {
id: row.SelectedTeamMember.team.teamId,
display_name: row.SelectedTeamMember.team.displayName,
profile_image_url: row.SelectedTeamMember.team.profileImageUrl,
created_at_millis: row.SelectedTeamMember.team.createdAt.getTime(),
client_metadata: row.SelectedTeamMember.team.clientMetadata,
client_read_only_metadata: row.SelectedTeamMember.team.clientReadOnlyMetadata,
server_metadata: row.SelectedTeamMember.team.serverMetadata,
} : null,
last_active_at_millis: row.lastActiveAt ? new Date(row.lastActiveAt + "Z").getTime() : new Date(row.createdAt + "Z").getTime(),
};
},
};
}

export async function getUser(options: { projectId: string, userId: string }) {
const result = await rawQuery(getUserQuery(options.projectId, options.userId));

// In non-prod environments, let's also call the legacy function and ensure the result is the same
// TODO next-release: remove this
if (!getNodeEnvironment().includes("prod")) {
const legacyResult = await getUserLegacy(options);
if (!deepPlainEquals(result, legacyResult)) {
throw new StackAssertionError("User result mismatch", {
result,
legacyResult,
});
}
}

return result;
}

async function getUserLegacy(options: { projectId: string, userId: string }) {
const [db, lastActiveAtMillis] = await Promise.all([
prismaClient.projectUser.findUnique({
where: {
Expand All @@ -252,7 +408,7 @@ export async function getUser(options: { projectId: string, userId: string }) {
},
include: userFullInclude,
}),
getUserLastActiveAtMillis(options.userId),
getUserLastActiveAtMillis(options.projectId, options.userId),
]);

if (!db) {
Expand Down Expand Up @@ -333,7 +489,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
} : {},
});

const lastActiveAtMillis = await getUsersLastActiveAtMillis(db.map(user => user.projectUserId), db.map(user => user.createdAt));
const lastActiveAtMillis = await getUsersLastActiveAtMillis(auth.project.id, db.map(user => user.projectUserId), db.map(user => user.createdAt));
return {
// remove the last item because it's the next cursor
items: db.map((user, index) => userPrismaToCrud(user, lastActiveAtMillis[index])).slice(0, query.limit),
Expand Down Expand Up @@ -514,7 +670,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
throw new StackAssertionError("User was created but not found", newUser);
}

return userPrismaToCrud(user, await getUserLastActiveAtMillis(user.projectUserId) ?? user.createdAt.getTime());
return userPrismaToCrud(user, await getUserLastActiveAtMillis(auth.project.id, user.projectUserId) ?? user.createdAt.getTime());
});

if (auth.project.config.create_team_on_sign_up) {
Expand Down Expand Up @@ -826,7 +982,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
});
}

return userPrismaToCrud(db, await getUserLastActiveAtMillis(params.user_id) ?? db.createdAt.getTime());
return userPrismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, params.user_id) ?? db.createdAt.getTime());
});


Expand Down
48 changes: 48 additions & 0 deletions apps/backend/src/prisma-client.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Prisma, PrismaClient } from "@prisma/client";
import { withAccelerate } from "@prisma/extension-accelerate";
import { getEnvVariable, getNodeEnvironment } from '@stackframe/stack-shared/dist/utils/env';
import { isNotNull, typedFromEntries, typedKeys } from "@stackframe/stack-shared/dist/utils/objects";
import { Result } from "@stackframe/stack-shared/dist/utils/results";
import { traceSpan } from "./utils/telemetry";

Expand Down Expand Up @@ -39,3 +40,50 @@ export async function retryTransaction<T>(fn: (...args: Parameters<Parameters<ty
return Result.orThrow(res);
});
}

export type RawQuery<T> = {
sql: Prisma.Sql,
postProcess: (rows: any[]) => T, // Tip: If your postProcess is async, just set T = Promise<any> (compared to doing Promise.all in rawQuery, this ensures that there are no accidental timing attacks)
};

export async function rawQuery<Q extends RawQuery<any>>(query: Q): Promise<Awaited<ReturnType<Q["postProcess"]>>> {
const result = await rawQueryArray([query]);
return result[0];
}

export async function rawQueryAll<Q extends Record<string, undefined | RawQuery<any>>>(queries: Q): Promise<{ [K in keyof Q]: Awaited<ReturnType<NonNullable<Q[K]>["postProcess"]>> }> {
const keys = typedKeys(queries);
const result = await rawQueryArray(keys.map(key => queries[key]).filter(isNotNull));
return typedFromEntries(keys.map((key, index) => [key, result[index]]));
}

async function rawQueryArray<Q extends RawQuery<any>[]>(queries: Q): Promise<[] & { [K in keyof Q]: Awaited<ReturnType<Q[K]["postProcess"]>> }> {
if (queries.length === 0) return [] as any;

const query = Prisma.sql`
WITH ${Prisma.join(queries.map((q, index) => {
return Prisma.sql`${Prisma.raw("q" + index)} AS (
${q.sql}
)`;
}), ",\n")}

${Prisma.join(queries.map((q, index) => {
return Prisma.sql`
SELECT
${"q" + index} AS type,
row_to_json(c) AS json
FROM (SELECT * FROM ${Prisma.raw("q" + index)}) c
`;
}), "\nUNION ALL\n")}
`;
const rawResult = await prismaClient.$queryRaw(query) as { type: string, json: any }[];
const unprocessed = new Array(queries.length).fill(null).map(() => [] as any[]);
for (const row of rawResult) {
const type = row.type;
const index = +type.slice(1);
unprocessed[index].push(row.json);
}
const postProcessed = queries.map((q, index) => q.postProcess(unprocessed[index]));
return postProcessed as any;
}

2 changes: 1 addition & 1 deletion apps/backend/src/route-handlers/crud-handler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ export function createCrudHandlers<
},
});
} catch (error) {
if (allowedErrorTypes?.some((a) => error instanceof a)) {
if (allowedErrorTypes?.some((a) => error instanceof a) || error instanceof StackAssertionError) {
throw error;
}
throw new CrudHandlerInvocationError(error);
Expand Down
Loading
Loading