Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw SQL query for fetching users #381

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apps/backend/src/app/api/v1/team-member-profiles/crud.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
include: fullInclude,
});

const lastActiveAtMillis = await getUsersLastActiveAtMillis(db.map(user => user.projectUserId), db.map(user => user.createdAt));
const lastActiveAtMillis = await getUsersLastActiveAtMillis(auth.project.id, db.map(user => user.projectUserId), db.map(user => user.createdAt));

return {
items: db.map((user, index) => prismaToCrud(user, lastActiveAtMillis[index])),
Expand Down Expand Up @@ -118,7 +118,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
throw new KnownErrors.TeamMembershipNotFound(params.team_id, params.user_id);
}

return prismaToCrud(db, await getUserLastActiveAtMillis(db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
return prismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
});
},
onUpdate: async ({ auth, data, params }) => {
Expand Down Expand Up @@ -151,7 +151,7 @@ export const teamMemberProfilesCrudHandlers = createLazyProxy(() => createCrudHa
include: fullInclude,
});

return prismaToCrud(db, await getUserLastActiveAtMillis(db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
return prismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, db.projectUser.projectUserId) ?? db.projectUser.createdAt.getTime());
});
},
}));
212 changes: 188 additions & 24 deletions apps/backend/src/app/api/v1/users/crud.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ensureTeamMembershipExists, ensureUserExists } from "@/lib/request-checks";
import { PrismaTransaction } from "@/lib/types";
import { sendTeamMembershipDeletedWebhook, sendUserCreatedWebhook, sendUserDeletedWebhook, sendUserUpdatedWebhook } from "@/lib/webhooks";
import { prismaClient, retryTransaction } from "@/prisma-client";
import { RawQuery, prismaClient, rawQuery, retryTransaction } from "@/prisma-client";
import { createCrudHandlers } from "@/route-handlers/crud-handler";
import { runAsynchronouslyAndWaitUntil } from "@/utils/vercel";
import { BooleanTrue, Prisma } from "@prisma/client";
Expand All @@ -11,8 +11,10 @@ import { UsersCrud, usersCrud } from "@stackframe/stack-shared/dist/interface/cr
import { userIdOrMeSchema, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { validateBase64Image } from "@stackframe/stack-shared/dist/utils/base64";
import { decodeBase64 } from "@stackframe/stack-shared/dist/utils/bytes";
import { getNodeEnvironment } from "@stackframe/stack-shared/dist/utils/env";
import { StackAssertionError, StatusError, throwErr } from "@stackframe/stack-shared/dist/utils/errors";
import { hashPassword, isPasswordHashValid } from "@stackframe/stack-shared/dist/utils/hashes";
import { deepPlainEquals } from "@stackframe/stack-shared/dist/utils/objects";
import { createLazyProxy } from "@stackframe/stack-shared/dist/utils/proxies";
import { typedToLowercase } from "@stackframe/stack-shared/dist/utils/strings";
import { teamPrismaToCrud, teamsCrudHandlers } from "../teams/crud";
Expand Down Expand Up @@ -203,45 +205,207 @@ async function getOtpConfig(tx: PrismaTransaction, projectConfigId: string) {
return otpConfig.length === 0 ? null : otpConfig[0];
}

export const getUserLastActiveAtMillis = async (userId: string): Promise<number | null> => {
const event = await prismaClient.event.findFirst({
where: {
data: {
path: ["$.userId"],
equals: userId,
},
},
orderBy: {
createdAt: 'desc',
},
});

return event?.createdAt.getTime() ?? null;
export const getUserLastActiveAtMillis = async (projectId: string, userId: string): Promise<number | null> => {
const res = (await getUsersLastActiveAtMillis(projectId, [userId], [0]))[0];
if (res === 0) {
return null;
}
return res;
};

// same as userIds.map(userId => getUserLastActiveAtMillis(userId, fallbackTo)), but uses a single query
export const getUsersLastActiveAtMillis = async (userIds: string[], fallbackTo: (number | Date)[]): Promise<number[]> => {
// same as userIds.map(userId => getUserLastActiveAtMillis(projectId, userId)), but uses a single query
export const getUsersLastActiveAtMillis = async (projectId: string, userIds: string[], userSignedUpAtMillis: (number | Date)[]): Promise<number[]> => {
if (userIds.length === 0) {
// Prisma.join throws an error if the array is empty, so we need to handle that case
return [];
}

const events = await prismaClient.$queryRaw<Array<{ userId: string, lastActiveAt: Date }>>`
SELECT data->>'userId' as "userId", MAX("createdAt") as "lastActiveAt"
SELECT data->>'userId' as "userId", MAX("eventStartedAt") as "lastActiveAt"
FROM "Event"
WHERE data->>'userId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(userIds)}]`})
WHERE data->>'userId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(userIds)}]`}) AND data->>'projectId' = ${projectId} AND "systemEventTypeIds" @> '{"$user-activity"}'
GROUP BY data->>'userId'
`;

return userIds.map((userId, index) => {
const event = events.find(e => e.userId === userId);
return event ? event.lastActiveAt.getTime() : (
typeof fallbackTo[index] === "number" ? (fallbackTo[index] as number) : (fallbackTo[index] as Date).getTime()
typeof userSignedUpAtMillis[index] === "number" ? (userSignedUpAtMillis[index] as number) : (userSignedUpAtMillis[index] as Date).getTime()
);
});
};

export function getUserQuery(projectId: string, userId: string): RawQuery<UsersCrud["Admin"]["Read"] | null> {
return {
sql: Prisma.sql`
SELECT to_json(
(
SELECT (
to_jsonb("ProjectUser".*) ||
jsonb_build_object(
'lastActiveAt', (
SELECT MAX("eventStartedAt") as "lastActiveAt"
FROM "Event"
WHERE data->>'projectId' = "ProjectUser"."projectId" AND ("data"->>'userId')::UUID = "ProjectUser"."projectUserId" AND "systemEventTypeIds" @> '{"$user-activity"}'
),
'ContactChannels', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("ContactChannel") ||
jsonb_build_object()
), '{}')
FROM "ContactChannel"
WHERE "ContactChannel"."projectId" = "ProjectUser"."projectId" AND "ContactChannel"."projectUserId" = "ProjectUser"."projectUserId" AND "ContactChannel"."isPrimary" = 'TRUE'
),
'ProjectUserOAuthAccounts', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("ProjectUserOAuthAccount") ||
jsonb_build_object(
'ProviderConfig', (
SELECT to_jsonb("OAuthProviderConfig")
FROM "OAuthProviderConfig"
WHERE "ProjectConfig"."id" = "OAuthProviderConfig"."projectConfigId" AND "OAuthProviderConfig"."id" = "ProjectUserOAuthAccount"."oauthProviderConfigId"
)
)
), '{}')
FROM "ProjectUserOAuthAccount"
WHERE "ProjectUserOAuthAccount"."projectId" = "ProjectUser"."projectId" AND "ProjectUserOAuthAccount"."projectUserId" = "ProjectUser"."projectUserId"
),
'AuthMethods', (
SELECT COALESCE(ARRAY_AGG(
to_jsonb("AuthMethod") ||
jsonb_build_object(
'PasswordAuthMethod', (
SELECT (
to_jsonb("PasswordAuthMethod") ||
jsonb_build_object()
)
FROM "PasswordAuthMethod"
WHERE "PasswordAuthMethod"."projectId" = "ProjectUser"."projectId" AND "PasswordAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasswordAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'OtpAuthMethod', (
SELECT (
to_jsonb("OtpAuthMethod") ||
jsonb_build_object()
)
FROM "OtpAuthMethod"
WHERE "OtpAuthMethod"."projectId" = "ProjectUser"."projectId" AND "OtpAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OtpAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'PasskeyAuthMethod', (
SELECT (
to_jsonb("PasskeyAuthMethod") ||
jsonb_build_object()
)
FROM "PasskeyAuthMethod"
WHERE "PasskeyAuthMethod"."projectId" = "ProjectUser"."projectId" AND "PasskeyAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasskeyAuthMethod"."authMethodId" = "AuthMethod"."id"
),
'OAuthAuthMethod', (
SELECT (
to_jsonb("OAuthAuthMethod") ||
jsonb_build_object()
)
FROM "OAuthAuthMethod"
WHERE "OAuthAuthMethod"."projectId" = "ProjectUser"."projectId" AND "OAuthAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OAuthAuthMethod"."authMethodId" = "AuthMethod"."id"
)
)
), '{}')
FROM "AuthMethod"
WHERE "AuthMethod"."projectId" = "ProjectUser"."projectId" AND "AuthMethod"."projectUserId" = "ProjectUser"."projectUserId"
),
'SelectedTeamMember', (
SELECT (
to_jsonb("TeamMember") ||
jsonb_build_object(
'Team', (
SELECT
to_jsonb("Team")
FROM "Team"
WHERE "Team"."projectId" = "ProjectUser"."projectId" AND "Team"."teamId" = "TeamMember"."teamId"
)
)
)
FROM "TeamMember"
WHERE "TeamMember"."projectId" = "ProjectUser"."projectId" AND "TeamMember"."projectUserId" = "ProjectUser"."projectUserId" AND "TeamMember"."isSelected" = 'TRUE'
)
)
)
FROM "ProjectUser"
LEFT JOIN "Project" ON "Project"."id" = "ProjectUser"."projectId"
LEFT JOIN "ProjectConfig" ON "ProjectConfig"."id" = "Project"."configId"
WHERE "ProjectUser"."projectId" = ${projectId} AND "ProjectUser"."projectUserId" = ${userId}::UUID
)
) AS "row_data_json"
`,
postProcess: (queryResult) => {
if (queryResult.length !== 1) {
throw new StackAssertionError("Expected 1 result, got " + queryResult.length, queryResult);
}

const row = queryResult[0].row_data_json;
if (!row) {
return null;
}

const primaryEmailContactChannel = row.ContactChannels.find((c: any) => c.type === 'EMAIL' && c.isPrimary);
const passwordAuth = row.AuthMethods.find((m: any) => m.PasswordAuthMethod);
const otpAuth = row.AuthMethods.find((m: any) => m.OtpAuthMethod);
const passkeyAuth = row.AuthMethods.find((m: any) => m.PasskeyAuthMethod);

return {
id: row.projectUserId,
display_name: row.displayName,
primary_email: primaryEmailContactChannel?.value || null,
primary_email_verified: primaryEmailContactChannel?.isVerified || false,
primary_email_auth_enabled: primaryEmailContactChannel?.usedForAuth === 'TRUE' ? true : false,
profile_image_url: row.profileImageUrl,
signed_up_at_millis: new Date(row.createdAt + "Z").getTime(),
client_metadata: row.clientMetadata,
client_read_only_metadata: row.clientReadOnlyMetadata,
server_metadata: row.serverMetadata,
has_password: !!passwordAuth,
otp_auth_enabled: !!otpAuth,
auth_with_email: !!passwordAuth || !!otpAuth,
requires_totp_mfa: row.requiresTotpMfa,
passkey_auth_enabled: !!passkeyAuth,
oauth_providers: row.ProjectUserOAuthAccounts.map((a: any) => ({
id: a.oauthProviderConfigId,
account_id: a.providerAccountId,
email: a.email,
})),
selected_team_id: row.SelectedTeamMember?.teamId ?? null,
selected_team: row.SelectedTeamMember ? {
id: row.SelectedTeamMember.team.teamId,
display_name: row.SelectedTeamMember.team.displayName,
profile_image_url: row.SelectedTeamMember.team.profileImageUrl,
created_at_millis: row.SelectedTeamMember.team.createdAt.getTime(),
client_metadata: row.SelectedTeamMember.team.clientMetadata,
client_read_only_metadata: row.SelectedTeamMember.team.clientReadOnlyMetadata,
server_metadata: row.SelectedTeamMember.team.serverMetadata,
} : null,
last_active_at_millis: row.lastActiveAt ? new Date(row.lastActiveAt + "Z").getTime() : new Date(row.createdAt + "Z").getTime(),
};
},
};
}

export async function getUser(options: { projectId: string, userId: string }) {
const result = await rawQuery(getUserQuery(options.projectId, options.userId));

// In non-prod environments, let's also call the legacy function and ensure the result is the same
// TODO next-release: remove this
if (!getNodeEnvironment().includes("prod")) {
const legacyResult = await getUserLegacy(options);
if (!deepPlainEquals(result, legacyResult)) {
throw new StackAssertionError("User result mismatch", {
result,
legacyResult,
});
}
}

return result;
}

async function getUserLegacy(options: { projectId: string, userId: string }) {
const [db, lastActiveAtMillis] = await Promise.all([
prismaClient.projectUser.findUnique({
where: {
Expand All @@ -252,7 +416,7 @@ export async function getUser(options: { projectId: string, userId: string }) {
},
include: userFullInclude,
}),
getUserLastActiveAtMillis(options.userId),
getUserLastActiveAtMillis(options.projectId, options.userId),
]);

if (!db) {
Expand Down Expand Up @@ -333,7 +497,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
} : {},
});

const lastActiveAtMillis = await getUsersLastActiveAtMillis(db.map(user => user.projectUserId), db.map(user => user.createdAt));
const lastActiveAtMillis = await getUsersLastActiveAtMillis(auth.project.id, db.map(user => user.projectUserId), db.map(user => user.createdAt));
return {
// remove the last item because it's the next cursor
items: db.map((user, index) => userPrismaToCrud(user, lastActiveAtMillis[index])).slice(0, query.limit),
Expand Down Expand Up @@ -514,7 +678,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
throw new StackAssertionError("User was created but not found", newUser);
}

return userPrismaToCrud(user, await getUserLastActiveAtMillis(user.projectUserId) ?? user.createdAt.getTime());
return userPrismaToCrud(user, await getUserLastActiveAtMillis(auth.project.id, user.projectUserId) ?? user.createdAt.getTime());
});

if (auth.project.config.create_team_on_sign_up) {
Expand Down Expand Up @@ -826,7 +990,7 @@ export const usersCrudHandlers = createLazyProxy(() => createCrudHandlers(usersC
});
}

return userPrismaToCrud(db, await getUserLastActiveAtMillis(params.user_id) ?? db.createdAt.getTime());
return userPrismaToCrud(db, await getUserLastActiveAtMillis(auth.project.id, params.user_id) ?? db.createdAt.getTime());
});


Expand Down
52 changes: 44 additions & 8 deletions apps/backend/src/lib/api-keys.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// TODO remove and replace with CRUD handler

import { prismaClient } from '@/prisma-client';
import { ApiKeySet } from '@prisma/client';
import { RawQuery, prismaClient, rawQuery } from '@/prisma-client';
import { ApiKeySet, Prisma } from '@prisma/client';
import { ApiKeysCrud } from '@stackframe/stack-shared/dist/interface/crud/api-keys';
import { yupString } from '@stackframe/stack-shared/dist/schema-fields';
import { typedIncludes } from '@stackframe/stack-shared/dist/utils/arrays';
Expand All @@ -13,10 +13,46 @@ export const publishableClientKeyHeaderSchema = yupString().matches(/^[a-zA-Z0-9
export const secretServerKeyHeaderSchema = publishableClientKeyHeaderSchema;
export const superSecretAdminKeyHeaderSchema = secretServerKeyHeaderSchema;

export async function checkApiKeySet(
...args: Parameters<typeof getApiKeySet>
): Promise<boolean> {
const set = await getApiKeySet(...args);
export function checkApiKeySetQuery(projectId: string, key: KeyType): RawQuery<boolean> {
key = validateKeyType(key);
const keyType = Object.keys(key)[0] as keyof KeyType;
const keyValue = key[keyType];

const whereClause = Prisma.sql`
${Prisma.raw(JSON.stringify(keyType))} = ${keyValue}
`;

return {
sql: Prisma.sql`
SELECT 't' AS "result"
FROM "ApiKeySet"
WHERE ${whereClause}
AND "projectId" = ${projectId}
AND "manuallyRevokedAt" IS NULL
AND "expiresAt" > ${new Date()}
`,
postProcess: (rows) => rows[0]?.result === "t",
};
}

export async function checkApiKeySet(projectId: string, key: KeyType): Promise<boolean> {
const result = await rawQuery(checkApiKeySetQuery(projectId, key));

// In non-prod environments, let's also call the legacy function and ensure the result is the same
// TODO next-release: remove this
const legacy = await checkApiKeySetLegacy(projectId, key);
if (legacy !== result) {
throw new StackAssertionError("checkApiKeySet result mismatch", {
result,
legacy,
});
}

return result;
}

async function checkApiKeySetLegacy(projectId: string, key: KeyType): Promise<boolean> {
const set = await getApiKeySet(projectId, key);
if (!set) return false;
if (set.manually_revoked_at_millis) return false;
if (set.expires_at_millis < Date.now()) return false;
Expand All @@ -29,7 +65,7 @@ type KeyType =
| { secretServerKey: string }
| { superSecretAdminKey: string };

function assertKeyType(obj: any): KeyType {
function validateKeyType(obj: any): KeyType {
if (typeof obj !== 'object' || obj === null) {
throw new StackAssertionError('Invalid key type', { obj });
}
Expand Down Expand Up @@ -64,7 +100,7 @@ export async function getApiKeySet(
}
}
: {
...assertKeyType(whereOrId),
...validateKeyType(whereOrId),
projectId,
};

Expand Down
Loading
Loading