From e27d414163e8d93c994f6e2c6046828f461b22ee Mon Sep 17 00:00:00 2001 From: Wonsuk Choi Date: Tue, 14 Jan 2025 10:29:08 +0900 Subject: [PATCH 1/2] chore(package.json): use double quotes in scripts (#2932) * chore(package.json): use double quotes in scripts * chore(package.json): add curly braces in 'patch-d-ts' script --- package.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/package.json b/package.json index aaef2918cd..b700627d32 100644 --- a/package.json +++ b/package.json @@ -61,8 +61,8 @@ "sideEffects": false, "scripts": { "prebuild": "shx rm -rf dist", - "build": "pnpm run prebuild && pnpm run '/^build:.*/' && pnpm run postbuild", - "build-watch": "pnpm run '/^build:.*/' --watch", + "build": "pnpm run prebuild && pnpm run \"/^build:.*/\" && pnpm run postbuild", + "build-watch": "pnpm run \"/^build:.*/\" --watch", "build:base": "rollup -c", "build:utils": "rollup -c --config-utils", "build:babel:plugin-debug-label": "rollup -c --config-babel_plugin-debug-label", @@ -73,15 +73,15 @@ "build:react": "rollup -c --config-react --client-only", "build:react:utils": "rollup -c --config-react_utils --client-only", "postbuild": "pnpm patch-d-ts && pnpm copy && pnpm patch-ts3.8 && pnpm patch-old-ts && pnpm patch-esm-ts && pnpm patch-readme", - "fix:format": "prettier '*.{js,json,md}' '{src,tests,benchmarks,docs}/**/*.{ts,tsx,md,mdx}' --write", + "fix:format": "prettier \"*.{js,json,md}\" \"{src,tests,benchmarks,docs}/**/*.{ts,tsx,md,mdx}\" --write", "fix:lint": "eslint . --fix", - "test": "pnpm run '/^test:.*/'", - "test:format": "prettier '*.{js,json,md}' '{src,tests,benchmarks,docs}/**/*.{ts,tsx,md,mdx}' --list-different", + "test": "pnpm run \"/^test:.*/\"", + "test:format": "prettier \"*.{js,json,md}\" \"{src,tests,benchmarks,docs}/**/*.{ts,tsx,md,mdx}\" --list-different", "test:types": "tsc --noEmit", "test:lint": "eslint .", "test:spec": "vitest run", "test-build:spec": "vitest run", - "patch-d-ts": "node --input-type=module -e \"import { entries } from './rollup.config.mjs'; import shelljs from 'shelljs'; const { find, sed } = shelljs; find('dist/**/*.d.ts').forEach(f => { entries.forEach(({ find, replacement }) => sed('-i', new RegExp(' from \\'' + find.source.slice(0, -1) + '\\';$'), ' from \\'' + replacement + '\\';', f)); sed('-i', / from '(\\.[^']+)\\.ts';$/, ' from \\'\\$1\\';', f); });\"", + "patch-d-ts": "node --input-type=module -e \"import { entries } from './rollup.config.mjs'; import shelljs from 'shelljs'; const { find, sed } = shelljs; find('dist/**/*.d.ts').forEach(f => { entries.forEach(({ find, replacement }) => { sed('-i', new RegExp(' from \\'' + find.source.slice(0, -1) + '\\';$'), ' from \\'' + replacement + '\\';', f); }); sed('-i', / from '(\\.[^']+)\\.ts';$/, ' from \\'\\$1\\';', f); });\"", "copy": "shx cp -r dist/src/* dist/esm && shx cp -r dist/src/* dist && shx rm -rf dist/src && shx rm -rf dist/{src,tests} && downlevel-dts dist dist/ts3.8 --to=3.8 && shx cp package.json readme.md LICENSE dist && json -I -f dist/package.json -e \"this.private=false; this.devDependencies=undefined; this.optionalDependencies=undefined; this.scripts=undefined; this.prettier=undefined;\"", "patch-ts3.8": "node -e \"require('shelljs').find('dist/ts3.8/**/*.d.ts').forEach(f=>require('fs').appendFileSync(f,'declare type Awaited = T extends Promise ? V : T;'))\"", "patch-old-ts": "shx touch dist/ts_version_3.8_and_above_is_required.d.ts", From efb4573470762e42b5e47eaf2de57f28cc8c193c Mon Sep 17 00:00:00 2001 From: Daishi Kato Date: Tue, 14 Jan 2025 14:53:35 +0900 Subject: [PATCH 2/2] refactor: eliminate batch (#2925) * refactor: eliminate batch * secret internal symbol * Update src/vanilla/store.ts Co-authored-by: David Maskasky * add a failing test * fix it * add some comments * add failing test * naive workaround * refactor readAtomState * is this better? * update test * add a test * fix it * Update src/vanilla/store.ts Co-authored-by: David Maskasky * refactor * add another failing test * fix it * a dirty hack for syncEffect * add effect test * would it be acceptable? * possible fix for typing * enjoying the puzzle * fix the dirty hack ab61bbd6f9aae34c7c9a55805a1b42abc3282a5d --------- Co-authored-by: David Maskasky --- src/vanilla/store.ts | 385 +++++++++++++++-------------------- tests/vanilla/effect.test.ts | 179 ++++++++++------ tests/vanilla/store.test.tsx | 63 +++++- 3 files changed, 344 insertions(+), 283 deletions(-) diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index 6c20762127..53f9e40156 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -7,6 +7,7 @@ type AnyWritableAtom = WritableAtom type OnUnmount = () => void type Getter = Parameters[0] type Setter = Parameters[1] +type EpochNumber = number const isSelfAtom = (atom: AnyAtom, a: AnyAtom): boolean => atom.unstable_is ? atom.unstable_is(a) : a === atom @@ -77,45 +78,52 @@ type Mounted = { /** Set of mounted atoms that depends on the atom. */ readonly t: Set /** Function to run when the atom is unmounted. */ - u?: (batch: Batch) => void + u?: () => void } /** * Mutable atom state, * tracked for both mounted and unmounted atoms in a store. + * + * This should be garbage collectable. + * We can mutate it during atom read. (except for fields with TODO) */ type AtomState = { /** * Map of atoms that the atom depends on. * The map value is the epoch number of the dependency. */ - readonly d: Map + readonly d: Map /** * Set of atoms with pending promise that depend on the atom. * * This may cause memory leaks, but it's for the capability to continue promises + * TODO(daishi): revisit how to handle this */ readonly p: Set /** The epoch number of the atom. */ - n: number - /** Object to store mounted state of the atom. */ + n: EpochNumber + /** + * Object to store mounted state of the atom. + * TODO(daishi): move this out of AtomState + */ m?: Mounted // only available if the atom is mounted /** * Listener to notify when the atom value is updated. - * This is still an experimental API and subject to change without notice. + * This is an experimental API and will be changed in the next minor. + * TODO(daishi): move this store hooks */ - u?: (batch: Batch) => void + u?: () => void /** * Listener to notify when the atom is mounted or unmounted. - * This is still an experimental API and subject to change without notice. + * This is an experimental API and will be changed in the next minor. + * TODO(daishi): move this store hooks */ - h?: (batch: Batch) => void + h?: () => void /** Atom value */ v?: Value /** Atom error */ e?: AnyError - /** Indicates that the atom value has been changed */ - x?: true } const isAtomStateInitialized = (atomState: AtomState) => @@ -165,75 +173,6 @@ const addDependency = ( aState.m?.t.add(atom) } -// -// Batch -// - -type BatchPriority = 0 | 1 | 2 - -type Batch = [ - /** finish recompute */ - priority0: Set<() => void>, - /** atom listeners */ - priority1: Set<() => void>, - /** atom mount hooks */ - priority2: Set<() => void>, -] & { - /** changed Atoms */ - C: Set -} - -const createBatch = (): Batch => - Object.assign([new Set(), new Set(), new Set()], { C: new Set() }) as Batch - -const addBatchFunc = ( - batch: Batch, - priority: BatchPriority, - fn: () => void, -) => { - batch[priority].add(fn) -} - -const registerBatchAtom = ( - batch: Batch, - atom: AnyAtom, - atomState: AtomState, -) => { - if (!batch.C.has(atom)) { - batch.C.add(atom) - atomState.u?.(batch) - const scheduleListeners = () => { - atomState.m?.l.forEach((listener) => addBatchFunc(batch, 1, listener)) - } - addBatchFunc(batch, 1, scheduleListeners) - } -} - -const flushBatch = (batch: Batch) => { - let error: AnyError - let hasError = false - const call = (fn: () => void) => { - try { - fn() - } catch (e) { - if (!hasError) { - error = e - hasError = true - } - } - } - while (batch.C.size || batch.some((channel) => channel.size)) { - batch.C.clear() - for (const channel of batch) { - channel.forEach(call) - channel.clear() - } - } - if (hasError) { - throw error - } -} - // internal & unstable type type StoreArgs = readonly [ getAtomState: (atom: Atom) => AtomState | undefined, @@ -275,6 +214,11 @@ type Store = { export type INTERNAL_DevStoreRev4 = DevStoreRev4 export type INTERNAL_PrdStore = Store +/** + * This is an experimental API and will be changed in the next minor. + */ +const INTERNAL_flushStoreHook = Symbol.for('JOTAI.EXPERIMENTAL.FLUSHSTOREHOOK') + const buildStore = (...storeArgs: StoreArgs): Store => { const [ getAtomState, @@ -297,6 +241,55 @@ const buildStore = (...storeArgs: StoreArgs): Store => { return atomState } + // These are store state. + // As they are not garbage collectable, they shouldn't be mutated during atom read. + const invalidatedAtoms = new WeakMap() + const changedAtoms = new Map() + const unmountCallbacks = new Set<() => void>() + const mountCallbacks = new Set<() => void>() + + let inTransaction = 0 + const runWithTransaction = (fn: () => T): T => { + const errors: unknown[] = [] + const call = (fn: () => void) => { + try { + fn() + } catch (e) { + errors.push(e) + } + } + let result: T + ++inTransaction + try { + result = fn() + } finally { + if (inTransaction === 1) { + while ( + changedAtoms.size || + unmountCallbacks.size || + mountCallbacks.size + ) { + recomputeInvalidatedAtoms() + ;(store as any)[INTERNAL_flushStoreHook]?.() + const callbacks = new Set<() => void>() + const add = callbacks.add.bind(callbacks) + changedAtoms.forEach((atomState) => atomState.m?.l.forEach(add)) + changedAtoms.clear() + unmountCallbacks.forEach(add) + unmountCallbacks.clear() + mountCallbacks.forEach(add) + mountCallbacks.clear() + callbacks.forEach(call) + } + } + --inTransaction + } + if (errors.length) { + throw errors[0] + } + return result + } + const setAtomStateValueOrPromise = ( atom: AnyAtom, atomState: AtomState, @@ -315,7 +308,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => { atomState.v = valueOrPromise } delete atomState.e - delete atomState.x if (!hasPrevValue || !Object.is(prevValue, atomState.v)) { ++atomState.n if (pendingPromise) { @@ -324,17 +316,14 @@ const buildStore = (...storeArgs: StoreArgs): Store => { } } - const readAtomState = ( - batch: Batch | undefined, - atom: Atom, - ): AtomState => { + const readAtomState = (atom: Atom): AtomState => { const atomState = ensureAtomState(atom) // See if we can skip recomputing this atom. if (isAtomStateInitialized(atomState)) { // If the atom is mounted, we can use cached atom state. // because it should have been updated by dependencies. - // We can't use the cache if the atom is dirty. - if (atomState.m && !atomState.x) { + // We can't use the cache if the atom is invalidated. + if (atomState.m && invalidatedAtoms.get(atom) !== atomState.n) { return atomState } // Otherwise, check if the dependencies have changed. @@ -344,7 +333,7 @@ const buildStore = (...storeArgs: StoreArgs): Store => { ([a, n]) => // Recursively, read the atom state of the dependency, and // check if the atom epoch number is unchanged - readAtomState(batch, a).n === n, + readAtomState(a).n === n, ) ) { return atomState @@ -353,6 +342,11 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // Compute a new state for this atom. atomState.d.clear() let isSync = true + const mountDependenciesIfAsync = () => { + if (atomState.m) { + runWithTransaction(() => mountDependencies(atom, atomState)) + } + } const getter: Getter = (a: Atom) => { if (isSelfAtom(atom, a)) { const aState = ensureAtomState(a) @@ -367,17 +361,13 @@ const buildStore = (...storeArgs: StoreArgs): Store => { return returnAtomValue(aState) } // a !== atom - const aState = readAtomState(batch, a) + const aState = readAtomState(a) try { return returnAtomValue(aState) } finally { - if (isSync) { - addDependency(atom, atomState, a, aState) - } else { - const batch = createBatch() - addDependency(atom, atomState, a, aState) - mountDependencies(batch, atom, atomState) - flushBatch(batch) + addDependency(atom, atomState, a, aState) + if (!isSync) { + mountDependenciesIfAsync() } } } @@ -415,20 +405,12 @@ const buildStore = (...storeArgs: StoreArgs): Store => { setAtomStateValueOrPromise(atom, atomState, valueOrPromise) if (isPromiseLike(valueOrPromise)) { valueOrPromise.onCancel?.(() => controller?.abort()) - const complete = () => { - if (atomState.m) { - const batch = createBatch() - mountDependencies(batch, atom, atomState) - flushBatch(batch) - } - } - valueOrPromise.then(complete, complete) + valueOrPromise.then(mountDependenciesIfAsync, mountDependenciesIfAsync) } return atomState } catch (error) { delete atomState.v atomState.e = error - delete atomState.x ++atomState.n return atomState } finally { @@ -437,9 +419,9 @@ const buildStore = (...storeArgs: StoreArgs): Store => { } const readAtom = (atom: Atom): Value => - returnAtomValue(readAtomState(undefined, atom)) + returnAtomValue(readAtomState(atom)) - const getMountedOrBatchDependents = ( + const getMountedOrPendingDependents = ( atomState: AtomState, ): Map => { const dependents = new Map() @@ -458,11 +440,22 @@ const buildStore = (...storeArgs: StoreArgs): Store => { return dependents } - const recomputeDependents = ( - batch: Batch, - atom: Atom, - atomState: AtomState, - ) => { + const invalidateDependents = (atomState: AtomState) => { + const visited = new WeakSet() + const stack: AtomState[] = [atomState] + while (stack.length) { + const aState = stack.pop()! + if (!visited.has(aState)) { + visited.add(aState) + for (const [d, s] of getMountedOrPendingDependents(aState)) { + invalidatedAtoms.set(d, s.n) + stack.push(s) + } + } + } + } + + const recomputeInvalidatedAtoms = () => { // Step 1: traverse the dependency graph to build the topsorted atom list // We don't bother to check for cycles, which simplifies the algorithm. // This is a topological sort via depth-first search, slightly modified from @@ -471,14 +464,14 @@ const buildStore = (...storeArgs: StoreArgs): Store => { const topSortedReversed: [ atom: AnyAtom, atomState: AtomState, - epochNumber: number, + epochNumber: EpochNumber, ][] = [] - const visiting = new Set() - const visited = new Set() + const visiting = new WeakSet() + const visited = new WeakSet() // Visit the root atom. This is the only atom in the dependency graph // without incoming edges, which is one reason we can simplify the algorithm - const stack: [a: AnyAtom, aState: AtomState][] = [[atom, atomState]] - while (stack.length > 0) { + const stack: [a: AnyAtom, aState: AtomState][] = Array.from(changedAtoms) + while (stack.length) { const [a, aState] = stack[stack.length - 1]! if (visited.has(a)) { // All dependents have been processed, now process this atom @@ -489,18 +482,21 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // The algorithm calls for pushing onto the front of the list. For // performance, we will simply push onto the end, and then will iterate in // reverse order later. - topSortedReversed.push([a, aState, aState.n]) + if (invalidatedAtoms.get(a) === aState.n) { + topSortedReversed.push([a, aState, aState.n]) + } else { + invalidatedAtoms.delete(a) + changedAtoms.set(a, aState) + } // Atom has been visited but not yet processed visited.add(a) - // Mark atom dirty - aState.x = true stack.pop() continue } visiting.add(a) // Push unvisited dependents onto the stack - for (const [d, s] of getMountedOrBatchDependents(aState)) { - if (a !== d && !visiting.has(d)) { + for (const [d, s] of getMountedOrPendingDependents(aState)) { + if (!visiting.has(d)) { stack.push([d, s]) } } @@ -508,45 +504,38 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // Step 2: use the topSortedReversed atom list to recompute all affected atoms // Track what's changed, so that we can short circuit when possible - const finishRecompute = () => { - const changedAtoms = new Set([atom]) - for (let i = topSortedReversed.length - 1; i >= 0; --i) { - const [a, aState, prevEpochNumber] = topSortedReversed[i]! - let hasChangedDeps = false - for (const dep of aState.d.keys()) { - if (dep !== a && changedAtoms.has(dep)) { - hasChangedDeps = true - break - } + for (let i = topSortedReversed.length - 1; i >= 0; --i) { + const [a, aState, prevEpochNumber] = topSortedReversed[i]! + let hasChangedDeps = false + for (const dep of aState.d.keys()) { + if (dep !== a && changedAtoms.has(dep)) { + hasChangedDeps = true + break } - if (hasChangedDeps) { - readAtomState(batch, a) - mountDependencies(batch, a, aState) - if (prevEpochNumber !== aState.n) { - registerBatchAtom(batch, a, aState) - changedAtoms.add(a) - } + } + if (hasChangedDeps) { + readAtomState(a) + mountDependencies(a, aState) + if (prevEpochNumber !== aState.n) { + changedAtoms.set(a, aState) + aState.u?.() } - delete aState.x } + invalidatedAtoms.delete(a) } - addBatchFunc(batch, 0, finishRecompute) } const writeAtomState = ( - batch: Batch, atom: WritableAtom, ...args: Args ): Result => { - let isSync = true - const getter: Getter = (a: Atom) => - returnAtomValue(readAtomState(batch, a)) + const getter: Getter = (a: Atom) => returnAtomValue(readAtomState(a)) const setter: Setter = ( a: WritableAtom, ...args: As ) => { const aState = ensureAtomState(a) - try { + return runWithTransaction(() => { if (isSelfAtom(atom, a)) { if (!hasInitialValue(a)) { // NOTE technically possible but restricted as it may cause bugs @@ -555,49 +544,31 @@ const buildStore = (...storeArgs: StoreArgs): Store => { const prevEpochNumber = aState.n const v = args[0] as V setAtomStateValueOrPromise(a, aState, v) - mountDependencies(batch, a, aState) + mountDependencies(a, aState) if (prevEpochNumber !== aState.n) { - registerBatchAtom(batch, a, aState) - recomputeDependents(batch, a, aState) + changedAtoms.set(a, aState) + aState.u?.() + invalidateDependents(aState) } return undefined as R } else { - return writeAtomState(batch, a, ...args) + return writeAtomState(a, ...args) } - } finally { - if (!isSync) { - flushBatch(batch) - } - } - } - try { - return atomWrite(atom, getter, setter, ...args) - } finally { - isSync = false + }) } + return atomWrite(atom, getter, setter, ...args) } const writeAtom = ( atom: WritableAtom, ...args: Args - ): Result => { - const batch = createBatch() - try { - return writeAtomState(batch, atom, ...args) - } finally { - flushBatch(batch) - } - } + ): Result => runWithTransaction(() => writeAtomState(atom, ...args)) - const mountDependencies = ( - batch: Batch, - atom: AnyAtom, - atomState: AtomState, - ) => { + const mountDependencies = (atom: AnyAtom, atomState: AtomState) => { if (atomState.m && !isPendingPromise(atomState.v)) { for (const a of atomState.d.keys()) { if (!atomState.m.d.has(a)) { - const aMounted = mountAtom(batch, a, ensureAtomState(a)) + const aMounted = mountAtom(a, ensureAtomState(a)) aMounted.t.add(atom) atomState.m.d.add(a) } @@ -605,7 +576,7 @@ const buildStore = (...storeArgs: StoreArgs): Store => { for (const a of atomState.m.d || []) { if (!atomState.d.has(a)) { atomState.m.d.delete(a) - const aMounted = unmountAtom(batch, a, ensureAtomState(a)) + const aMounted = unmountAtom(a, ensureAtomState(a)) aMounted?.t.delete(atom) } } @@ -613,16 +584,15 @@ const buildStore = (...storeArgs: StoreArgs): Store => { } const mountAtom = ( - batch: Batch, atom: Atom, atomState: AtomState, ): Mounted => { if (!atomState.m) { // recompute atom state - readAtomState(batch, atom) + readAtomState(atom) // mount dependencies first for (const a of atomState.d.keys()) { - const aMounted = mountAtom(batch, a, ensureAtomState(a)) + const aMounted = mountAtom(a, ensureAtomState(a)) aMounted.t.add(atom) } // mount self @@ -631,43 +601,24 @@ const buildStore = (...storeArgs: StoreArgs): Store => { d: new Set(atomState.d.keys()), t: new Set(), } - atomState.h?.(batch) + atomState.h?.() if (isActuallyWritableAtom(atom)) { const mounted = atomState.m - let setAtom: (...args: unknown[]) => unknown - const createInvocationContext = (batch: Batch, fn: () => T) => { - let isSync = true - setAtom = (...args: unknown[]) => { - try { - return writeAtomState(batch, atom, ...args) - } finally { - if (!isSync) { - flushBatch(batch) - } - } - } - try { - return fn() - } finally { - isSync = false - } - } const processOnMount = () => { - const onUnmount = createInvocationContext(batch, () => - atomOnMount(atom, (...args) => setAtom(...args)), + const onUnmount = atomOnMount(atom, (...args) => + runWithTransaction(() => writeAtomState(atom, ...args)), ) if (onUnmount) { - mounted.u = (batch) => createInvocationContext(batch, onUnmount) + mounted.u = onUnmount } } - addBatchFunc(batch, 2, processOnMount) + mountCallbacks.add(processOnMount) } } return atomState.m } const unmountAtom = ( - batch: Batch, atom: Atom, atomState: AtomState, ): Mounted | undefined => { @@ -679,13 +630,13 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // unmount self const onUnmount = atomState.m.u if (onUnmount) { - addBatchFunc(batch, 2, () => onUnmount(batch)) + unmountCallbacks.add(onUnmount) } delete atomState.m - atomState.h?.(batch) + atomState.h?.() // unmount dependencies for (const a of atomState.d.keys()) { - const aMounted = unmountAtom(batch, a, ensureAtomState(a)) + const aMounted = unmountAtom(a, ensureAtomState(a)) aMounted?.t.delete(atom) } return undefined @@ -694,18 +645,18 @@ const buildStore = (...storeArgs: StoreArgs): Store => { } const subscribeAtom = (atom: AnyAtom, listener: () => void) => { - const batch = createBatch() const atomState = ensureAtomState(atom) - const mounted = mountAtom(batch, atom, atomState) - const listeners = mounted.l - listeners.add(listener) - flushBatch(batch) - return () => { - listeners.delete(listener) - const batch = createBatch() - unmountAtom(batch, atom, atomState) - flushBatch(batch) - } + return runWithTransaction(() => { + const mounted = mountAtom(atom, atomState) + const listeners = mounted.l + listeners.add(listener) + return () => { + runWithTransaction(() => { + listeners.delete(listener) + unmountAtom(atom, atomState) + }) + } + }) } const unstable_derive: Store['unstable_derive'] = (fn) => @@ -730,8 +681,8 @@ const deriveDevStoreRev4 = (store: Store): Store & DevStoreRev4 => { storeArgs[1] = function devSetAtomState(atom, atomState) { setAtomState(atom, atomState) const originalMounted = atomState.h - atomState.h = (batch) => { - originalMounted?.(batch) + atomState.h = () => { + originalMounted?.() if (atomState.m) { debugMountedAtoms.add(atom) } else { diff --git a/tests/vanilla/effect.test.ts b/tests/vanilla/effect.test.ts index 697d6a6733..7a1dfb9906 100644 --- a/tests/vanilla/effect.test.ts +++ b/tests/vanilla/effect.test.ts @@ -1,12 +1,11 @@ import { expect, it, vi } from 'vitest' -import type { Atom, Getter, Setter } from 'jotai/vanilla' +import type { Atom, Getter, Setter, WritableAtom } from 'jotai/vanilla' import { atom, createStore } from 'jotai/vanilla' type Store = ReturnType type GetAtomState = Parameters[0]>[0] type AtomState = NonNullable> type AnyAtom = Atom -type Batch = Parameters>[0] type Cleanup = () => void type Effect = (get: Getter, set: Setter) => Cleanup | void @@ -17,52 +16,59 @@ type Ref = { cleanup?: Cleanup | undefined } -const syncEffectChannelSymbol = Symbol() - function syncEffect(effect: Effect): Atom { const refAtom = atom(() => ({ inProgress: 0, epoch: 0 })) const refreshAtom = atom(0) - const internalAtom = atom((get) => { - get(refreshAtom) - const ref = get(refAtom) - if (ref.inProgress) { - return ref.epoch - } - ref.get = get - return ++ref.epoch - }) + const internalAtom = atom( + (get) => { + get(refreshAtom) + const ref = get(refAtom) + if (ref.inProgress) { + return ref.epoch + } + ref.get = get + return ++ref.epoch + }, + () => {}, + ) + internalAtom.onMount = () => { + return () => {} + } internalAtom.unstable_onInit = (store) => { const ref = store.get(refAtom) const runEffect = () => { const deps = new Set() - ref.cleanup?.() - ref.cleanup = - effect( - (a) => { - deps.add(a) - return ref.get!(a) - }, - (a, ...args) => { - try { - ++ref.inProgress - return store.set(a, ...args) - } finally { - deps.forEach(ref.get!) - --ref.inProgress - } - }, - ) || undefined + try { + ref.cleanup?.() + ref.cleanup = + effect( + (a) => { + deps.add(a) + return store.get(a) + }, + (a, ...args) => { + try { + ++ref.inProgress + return store.set(a, ...args) + } finally { + --ref.inProgress + } + }, + ) || undefined + } finally { + deps.forEach(ref.get!) + } } const internalAtomState = getAtomState(store, internalAtom) const originalMountHook = internalAtomState.h - internalAtomState.h = (batch) => { - originalMountHook?.(batch) + internalAtomState.h = () => { + originalMountHook?.() if (internalAtomState.m) { // mount store.set(refreshAtom, (v) => v + 1) } else { // unmount - const syncEffectChannel = ensureBatchChannel(batch) + const syncEffectChannel = ensureSyncEffectChannel(store) syncEffectChannel.add(() => { ref.cleanup?.() delete ref.cleanup @@ -70,10 +76,10 @@ function syncEffect(effect: Effect): Atom { } } const originalUpdateHook = internalAtomState.u - internalAtomState.u = (batch) => { - originalUpdateHook?.(batch) + internalAtomState.u = () => { + originalUpdateHook?.() // update - const syncEffectChannel = ensureBatchChannel(batch) + const syncEffectChannel = ensureSyncEffectChannel(store) syncEffectChannel.add(runEffect) } } @@ -82,37 +88,24 @@ function syncEffect(effect: Effect): Atom { }) } -type BatchWithSyncEffect = Batch & { - [syncEffectChannelSymbol]?: Set<() => void> -} -function ensureBatchChannel(batch: BatchWithSyncEffect) { - // ensure continuation of the flushBatch while loop - const originalQueue = batch[1] - if (!originalQueue) { - throw new Error('batch[1] must be present') - } - if (!batch[syncEffectChannelSymbol]) { - batch[syncEffectChannelSymbol] = new Set<() => void>() - batch[1] = { - ...originalQueue, - add(item) { - originalQueue.add(item) - return this - }, - clear() { - batch[syncEffectChannelSymbol]!.clear() - originalQueue.clear() - }, - forEach(callback) { - batch[syncEffectChannelSymbol]!.forEach(callback) - originalQueue.forEach(callback) - }, - get size() { - return batch[syncEffectChannelSymbol]!.size + originalQueue.size - }, +const INTERNAL_flushStoreHook = Symbol.for('JOTAI.EXPERIMENTAL.FLUSHSTOREHOOK') +const syncEffectChannelSymbol = Symbol() + +function ensureSyncEffectChannel(store: any) { + if (!store[syncEffectChannelSymbol]) { + store[syncEffectChannelSymbol] = new Set<() => void>() + const originalFlushHook = store[INTERNAL_flushStoreHook] + store[INTERNAL_flushStoreHook] = () => { + originalFlushHook?.() + const syncEffectChannel = store[syncEffectChannelSymbol] as Set< + () => void + > + const fns = Array.from(syncEffectChannel) + syncEffectChannel.clear() + fns.forEach((fn: () => void) => fn()) } } - return batch[syncEffectChannelSymbol]! + return store[syncEffectChannelSymbol] as Set<() => void> } const getAtomStateMap = new WeakMap() @@ -246,3 +239,59 @@ it('sets values to atoms without causing infinite loop', () => { unsub() expect(effect).toBeCalledTimes(2) }) + +// TODO: consider removing this after we provide a new syncEffect implementation +it('supports recursive setting synchronous in read', async () => { + const store = createStore() + const a = atom(0) + const refreshAtom = atom(0) + type Ref = { + isMounted?: true + recursing: number + set: Setter + } + const refAtom = atom( + () => ({ recursing: 0 }) as Ref, + (get, set) => { + const ref = get(refAtom) + ref.isMounted = true + ref.set = set + set(refreshAtom, (v) => v + 1) + }, + ) + refAtom.onMount = (mount) => mount() + const effectAtom = atom((get) => { + get(refreshAtom) + const ref = get(refAtom) + if (!ref.isMounted) { + return + } + const recurse = ( + a: WritableAtom, + ...args: Args + ): Result => { + ++ref.recursing + const value = ref.set(a, ...args) + return value as Result + } + function runEffect() { + const v = get(a) + if (v < 5) { + recurse(a, (v) => v + 1) + } + } + if (ref.recursing) { + let prevRecursing = ref.recursing + do { + prevRecursing = ref.recursing + runEffect() + } while (prevRecursing !== ref.recursing) + ref.recursing = 0 + return Promise.resolve() + } + return Promise.resolve().then(runEffect) + }) + store.sub(effectAtom, () => {}) + await Promise.resolve() + expect(store.get(a)).toBe(5) +}) diff --git a/tests/vanilla/store.test.tsx b/tests/vanilla/store.test.tsx index 425063617e..7202851f1d 100644 --- a/tests/vanilla/store.test.tsx +++ b/tests/vanilla/store.test.tsx @@ -984,6 +984,21 @@ it('mounted atom should be recomputed eagerly', () => { expect(result).toEqual(['bRead', 'aCallback', 'bCallback']) }) +it('should notify subscription even with reading atom in write', () => { + const a = atom(1) + const b = atom((get) => get(a) * 2) + const c = atom((get) => get(b) + 1) + const d = atom(null, (get, set) => { + set(a, 2) + get(b) + }) + const store = createStore() + const callback = vi.fn() + store.sub(c, callback) + store.set(d) + expect(callback).toHaveBeenCalledTimes(1) +}) + it('should process all atom listeners even if some of them throw errors', () => { const store = createStore() const a = atom(0) @@ -1091,7 +1106,6 @@ it('recomputes dependents of unmounted atoms', () => { const a = atom(0) a.debugLabel = 'a' const bRead = vi.fn((get: Getter) => { - console.log('bRead') return get(a) }) const b = atom(bRead) @@ -1108,3 +1122,50 @@ it('recomputes dependents of unmounted atoms', () => { store.set(w) expect(bRead).not.toHaveBeenCalled() }) + +it('should not inf on subscribe or unsubscribe', async () => { + const store = createStore() + const countAtom = atom(0) + const effectAtom = atom( + (get) => get(countAtom), + (_, set) => set, + ) + effectAtom.onMount = (setAtom) => { + const set = setAtom() + set(countAtom, 1) + return () => { + set(countAtom, 2) + } + } + const unsub = store.sub(effectAtom, () => {}) + expect(store.get(countAtom)).toBe(1) + unsub() + expect(store.get(countAtom)).toBe(2) +}) + +it('supports recursion in an atom subscriber', () => { + const a = atom(0) + const store = createStore() + store.sub(a, () => { + if (store.get(a) < 3) { + store.set(a, (v) => v + 1) + } + }) + store.set(a, 1) + expect(store.get(a)).toBe(3) +}) + +it('allows subcribing to atoms during mount', () => { + const store = createStore() + const a = atom(0) + a.onMount = () => { + store.sub(b, () => {}) + } + const b = atom(0) + let bMounted = false + b.onMount = () => { + bMounted = true + } + store.sub(a, () => {}) + expect(bMounted).toBe(true) +})