From 58499ad33612b83e72f64330538d6f6a01a1c731 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Sun, 9 Jul 2023 20:43:25 +0200 Subject: [PATCH] feat: refactor to `_useSession` semantics --- src/GoTrueClient.ts | 485 +++++++++++++++++++++++++------------------ src/lib/helpers.ts | 111 ++++++++++ test/helpers.test.ts | 47 +++++ 3 files changed, 438 insertions(+), 205 deletions(-) create mode 100644 test/helpers.test.ts diff --git a/src/GoTrueClient.ts b/src/GoTrueClient.ts index f498b3de0..1fc0351ab 100644 --- a/src/GoTrueClient.ts +++ b/src/GoTrueClient.ts @@ -29,6 +29,8 @@ import { generatePKCEVerifier, generatePKCEChallenge, supportsLocalStorage, + stackGuard, + isInStackGuard, } from './lib/helpers' import localStorageAdapter from './lib/local-storage' import { polyfillGlobalThis } from './lib/polyfills' @@ -279,8 +281,6 @@ export default class GoTrueClient { redirectType ) - await this._saveSession(session) - setTimeout(async () => { if (redirectType === 'recovery') { await this._notifyAllSubscribers('PASSWORD_RECOVERY', session) @@ -291,7 +291,6 @@ export default class GoTrueClient { return { error: null } } - // no login attempt via callback url try to recover session from storage await this._recoverAndRefresh() return { error: null } @@ -699,18 +698,20 @@ export default class GoTrueClient { */ async reauthenticate(): Promise { try { - const { - data: { session }, - error: sessionError, - } = await this.getSession() - if (sessionError) throw sessionError - if (!session) throw new AuthSessionMissingError() + return await this._useSession(async (result) => { + const { + data: { session }, + error: sessionError, + } = result + if (sessionError) throw sessionError + if (!session) throw new AuthSessionMissingError() - const { error } = await _request(this.fetch, 'GET', `${this.url}/reauthenticate`, { - headers: this.headers, - jwt: session.access_token, + const { error } = await _request(this.fetch, 'GET', `${this.url}/reauthenticate`, { + headers: this.headers, + jwt: session.access_token, + }) + return { data: { user: null, session: null }, error } }) - return { data: { user: null, session: null }, error } } catch (error) { if (isAuthError(error)) { return { data: { user: null, session: null }, error } @@ -768,7 +769,55 @@ export default class GoTrueClient { * Returns the session, refreshing it if necessary. * The session returned can be null if the session is not detected which can happen in the event a user is not signed-in or has logged out. */ - async getSession(): Promise< + async getSession() { + return this._useSession(async (result) => { + return result + }) + } + + /** + * Use instead of {@link #getSession} inside the library. It is + * semantically usually what you want, as getting a session involves some + * processing afterwards that requires only one client operating on the + * session at once across multiple tabs or processes. + */ + private async _useSession( + fn: ( + result: + | { + data: { + session: Session + } + error: null + } + | { + data: { + session: null + } + error: AuthError + } + | { + data: { + session: null + } + error: null + } + ) => Promise + ): Promise { + return await stackGuard('_useSession', async () => { + // the use of __loadSession here is the only correct use of the function! + const result = await this.__loadSession() + + return await fn(result) + }) + } + + /** + * NEVER USE DIRECTLY! + * + * Always use {@link #_useSession}. + */ + private async __loadSession(): Promise< | { data: { session: Session @@ -788,11 +837,15 @@ export default class GoTrueClient { error: null } > { + if (this.logDebugMessages && !isInStackGuard('_useSession')) { + throw new Error('Please use #_useSession()') + } + // make sure we've read the session from the url if there is one // save to just await, as long we make sure _initialize() never throws await this.initializePromise - this._debug('#getSession()', 'begin') + this._debug('#__loadSession()', 'begin') try { let currentSession: Session | null = null @@ -824,7 +877,7 @@ export default class GoTrueClient { : false this._debug( - '#getSession()', + '#__loadSession()', `session has${hasExpired ? '' : ' not'} expired`, 'expires_at', currentSession.expires_at @@ -841,7 +894,7 @@ export default class GoTrueClient { return { data: { session }, error: null } } finally { - this._debug('#getSession()', 'end') + this._debug('#__loadSession()', 'end') } } @@ -851,20 +904,22 @@ export default class GoTrueClient { */ async getUser(jwt?: string): Promise { try { - if (!jwt) { - const { data, error } = await this.getSession() - if (error) { - throw error - } + return await this._useSession(async (result) => { + if (!jwt) { + const { data, error } = result + if (error) { + throw error + } - // Default to Authorization header if there is no existing session - jwt = data.session?.access_token ?? undefined - } + // Default to Authorization header if there is no existing session + jwt = data.session?.access_token ?? undefined + } - return await _request(this.fetch, 'GET', `${this.url}/user`, { - headers: this.headers, - jwt: jwt, - xform: _userResponse, + return await _request(this.fetch, 'GET', `${this.url}/user`, { + headers: this.headers, + jwt: jwt, + xform: _userResponse, + }) }) } catch (error) { if (isAuthError(error)) { @@ -885,27 +940,29 @@ export default class GoTrueClient { } = {} ): Promise { try { - const { data: sessionData, error: sessionError } = await this.getSession() - if (sessionError) { - throw sessionError - } - if (!sessionData.session) { - throw new AuthSessionMissingError() - } - const session: Session = sessionData.session - const { data, error: userError } = await _request(this.fetch, 'PUT', `${this.url}/user`, { - headers: this.headers, - redirectTo: options?.emailRedirectTo, - body: attributes, - jwt: session.access_token, - xform: _userResponse, - }) - if (userError) throw userError - session.user = data.user as User - await this._saveSession(session) - await this._notifyAllSubscribers('USER_UPDATED', session) + return await this._useSession(async (result) => { + const { data: sessionData, error: sessionError } = result + if (sessionError) { + throw sessionError + } + if (!sessionData.session) { + throw new AuthSessionMissingError() + } + const session: Session = sessionData.session + const { data, error: userError } = await _request(this.fetch, 'PUT', `${this.url}/user`, { + headers: this.headers, + redirectTo: options?.emailRedirectTo, + body: attributes, + jwt: session.access_token, + xform: _userResponse, + }) + if (userError) throw userError + session.user = data.user as User + await this._saveSession(session) + await this._notifyAllSubscribers('USER_UPDATED', session) - return { data: { user: session.user }, error: null } + return { data: { user: session.user }, error: null } + }) } catch (error) { if (isAuthError(error)) { return { data: { user: null }, error } @@ -997,29 +1054,31 @@ export default class GoTrueClient { */ async refreshSession(currentSession?: { refresh_token: string }): Promise { try { - if (!currentSession) { - const { data, error } = await this.getSession() - if (error) { - throw error - } + return await this._useSession(async (result) => { + if (!currentSession) { + const { data, error } = result + if (error) { + throw error + } - currentSession = data.session ?? undefined - } + currentSession = data.session ?? undefined + } - if (!currentSession?.refresh_token) { - throw new AuthSessionMissingError() - } + if (!currentSession?.refresh_token) { + throw new AuthSessionMissingError() + } - const { session, error } = await this._callRefreshToken(currentSession.refresh_token) - if (error) { - return { data: { user: null, session: null }, error: error } - } + const { session, error } = await this._callRefreshToken(currentSession.refresh_token) + if (error) { + return { data: { user: null, session: null }, error: error } + } - if (!session) { - return { data: { user: null, session: null }, error: null } - } + if (!session) { + return { data: { user: null, session: null }, error: null } + } - return { data: { user: session.user, session }, error: null } + return { data: { user: session.user, session }, error: null } + }) } catch (error) { if (isAuthError(error)) { return { data: { user: null, session: null }, error } @@ -1142,27 +1201,29 @@ export default class GoTrueClient { * If using others scope, no `SIGNED_OUT` event is fired! */ async signOut({ scope }: SignOut = { scope: 'global' }): Promise<{ error: AuthError | null }> { - const { data, error: sessionError } = await this.getSession() - if (sessionError) { - return { error: sessionError } - } - const accessToken = data.session?.access_token - if (accessToken) { - const { error } = await this.admin.signOut(accessToken, scope) - if (error) { - // ignore 404s since user might not exist anymore - // ignore 401s since an invalid or expired JWT should sign out the current session - if (!(isAuthApiError(error) && (error.status === 404 || error.status === 401))) { - return { error } + return await this._useSession(async (result) => { + const { data, error: sessionError } = result + if (sessionError) { + return { error: sessionError } + } + const accessToken = data.session?.access_token + if (accessToken) { + const { error } = await this.admin.signOut(accessToken, scope) + if (error) { + // ignore 404s since user might not exist anymore + // ignore 401s since an invalid or expired JWT should sign out the current session + if (!(isAuthApiError(error) && (error.status === 404 || error.status === 401))) { + return { error } + } } } - } - if (scope !== 'others') { - await this._removeSession() - await removeItemAsync(this.storage, `${this.storageKey}-code-verifier`) - await this._notifyAllSubscribers('SIGNED_OUT', null) - } - return { error: null } + if (scope !== 'others') { + await this._removeSession() + await removeItemAsync(this.storage, `${this.storageKey}-code-verifier`) + await this._notifyAllSubscribers('SIGNED_OUT', null) + } + return { error: null } + }) } /** @@ -1195,20 +1256,22 @@ export default class GoTrueClient { } private async _emitInitialSession(id: string): Promise { - try { - const { - data: { session }, - error, - } = await this.getSession() - if (error) throw error + return await this._useSession(async (result) => { + try { + const { + data: { session }, + error, + } = result + if (error) throw error - await this.stateChangeEmitters.get(id)?.callback('INITIAL_SESSION', session) - this._debug('INITIAL_SESSION', 'callback id', id, 'session', session) - } catch (err) { - await this.stateChangeEmitters.get(id)?.callback('INITIAL_SESSION', null) - this._debug('INITIAL_SESSION', 'callback id', id, 'error', err) - console.error(err) - } + await this.stateChangeEmitters.get(id)?.callback('INITIAL_SESSION', session) + this._debug('INITIAL_SESSION', 'callback id', id, 'session', session) + } catch (err) { + await this.stateChangeEmitters.get(id)?.callback('INITIAL_SESSION', null) + this._debug('INITIAL_SESSION', 'callback id', id, 'error', err) + console.error(err) + } + }) } /** @@ -1634,28 +1697,30 @@ export default class GoTrueClient { const now = Date.now() try { - const { - data: { session }, - } = await this.getSession() - - if (!session || !session.refresh_token || !session.expires_at) { - this._debug('#_autoRefreshTokenTick()', 'no session') - return - } + return await this._useSession(async (result) => { + const { + data: { session }, + } = result + + if (!session || !session.refresh_token || !session.expires_at) { + this._debug('#_autoRefreshTokenTick()', 'no session') + return + } - // session will expire in this many ticks (or has already expired if <= 0) - const expiresInTicks = Math.floor( - (session.expires_at * 1000 - now) / AUTO_REFRESH_TICK_DURATION - ) + // session will expire in this many ticks (or has already expired if <= 0) + const expiresInTicks = Math.floor( + (session.expires_at * 1000 - now) / AUTO_REFRESH_TICK_DURATION + ) - this._debug( - '#_autoRefreshTokenTick()', - `access token expires in ${expiresInTicks} ticks, a tick lasts ${AUTO_REFRESH_TICK_DURATION}ms, refresh threshold is ${AUTO_REFRESH_TICK_THRESHOLD} ticks` - ) + this._debug( + '#_autoRefreshTokenTick()', + `access token expires in ${expiresInTicks} ticks, a tick lasts ${AUTO_REFRESH_TICK_DURATION}ms, refresh threshold is ${AUTO_REFRESH_TICK_THRESHOLD} ticks` + ) - if (expiresInTicks <= AUTO_REFRESH_TICK_THRESHOLD) { - await this._callRefreshToken(session.refresh_token) - } + if (expiresInTicks <= AUTO_REFRESH_TICK_THRESHOLD) { + await this._callRefreshToken(session.refresh_token) + } + }) } catch (e: any) { console.error('Auto refresh tick failed with error. This is likely a transient error.', e) } @@ -1777,14 +1842,16 @@ export default class GoTrueClient { private async _unenroll(params: MFAUnenrollParams): Promise { try { - const { data: sessionData, error: sessionError } = await this.getSession() - if (sessionError) { - return { data: null, error: sessionError } - } + return await this._useSession(async (result) => { + const { data: sessionData, error: sessionError } = result + if (sessionError) { + return { data: null, error: sessionError } + } - return await _request(this.fetch, 'DELETE', `${this.url}/factors/${params.factorId}`, { - headers: this.headers, - jwt: sessionData?.session?.access_token, + return await _request(this.fetch, 'DELETE', `${this.url}/factors/${params.factorId}`, { + headers: this.headers, + jwt: sessionData?.session?.access_token, + }) }) } catch (error) { if (isAuthError(error)) { @@ -1799,30 +1866,32 @@ export default class GoTrueClient { */ private async _enroll(params: MFAEnrollParams): Promise { try { - const { data: sessionData, error: sessionError } = await this.getSession() - if (sessionError) { - return { data: null, error: sessionError } - } + return await this._useSession(async (result) => { + const { data: sessionData, error: sessionError } = result + if (sessionError) { + return { data: null, error: sessionError } + } - const { data, error } = await _request(this.fetch, 'POST', `${this.url}/factors`, { - body: { - friendly_name: params.friendlyName, - factor_type: params.factorType, - issuer: params.issuer, - }, - headers: this.headers, - jwt: sessionData?.session?.access_token, - }) + const { data, error } = await _request(this.fetch, 'POST', `${this.url}/factors`, { + body: { + friendly_name: params.friendlyName, + factor_type: params.factorType, + issuer: params.issuer, + }, + headers: this.headers, + jwt: sessionData?.session?.access_token, + }) - if (error) { - return { data: null, error } - } + if (error) { + return { data: null, error } + } - if (data?.totp?.qr_code) { - data.totp.qr_code = `data:image/svg+xml;utf-8,${data.totp.qr_code}` - } + if (data?.totp?.qr_code) { + data.totp.qr_code = `data:image/svg+xml;utf-8,${data.totp.qr_code}` + } - return { data, error: null } + return { data, error: null } + }) } catch (error) { if (isAuthError(error)) { return { data: null, error } @@ -1836,32 +1905,34 @@ export default class GoTrueClient { */ private async _verify(params: MFAVerifyParams): Promise { try { - const { data: sessionData, error: sessionError } = await this.getSession() - if (sessionError) { - return { data: null, error: sessionError } - } + return await this._useSession(async (result) => { + const { data: sessionData, error: sessionError } = result + if (sessionError) { + return { data: null, error: sessionError } + } - const { data, error } = await _request( - this.fetch, - 'POST', - `${this.url}/factors/${params.factorId}/verify`, - { - body: { code: params.code, challenge_id: params.challengeId }, - headers: this.headers, - jwt: sessionData?.session?.access_token, + const { data, error } = await _request( + this.fetch, + 'POST', + `${this.url}/factors/${params.factorId}/verify`, + { + body: { code: params.code, challenge_id: params.challengeId }, + headers: this.headers, + jwt: sessionData?.session?.access_token, + } + ) + if (error) { + return { data: null, error } } - ) - if (error) { - return { data: null, error } - } - await this._saveSession({ - expires_at: Math.round(Date.now() / 1000) + data.expires_in, - ...data, - }) - await this._notifyAllSubscribers('MFA_CHALLENGE_VERIFIED', data) + await this._saveSession({ + expires_at: Math.round(Date.now() / 1000) + data.expires_in, + ...data, + }) + await this._notifyAllSubscribers('MFA_CHALLENGE_VERIFIED', data) - return { data, error } + return { data, error } + }) } catch (error) { if (isAuthError(error)) { return { data: null, error } @@ -1875,20 +1946,22 @@ export default class GoTrueClient { */ private async _challenge(params: MFAChallengeParams): Promise { try { - const { data: sessionData, error: sessionError } = await this.getSession() - if (sessionError) { - return { data: null, error: sessionError } - } - - return await _request( - this.fetch, - 'POST', - `${this.url}/factors/${params.factorId}/challenge`, - { - headers: this.headers, - jwt: sessionData?.session?.access_token, + return await this._useSession(async (result) => { + const { data: sessionData, error: sessionError } = result + if (sessionError) { + return { data: null, error: sessionError } } - ) + + return await _request( + this.fetch, + 'POST', + `${this.url}/factors/${params.factorId}/challenge`, + { + headers: this.headers, + jwt: sessionData?.session?.access_token, + } + ) + }) } catch (error) { if (isAuthError(error)) { return { data: null, error } @@ -1946,39 +2019,41 @@ export default class GoTrueClient { * {@see GoTrueMFAApi#getAuthenticatorAssuranceLevel} */ private async _getAuthenticatorAssuranceLevel(): Promise { - const { - data: { session }, - error: sessionError, - } = await this.getSession() - if (sessionError) { - return { data: null, error: sessionError } - } - if (!session) { - return { - data: { currentLevel: null, nextLevel: null, currentAuthenticationMethods: [] }, - error: null, + return await this._useSession(async (result) => { + const { + data: { session }, + error: sessionError, + } = result + if (sessionError) { + return { data: null, error: sessionError } + } + if (!session) { + return { + data: { currentLevel: null, nextLevel: null, currentAuthenticationMethods: [] }, + error: null, + } } - } - const payload = this._decodeJWT(session.access_token) + const payload = this._decodeJWT(session.access_token) - let currentLevel: AuthenticatorAssuranceLevels | null = null + let currentLevel: AuthenticatorAssuranceLevels | null = null - if (payload.aal) { - currentLevel = payload.aal - } + if (payload.aal) { + currentLevel = payload.aal + } - let nextLevel: AuthenticatorAssuranceLevels | null = currentLevel + let nextLevel: AuthenticatorAssuranceLevels | null = currentLevel - const verifiedFactors = - session.user.factors?.filter((factor: Factor) => factor.status === 'verified') ?? [] + const verifiedFactors = + session.user.factors?.filter((factor: Factor) => factor.status === 'verified') ?? [] - if (verifiedFactors.length > 0) { - nextLevel = 'aal2' - } + if (verifiedFactors.length > 0) { + nextLevel = 'aal2' + } - const currentAuthenticationMethods = payload.amr || [] + const currentAuthenticationMethods = payload.amr || [] - return { data: { currentLevel, nextLevel, currentAuthenticationMethods }, error: null } + return { data: { currentLevel, nextLevel, currentAuthenticationMethods }, error: null } + }) } } diff --git a/src/lib/helpers.ts b/src/lib/helpers.ts index e4e0e82c7..856afb07b 100644 --- a/src/lib/helpers.ts +++ b/src/lib/helpers.ts @@ -282,3 +282,114 @@ export async function generatePKCEChallenge(verifier: string) { const hashed = await sha256(verifier) return base64urlencode(hashed) } + +const STACK_GUARD_PREFIX = `__stack_guard__` +const STACK_GUARD_SUFFIX = `__` + +// Firefox and WebKit based browsers encode the stack entry differently, but +// they all include the function name. So instead of trying to parse the entry, +// we're only looking for the special string `__stack_guard__${guardName}__`. +// Guard names can only be letters with dashes or underscores. +// +// Example Firefox stack trace: +// ``` +// __stack_guard__EXAMPLE__@debugger eval code:1:55 +// @debugger eval code:1:3 +// ``` +// +// Example WebKit/Chrome stack trace: +// ``` +// Error +// at Object.__stack_guard__EXAMPLE__ (:1:55) +// at :1:13 +// ``` +// +const STACK_ENTRY_REGEX = /__stack_guard__([a-zA-Z0-9_-]+)__/ + +let STACK_GUARD_CHECKED = false +let STACK_GUARD_CHECK_FN: () => Promise // eslint-disable-line prefer-const + +/** + * Checks if the current caller of the function is in a {@link + * #stackGuard} of the provided `name`. Works by looking through + * the stack trace of an `Error` object for a special function + * name (generated by {@link #stackGuard}). + * + * @param name The name of the stack guard to check for. Must be `[a-zA-Z0-9_-]` only. + */ +export function isInStackGuard(name: string): boolean { + STACK_GUARD_CHECK_FN() + + let error: Error + + try { + throw new Error() + } catch (e: any) { + error = e + } + + const stack = error.stack?.split('\n') ?? [] + + for (let i = 0; i < stack.length; i += 1) { + const entry = stack[i] + const match = entry.match(STACK_ENTRY_REGEX) + + if (match && match[1] === name) { + return true + } + } + + return false +} + +/** + * Creates a minification resistant stack guard, i.e. if you + * call {@link #isInStackGuard} from within the `fn` parameter + * function, you will always get `true` otherwise it will be + * `false`. + * + * Works by dynamically defining a function name before calling + * into `fn`, which is then parsed from the stack trace on an + * `Error` object within {@link #isInStackGuard}. + * + * @param name The name of the stack guard. Must be `[a-zA-Z0-9_-]` only. + * @param fn The async/await function to be run within the stack guard. + */ +export async function stackGuard(name: string, fn: () => Promise): Promise { + await STACK_GUARD_CHECK_FN() + + const guardName = `${STACK_GUARD_PREFIX}${name}${STACK_GUARD_SUFFIX}` + + const guardFunc: { + [funcName: string]: () => Promise + } = { + // per ECMAScript rules, this defines a new function with the dynamic name + // contained in the `guardName` variable + // this function name shows up in stack traces and is resistant to mangling + // from minification processes as it is determined at runtime + [guardName]: async () => await fn(), + } + + return await guardFunc[guardName]() +} + +// In certain cases, if this file is transpiled using an ES2015 target, or is +// running in a JS engine that does not support async/await stack traces, this +// function will log a single warning message. +STACK_GUARD_CHECK_FN = async () => { + if (!STACK_GUARD_CHECKED) { + STACK_GUARD_CHECKED = true + + await stackGuard('ENV_CHECK', async () => { + const result = isInStackGuard('ENV_CHECK') + + if (!result) { + console.warn( + '@supabase/gotrue-js: Stack guards not supported in this environment. Generally not an issue but may point to a very conservative transpilation environment (use ES2017 or above) that implements async/await with generators, or this is a JavaScript engine that does not support async/await stack traces.' + ) + } + + return result + }) + } +} diff --git a/test/helpers.test.ts b/test/helpers.test.ts new file mode 100644 index 000000000..90729feef --- /dev/null +++ b/test/helpers.test.ts @@ -0,0 +1,47 @@ +import { stackGuard, isInStackGuard } from '../src/lib/helpers' + +describe('stackGuard and isInStackGuard', () => { + it('should detect that a nested function is in a stack guard', async () => { + let result: boolean | null = null + + const nested = async () => { + result = isInStackGuard('TEST') + } + + await stackGuard('TEST', async () => { + await nested() + }) + + expect(result).toBe(true) + }) + + it('should not detect that a nested function is in a stack guard', async () => { + let result: boolean | null = null + + const nested = async () => { + result = isInStackGuard('TEST') + } + + await stackGuard('DIFFERENT', async () => { + await nested() + }) + + expect(result).toBe(false) + }) + + it('should not detect that a function called outside a stack guard is in one', async () => { + let result: boolean | null = null + + const nested = async () => { + result = isInStackGuard('TEST') + } + + await stackGuard('TEST', async () => { + // not calling nested + }) + + await nested() + + expect(result).toBe(false) + }) +})