Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce getClaims method to verify asymmetric JWTs #1030

Merged
merged 18 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 134 additions & 17 deletions src/GoTrueClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
isAuthRetryableFetchError,
isAuthSessionMissingError,
isAuthImplicitGrantRedirectError,
AuthInvalidJwtError,
} from './lib/errors'
import {
Fetch,
Expand All @@ -30,7 +31,6 @@ import {
_ssoResponse,
} from './lib/fetch'
import {
decodeJWTPayload,
Deferred,
getItemAsync,
isBrowser,
Expand All @@ -43,6 +43,9 @@ import {
supportsLocalStorage,
parseParametersFromURL,
getCodeChallengeAndMethod,
getAlgorithm,
validateExp,
decodeJWT,
} from './lib/helpers'
import { localStorageAdapter, memoryLocalStorageAdapter } from './lib/local-storage'
import { polyfillGlobalThis } from './lib/polyfills'
Expand Down Expand Up @@ -86,7 +89,6 @@ import type {
MFAVerifyParams,
AuthMFAVerifyResponse,
AuthMFAListFactorsResponse,
AMREntry,
AuthMFAGetAuthenticatorAssuranceLevelResponse,
AuthenticatorAssuranceLevels,
Factor,
Expand All @@ -100,6 +102,7 @@ import type {
MFAEnrollPhoneParams,
AuthMFAEnrollTOTPResponse,
AuthMFAEnrollPhoneResponse,
JWK,
} from './lib/types'

polyfillGlobalThis() // Make "globalThis" available
Expand Down Expand Up @@ -140,7 +143,10 @@ export default class GoTrueClient {
protected storageKey: string

protected flowType: AuthFlowType

/**
* The JWKS used for verifying asymmetric JWTs
*/
protected jwks: { keys: JWK[] }
protected autoRefreshToken: boolean
protected persistSession: boolean
protected storage: SupportedStorage
Expand Down Expand Up @@ -220,7 +226,7 @@ export default class GoTrueClient {
} else {
this.lock = lockNoOp
}

this.jwks = { keys: [] }
this.mfa = {
verify: this._verify.bind(this),
enroll: this._enroll.bind(this),
Expand Down Expand Up @@ -1288,17 +1294,6 @@ export default class GoTrueClient {
}
}

/**
* Decodes a JWT (without performing any validation).
*/
private _decodeJWT(jwt: string): {
exp?: number
aal?: AuthenticatorAssuranceLevels | null
amr?: AMREntry[] | null
} {
return decodeJWTPayload(jwt)
}

/**
* Sets the session data from the current session. If the current session is expired, setSession will take care of refreshing it to obtain a new session.
* If the refresh token or access token in the current session is invalid, an error will be thrown.
Expand Down Expand Up @@ -1328,7 +1323,7 @@ export default class GoTrueClient {
let expiresAt = timeNow
let hasExpired = true
let session: Session | null = null
const payload = decodeJWTPayload(currentSession.access_token)
const { payload } = decodeJWT(currentSession.access_token)
if (payload.exp) {
expiresAt = payload.exp
hasExpired = expiresAt <= timeNow
Expand Down Expand Up @@ -2576,7 +2571,7 @@ export default class GoTrueClient {
}
}

const payload = this._decodeJWT(session.access_token)
const { payload } = decodeJWT(session.access_token)

let currentLevel: AuthenticatorAssuranceLevels | null = null

Expand All @@ -2599,4 +2594,126 @@ export default class GoTrueClient {
})
})
}

private async fetchJwk(kid: string, jwks: { keys: JWK[] } = { keys: [] }): Promise<JWK> {
// try fetching from the supplied jwks
let jwk = jwks.keys.find((key) => key.kid === kid)
if (jwk) {
return jwk
}

// try fetching from cache
jwk = this.jwks.keys.find((key) => key.kid === kid)
if (jwk) {
return jwk
}
hf marked this conversation as resolved.
Show resolved Hide resolved
// jwk isn't cached in memory so we need to fetch it from the well-known endpoint
const { data, error } = await _request(this.fetch, 'GET', `${this.url}/.well-known/jwks.json`, {
headers: this.headers,
})
if (error) {
throw error
}
if (!data.keys || data.keys.length === 0) {
throw new AuthInvalidJwtError('JWKS is empty')
}
this.jwks = data
// Find the signing key
jwk = data.keys.find((key: any) => key.kid === kid)
if (!jwk) {
throw new AuthInvalidJwtError('No matching signing key found in JWKS')
}
return jwk
}

/**
* @experimental This method may change in future versions.
* @description Gets the claims from a JWT. If the JWT is symmetric JWTs, it will call getUser() to verify against the server. If the JWT is asymmetric, it will be verified against the JWKS using the WebCrypto API.
*/
async getClaims(
jwt?: string,
jwks: { keys: JWK[] } = { keys: [] }
): Promise<
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
| {
data: { claims: { [key: string]: any } }
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
error: null
}
| { data: null; error: AuthError }
| { data: null; error: null }
> {
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
try {
let token = jwt
if (!token) {
const { data, error } = await this.getSession()
if (error || !data.session) {
return { data: null, error }
}
token = data.session.access_token
}

const {
header,
payload,
signature,
raw: { header: rawHeader, payload: rawPayload },
} = decodeJWT(token)

// Reject expired JWTs
validateExp(payload.exp)

// If symmetric algorithm, fallback to getUser()
if (header.alg === 'HS256') {
const { error } = await this.getUser(token)
if (error) {
throw error
}
// getUser succeeds so the claims in the JWT can be trusted
return {
data: {
claims: payload,
},
error: null,
}
}

if (!header.kid) {
throw new AuthInvalidJwtError('Missing kid claim')
}
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved

const algorithm = getAlgorithm(header.alg)
const signingKey = await this.fetchJwk(header.kid, jwks)
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved

// Convert JWK to CryptoKey
const publicKey = await crypto.subtle.importKey('jwk', signingKey, algorithm, true, [
'verify',
])
// Verify the signature
const isValid = await crypto.subtle.verify(
algorithm,
publicKey,
signature,
new TextEncoder().encode(`${rawHeader}.${rawPayload}`)
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
)

if (!isValid) {
throw new AuthInvalidJwtError('Invalid JWT signature')
}

// If verification succeeds, decode and return claims
return {
data: {
claims: payload,
},
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
error: null,
}
} catch (error) {
if (isAuthError(error)) {
return { data: null, error }
}
return {
data: null,
error: new AuthUnknownError('Unknown error occurred while getting claims', error),
}
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
6 changes: 6 additions & 0 deletions src/lib/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,9 @@ export class AuthWeakPasswordError extends CustomAuthError {
export function isAuthWeakPasswordError(error: unknown): error is AuthWeakPasswordError {
return isAuthError(error) && error.name === 'AuthWeakPasswordError'
}

export class AuthInvalidJwtError extends CustomAuthError {
constructor(message: string) {
super(message, 'AuthInvalidJwtError', 400, 'invalid_jwt')
}
}
98 changes: 58 additions & 40 deletions src/lib/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { API_VERSION_HEADER_NAME } from './constants'
import { SupportedStorage } from './types'
import { AuthInvalidJwtError } from './errors'
import { base64url } from './rfc4648'
import { JwtHeader, JwtPayload, SupportedStorage } from './types'

export function expiresAt(expiresIn: number) {
const timeNow = Math.round(Date.now() / 1000)
Expand Down Expand Up @@ -141,34 +143,6 @@ export const removeItemAsync = async (storage: SupportedStorage, key: string): P
await storage.removeItem(key)
}

export function decodeBase64URL(value: string): string {
const key = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/='
let base64 = ''
let chr1, chr2, chr3
let enc1, enc2, enc3, enc4
let i = 0
value = value.replace('-', '+').replace('_', '/')

while (i < value.length) {
enc1 = key.indexOf(value.charAt(i++))
enc2 = key.indexOf(value.charAt(i++))
enc3 = key.indexOf(value.charAt(i++))
enc4 = key.indexOf(value.charAt(i++))
chr1 = (enc1 << 2) | (enc2 >> 4)
chr2 = ((enc2 & 15) << 4) | (enc3 >> 2)
chr3 = ((enc3 & 3) << 6) | enc4
base64 = base64 + String.fromCharCode(chr1)

if (enc3 != 64 && chr2 != 0) {
base64 = base64 + String.fromCharCode(chr2)
}
if (enc4 != 64 && chr3 != 0) {
base64 = base64 + String.fromCharCode(chr3)
}
}
return base64
}

/**
* A deferred represents some asynchronous work that is not yet finished, which
* may or may not culminate in a value.
Expand All @@ -194,23 +168,39 @@ export class Deferred<T = any> {
}
}

// Taken from: https://stackoverflow.com/questions/38552003/how-to-decode-jwt-token-in-javascript-without-using-a-library
export function decodeJWTPayload(token: string) {
// Regex checks for base64url format
const base64UrlRegex = /^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}=?$|[a-z0-9_-]{2}(==)?$)$/i

export function decodeJWT(token: string): {
header: JwtHeader
payload: JwtPayload
signature: Uint8Array
raw: {
header: string
payload: string
}
} {
const parts = token.split('.')

if (parts.length !== 3) {
throw new Error('JWT is not valid: not a JWT structure')
throw new AuthInvalidJwtError('Invalid JWT structure')
}

if (!base64UrlRegex.test(parts[1])) {
throw new Error('JWT is not valid: payload is not in base64url format')
// Regex checks for base64url format
const base64UrlRegex = /^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}=?$|[a-z0-9_-]{2}(==)?$)$/i
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
for (let i = 0; i < parts.length; i++) {
if (!base64UrlRegex.test(parts[i])) {
throw new AuthInvalidJwtError('JWT not in base64url format')
}
}

const base64Url = parts[1]
return JSON.parse(decodeBase64URL(base64Url))
const decoder = new TextDecoder()
const data = {
header: JSON.parse(decoder.decode(base64url.parse(parts[0], { loose: true }))),
payload: JSON.parse(decoder.decode(base64url.parse(parts[1], { loose: true }))),
signature: base64url.parse(parts[2], { loose: true }),
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
raw: {
header: parts[0],
payload: parts[1],
},
}
return data
}

/**
Expand Down Expand Up @@ -344,3 +334,31 @@ export function parseResponseAPIVersion(response: Response) {
return null
}
}

export function validateExp(exp: number) {
if (!exp) {
throw new Error('Missing exp claim')
}
const timeNow = Date.now() / 1000
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
if (exp <= timeNow) {
throw new Error('JWT has expired')
}
}

export function getAlgorithm(alg: 'RS256' | 'ES256'): RsaHashedImportParams | EcKeyImportParams {
switch (alg) {
case 'RS256':
return {
name: 'RSASSA-PKCS1-v1_5',
hash: { name: 'SHA-256' },
}
case 'ES256':
return {
name: 'ECDSA',
namedCurve: 'P-256',
hash: { name: 'SHA-256' },
}
default:
throw new Error('Invalid alg claim')
}
}
Loading