diff --git a/modules/auth_oauth2/config.ts b/modules/auth_oauth2/config.ts new file mode 100644 index 00000000..76e61b58 --- /dev/null +++ b/modules/auth_oauth2/config.ts @@ -0,0 +1,11 @@ +export interface Config { + providers: Record; +} + +export interface ProviderEndpoints { + authorization: string; + token: string; + userinfo: string; + scopes: string; + userinfoKey: string; +} diff --git a/modules/auth_oauth2/db/migrations/20240508161825_/migration.sql b/modules/auth_oauth2/db/migrations/20240508161825_/migration.sql new file mode 100644 index 00000000..5986a899 --- /dev/null +++ b/modules/auth_oauth2/db/migrations/20240508161825_/migration.sql @@ -0,0 +1,45 @@ +-- CreateTable +CREATE TABLE "OAuthUsers" ( + "userId" UUID NOT NULL, + "provider" TEXT NOT NULL, + "sub" TEXT NOT NULL, + "createdAt" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "OAuthUsers_pkey" PRIMARY KEY ("provider","userId") +); + +-- CreateTable +CREATE TABLE "OAuthLoginAttempt" ( + "id" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "state" TEXT NOT NULL, + "codeVerifier" TEXT NOT NULL, + "targetUrl" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "completedAt" TIMESTAMP(3), + "invalidatedAt" TIMESTAMP(3), + + CONSTRAINT "OAuthLoginAttempt_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "OAuthCreds" ( + "id" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "accessToken" TEXT NOT NULL, + "refreshToken" TEXT NOT NULL, + "expiresAt" TIMESTAMP(3) NOT NULL, + "userToken" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "loginAttemptId" TEXT NOT NULL, + + CONSTRAINT "OAuthCreds_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "OAuthCreds_loginAttemptId_key" ON "OAuthCreds"("loginAttemptId"); + +-- AddForeignKey +ALTER TABLE "OAuthCreds" ADD CONSTRAINT "OAuthCreds_loginAttemptId_fkey" FOREIGN KEY ("loginAttemptId") REFERENCES "OAuthLoginAttempt"("id") ON DELETE RESTRICT ON UPDATE CASCADE; diff --git a/modules/auth_oauth2/db/migrations/migration_lock.toml b/modules/auth_oauth2/db/migrations/migration_lock.toml new file mode 100644 index 00000000..fbffa92c --- /dev/null +++ b/modules/auth_oauth2/db/migrations/migration_lock.toml @@ -0,0 +1,3 @@ +# Please do not edit this file manually +# It should be added in your version-control system (i.e. Git) +provider = "postgresql" \ No newline at end of file diff --git a/modules/auth_oauth2/db/schema.prisma b/modules/auth_oauth2/db/schema.prisma new file mode 100644 index 00000000..5f40c38f --- /dev/null +++ b/modules/auth_oauth2/db/schema.prisma @@ -0,0 +1,48 @@ +// Do not modify this `datasource` block +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model OAuthUsers { + userId String @db.Uuid + + provider String + sub String + createdAt DateTime @default(now()) @db.Timestamp + + @@id([provider, userId]) +} + +model OAuthLoginAttempt { + id String @id @default(uuid()) + + provider String + state String + codeVerifier String + targetUrl String + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + completedAt DateTime? + invalidatedAt DateTime? + + creds OAuthCreds? +} + +model OAuthCreds { + id String @id @default(uuid()) + + provider String + accessToken String + refreshToken String + expiresAt DateTime + userToken String + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + loginAttemptId String @unique + loginAttempt OAuthLoginAttempt @relation(fields: [loginAttemptId], references: [id]) +} + diff --git a/modules/auth_oauth2/module.yaml b/modules/auth_oauth2/module.yaml new file mode 100644 index 00000000..294f0d9f --- /dev/null +++ b/modules/auth_oauth2/module.yaml @@ -0,0 +1,33 @@ +name: OAuth Authentication +description: Authenticate users with OAuth 2.0. +icon: key +tags: + - core + - auth + - user +authors: + - rivet-gg + - Skyler Calaman +status: stable +dependencies: + users: {} + tokens: {} + rate_limit: {} +scripts: + login_link: + name: Login Link + description: Generate a login link for accessing OpenGB. + public: true + api: + methods: [GET] + data: [query] + login_callback: + name: OAuth Redirect Callback + description: Verify a user's OAuth login and create a session. + public: true + api: + methods: [GET] + data: [query] +errors: + invalid_config: + name: Invalid OAuth Provider Configuration diff --git a/modules/auth_oauth2/scripts/login_callback.ts b/modules/auth_oauth2/scripts/login_callback.ts new file mode 100644 index 00000000..64930765 --- /dev/null +++ b/modules/auth_oauth2/scripts/login_callback.ts @@ -0,0 +1,131 @@ +import { ScriptContext, RuntimeError, Empty } from "../_gen/scripts/login_callback.ts"; +import { getHttpPath, getCodeVerifierFromCookie, getStateFromCookie, getLoginIdFromCookie } from "../utils/trace.ts"; +import { getFullConfig } from "../utils/env.ts"; +import { getClient } from "../utils/client.ts"; +import { getUserUniqueIdentifier } from "../utils/client.ts"; + +export type Request = Record; +export type Response = Empty; + +export async function run( + ctx: ScriptContext, + req: Request, +): Promise { + // Max 2 login attempts per IP per minute + ctx.modules.rateLimit.throttlePublic({ requests: 5, period: 60 }); + + // Ensure that the provider configurations are valid + const config = await getFullConfig(ctx.userConfig); + if (!config) throw new RuntimeError("invalid_config", { statusCode: 500 }); + + const loginId = getLoginIdFromCookie(ctx); + const codeVerifier = getCodeVerifierFromCookie(ctx); + const state = getStateFromCookie(ctx); + + if (!loginId || !codeVerifier || !state) throw new RuntimeError("missing_login_data", { statusCode: 400 }); + + + // Get the login attempt stored in the database + const loginAttempt = await ctx.db.oAuthLoginAttempt.findUnique({ + where: { id: loginId, completedAt: null, invalidatedAt: null }, + }); + + if (!loginAttempt) throw new RuntimeError("login_not_found", { statusCode: 400 }); + if (loginAttempt.state !== state) throw new RuntimeError("invalid_state", { statusCode: 400 }); + if (loginAttempt.codeVerifier !== codeVerifier) throw new RuntimeError("invalid_code_verifier", { statusCode: 400 }); + + // Get the provider config + const provider = config.providers[loginAttempt.provider]; + if (!provider) throw new RuntimeError("invalid_provider", { statusCode: 400 }); + + // Get the oauth client + const client = getClient(config, provider.name); + if (!client.config.redirectUri) throw new RuntimeError("invalid_config", { statusCode: 500 }); + + + // Get the URI that this request was made to + const uri = new URL(client.config.redirectUri); + + const path = getHttpPath(ctx); + if (!path) throw new RuntimeError("invalid_request", { statusCode: 400 }); + uri.pathname = path; + + for (const key in req) { + const value = req[key]; + if (typeof value === "string") { + uri.searchParams.set(key, value); + } else { + uri.searchParams.set(key, JSON.stringify(value)); + } + } + const uriStr = uri.toString(); + + // Get the user's tokens and sub + let tokens: Awaited>; + let sub: string; + try { + tokens = await client.code.getToken(uriStr, { state, codeVerifier }); + sub = await getUserUniqueIdentifier(tokens.accessToken, provider); + } catch (e) { + console.error(e); + throw new RuntimeError("invalid_oauth_response", { statusCode: 502 }); + } + + const expiresIn = tokens.expiresIn ?? 3600; + const expiry = new Date(Date.now() + expiresIn); + + // Ensure the user is registered with this sub/provider combo + const user = await ctx.db.oAuthUsers.findFirst({ + where: { + sub, + provider: loginAttempt.provider, + }, + }); + + let userId: string; + if (user) { + userId = user.userId; + } else { + const { user: newUser } = await ctx.modules.users.createUser({ username: sub }); + await ctx.db.oAuthUsers.create({ + data: { + sub, + provider: loginAttempt.provider, + userId: newUser.id, + }, + }); + + userId = newUser.id; + } + + // Generate a token which the user can use to authenticate with this module + const { token } = await ctx.modules.users.createUserToken({ userId }); + + // Record the credentials + await ctx.db.oAuthCreds.create({ + data: { + loginAttemptId: loginAttempt.id, + provider: provider.name, + accessToken: tokens.accessToken, + refreshToken: tokens.refreshToken ?? "", + userToken: token.token, + expiresAt: expiry, + }, + }); + + + // Redirect the user to the target URL + ctx.statusCode = 303; + ctx.headers.set("Location", loginAttempt.targetUrl); + + // Set token cookie and expire state cookie + const cookieAttribs = `Path=/; Max-Age=${expiresIn}; SameSite=Lax; Expires=${expiry.toUTCString()}`; + ctx.headers.append("Set-Cookie", `token=${token.token}; ${cookieAttribs}`); + + const expireAttribs = `Path=/; Max-Age=0; SameSite=Lax; Expires=${new Date(0).toUTCString()}`; + ctx.headers.append("Set-Cookie", `login_id=EXPIRED; ${expireAttribs}`); + ctx.headers.append("Set-Cookie", `code_verifier=EXPIRED; ${expireAttribs}`); + ctx.headers.append("Set-Cookie", `state=EXPIRED; ${expireAttribs}`); + + return {}; +} diff --git a/modules/auth_oauth2/scripts/login_link.ts b/modules/auth_oauth2/scripts/login_link.ts new file mode 100644 index 00000000..37a87658 --- /dev/null +++ b/modules/auth_oauth2/scripts/login_link.ts @@ -0,0 +1,46 @@ +import { ScriptContext, RuntimeError, Empty } from "../_gen/scripts/login_link.ts"; +import { getFullConfig } from "../utils/env.ts"; +import { getClient } from "../utils/client.ts"; +import { generateStateStr } from "../utils/state.ts"; + +export interface Request { + provider: string; + targetUrl: string; +} + +export type Response = Empty; + +export async function run( + ctx: ScriptContext, + req: Request, +): Promise { + // Max 5 login attempts per IP per minute + ctx.modules.rateLimit.throttlePublic({ requests: 5, period: 60 }); + + // Ensure that the provider configurations are valid + const providers = await getFullConfig(ctx.userConfig); + if (!providers) throw new RuntimeError("invalid_config", { statusCode: 500 }); + + const client = getClient(providers, req.provider); + const state = generateStateStr(); + + const { uri, codeVerifier } = await client.code.getAuthorizationUri({ state }); + + const { id: loginId } = await ctx.db.oAuthLoginAttempt.create({ + data: { + provider: req.provider, + targetUrl: req.targetUrl, + state, + codeVerifier, + }, + }); + + ctx.statusCode = 303; + ctx.headers.set("Location", uri.toString()); + + ctx.headers.set("Cache-Control", "no-store"); + ctx.headers.append("Set-Cookie", `login_id=${encodeURIComponent(loginId)}; SameSite=Lax; Path=/; Max-Age=300`); + ctx.headers.append("Set-Cookie", `code_verifier=${encodeURIComponent(codeVerifier)}; SameSite=Lax; Path=/; Max-Age=300`); + ctx.headers.append("Set-Cookie", `state=${encodeURIComponent(state)}; SameSite=Lax; Path=/; Max-Age=300`); + return {}; +} diff --git a/modules/auth_oauth2/utils/client.ts b/modules/auth_oauth2/utils/client.ts new file mode 100644 index 00000000..cf696db4 --- /dev/null +++ b/modules/auth_oauth2/utils/client.ts @@ -0,0 +1,52 @@ +import { OAuth2Client } from "https://deno.land/x/oauth2_client@v1.0.2/mod.ts"; +import { FullConfig, ProviderConfig } from "./env.ts"; +import { RuntimeError } from "../_gen/mod.ts"; + +export function getClient(cfg: FullConfig, provider: string) { + const providerCfg = cfg.providers[provider]; + if (!providerCfg) throw new RuntimeError("invalid_provider", { statusCode: 400 }); + + return new OAuth2Client({ + clientId: providerCfg.clientId, + clientSecret: providerCfg.clientSecret, + authorizationEndpointUri: providerCfg.endpoints.authorization, + tokenUri: providerCfg.endpoints.token, + // TODO: Use a real redirect URI + redirectUri: "http://localhost:8080/modules/auth_oauth2/scripts/login_callback/call", + defaults: { + scope: providerCfg.endpoints.scopes, + }, + }); +} + +export async function getUserUniqueIdentifier(accessToken: string, provider: ProviderConfig): Promise { + const res = await fetch(provider.endpoints.userinfo, { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (!res.ok) throw new RuntimeError("bad_oauth_response", { statusCode: 502 }); + + let json: unknown; + try { + json = await res.json(); + } catch { + throw new RuntimeError("bad_oauth_response", { statusCode: 502 }); + } + + if (typeof json !== "object" || json === null) { + throw new RuntimeError("bad_oauth_response", { statusCode: 502 }); + } + + const jsonObj = json as Record; + const uniqueIdent = jsonObj[provider.endpoints.userinfoKey]; + + if (typeof uniqueIdent !== "string" && typeof uniqueIdent !== "number") { + console.warn("Invalid userinfo response", jsonObj); + throw new RuntimeError("bad_oauth_response", { statusCode: 502 }); + } + if (!uniqueIdent) throw new RuntimeError("bad_oauth_response", { statusCode: 502 }); + + return uniqueIdent.toString(); +} diff --git a/modules/auth_oauth2/utils/env.ts b/modules/auth_oauth2/utils/env.ts new file mode 100644 index 00000000..a5490143 --- /dev/null +++ b/modules/auth_oauth2/utils/env.ts @@ -0,0 +1,55 @@ +import { Config, ProviderEndpoints } from "../config.ts"; +import { getFromOidcWellKnown } from "./wellknown.ts"; + +export interface FullConfig { + providers: Record; + oauthSecret: string; +} + +export interface ProviderConfig { + name: string; + clientId: string; + clientSecret: string; + endpoints: ProviderEndpoints; +} + +export async function getProvidersEnvConfig(providerCfg: Config["providers"]): Promise { + const baseProviders = Object.entries(providerCfg).map(([name, config]) => ({ name, config })); + + const providers: ProviderConfig[] = []; + for (const { name, config } of baseProviders) { + const clientIdEnv = `${name.toUpperCase()}_OAUTH_CLIENT_ID`; + const clientSecretEnv = `${name.toUpperCase()}_OAUTH_CLIENT_SECRET`; + + const clientId = Deno.env.get(clientIdEnv); + const clientSecret = Deno.env.get(clientSecretEnv); + if (!clientId || !clientSecret) return null; + + let resolvedConfig: ProviderEndpoints; + if (typeof config === "string") { + resolvedConfig = await getFromOidcWellKnown(config); + } else { + resolvedConfig = config; + } + + providers.push({ name, clientId, clientSecret, endpoints: resolvedConfig }); + } + + return providers; +} + +export function getOauthSecret(): string | null { + return Deno.env.get("OAUTH_SECRET") ?? null; +} + +export async function getFullConfig(cfg: Config): Promise { + const providerArr = await getProvidersEnvConfig(cfg.providers); + if (!providerArr) return null; + + const providers = Object.fromEntries(providerArr.map(p => [p.name, p])); + + const oauthSecret = getOauthSecret(); + if (!oauthSecret) return null; + + return { providers, oauthSecret }; +} diff --git a/modules/auth_oauth2/utils/state.ts b/modules/auth_oauth2/utils/state.ts new file mode 100644 index 00000000..40a0f198 --- /dev/null +++ b/modules/auth_oauth2/utils/state.ts @@ -0,0 +1,46 @@ +import base64 from "https://deno.land/x/b64@1.1.28/src/base64.js"; + +const STATE_BYTES = 16; + + +type InputData = ArrayBufferLike | Uint8Array | string; + +/** + * Generates a new random `STATE_BYTES`-byte state buffer. + * + * @returns A new random state buffer + */ +export function generateState(): ArrayBufferLike { + return crypto.getRandomValues(new Uint8Array(STATE_BYTES)); +} + +/** + * Generates a new random string with `STATE_BYTES` bytes of entropy. + */ +export function generateStateStr(): string { + return base64.fromArrayBuffer(generateState()); +} + +/** + * Compares two buffers for equality in a way that is resistant to timing + * attacks. + * + * @param a The first buffer + * @param b The second buffer + * @returns Whether the two buffers are equal + */ +export function compareConstantTime(a: InputData, b: InputData): boolean { + const bufLikeA = typeof a === "string" ? new TextEncoder().encode(a) : a; + const bufLikeB = typeof b === "string" ? new TextEncoder().encode(b) : b; + + if (bufLikeA.byteLength !== bufLikeB.byteLength) return false; + + const bufA = new Uint8Array(bufLikeA); + const bufB = new Uint8Array(bufLikeB); + + let result = 0; + for (let i = 0; i < bufLikeA.byteLength; i++) { + result |= bufA[i] ^ bufB[i]; + } + return result === 0; +} diff --git a/modules/auth_oauth2/utils/trace.ts b/modules/auth_oauth2/utils/trace.ts new file mode 100644 index 00000000..38e8ae39 --- /dev/null +++ b/modules/auth_oauth2/utils/trace.ts @@ -0,0 +1,51 @@ +import { ModuleContext } from "../_gen/mod.ts"; + +export function getHttpPath(ctx: T): string | undefined { + for (const entry of ctx.trace.entries) { + if ("httpRequest" in entry.type) { + return entry.type.httpRequest.path; + } + } + return undefined; +} + +export function getCookieString(ctx: T): string | undefined { + for (const entry of ctx.trace.entries) { + if ("httpRequest" in entry.type) { + return entry.type.httpRequest.headers["cookie"]; + } + } + return undefined; +} + +export function getCookieObject(ctx: T): Record | null { + const cookieString = getCookieString(ctx); + if (!cookieString) return null; + + const pairs = cookieString + .split(";") + .map(pair => pair.trim()) + .map(pair => pair.split("=")) + .map(([key, value]) => [decodeURIComponent(key), decodeURIComponent(value)]); + + return Object.fromEntries(pairs); +} + + +export function getLoginIdFromCookie(ctx: T): string | null { + const cookies = getCookieObject(ctx); + if (!cookies) return null; + return cookies["login_id"] || null; +} + +export function getCodeVerifierFromCookie(ctx: T): string | null { + const cookies = getCookieObject(ctx); + if (!cookies) return null; + return cookies["code_verifier"] || null; +} + +export function getStateFromCookie(ctx: T): string | null { + const cookies = getCookieObject(ctx); + if (!cookies) return null; + return cookies["state"] || null; +} diff --git a/modules/auth_oauth2/utils/wellknown.ts b/modules/auth_oauth2/utils/wellknown.ts new file mode 100644 index 00000000..3abab568 --- /dev/null +++ b/modules/auth_oauth2/utils/wellknown.ts @@ -0,0 +1,40 @@ +import { RuntimeError } from "../_gen/mod.ts"; +import { ProviderEndpoints } from "../config.ts"; + +/** + * Get the OIDC well-known config object from the given URL. + * + * @param wellKnownUrl The URL of the OIDC well-known config + * @returns The OIDC well-known config object + */ +export async function getFromOidcWellKnown(wellKnownUrl: string): Promise { + const res = await fetch(wellKnownUrl).catch(() => { throw new RuntimeError("invalid_config") }); + if (!res.ok) throw new RuntimeError("invalid_config"); + + const json: unknown = await res.json().catch(() => { throw new RuntimeError("invalid_config") }); + if (typeof json !== "object" || json === null) throw new RuntimeError("invalid_config"); + + const jsonObj = json as Record; + + const { + authorization_endpoint, + token_endpoint, + userinfo_endpoint, + scopes_supported, + } = jsonObj; + + if (typeof authorization_endpoint !== "string") throw new RuntimeError("invalid_config"); + if (typeof token_endpoint !== "string") throw new RuntimeError("invalid_config"); + if (typeof userinfo_endpoint !== "string") throw new RuntimeError("invalid_config"); + if (!Array.isArray(scopes_supported)) throw new RuntimeError("invalid_config"); + if (scopes_supported.some(scope => typeof scope !== "string")) throw new RuntimeError("invalid_config"); + + + return { + authorization: authorization_endpoint, + token: token_endpoint, + userinfo: userinfo_endpoint, + scopes: scopes_supported.join(" "), + userinfoKey: "sub", + }; +} diff --git a/tests/basic/backend.yaml b/tests/basic/backend.yaml index 19142bec..07ce1a46 100644 --- a/tests/basic/backend.yaml +++ b/tests/basic/backend.yaml @@ -25,3 +25,15 @@ modules: test: {} # sendGrid: # apiKeyVariable: SENDGRID_API_KEY + auth_oauth2: + registry: local + config: + providers: [] + # google: https://accounts.google.com/.well-known/openid-configuration + # microsoft: https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration + # github: + # authorization: https://github.com/login/oauth/authorize + # token: https://github.com/login/oauth/access_token + # userinfo: https://api.github.com/user + # scope: [read:user] + # userinfoKey: id