Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] feat(cache): promise based caching with workers #3245

Open
wants to merge 7 commits into
base: v9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jest.config.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'jsdom',
testEnvironment: '<rootDir>/packages/shared/ExtendedJSDOMEnvironment.ts',
testPathIgnorePatterns: ['/node_modules/'],
coveragePathIgnorePatterns: [
'<rootDir>/node_modules/',
Expand Down
87 changes: 87 additions & 0 deletions packages/fiber/src/core/cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
export const promiseCaches = new Set<PromiseCache>()

export class PromiseCache {
promises = new Map<string, Promise<any>>()
cachePromise: Promise<Cache>

constructor(cache: string | Cache | Promise<Cache>) {
this.cachePromise = Promise.resolve(cache).then((cache) => {
if (typeof cache === 'string') return caches.open(cache)
return cache
})

promiseCaches.add(this)
}

async run(url: string, handler: (url: string) => any) {
if (this.promises.has(url)) {
return this.promises.get(url)!
}

const promise = new Promise<any>(async (resolve, reject) => {
const blob = await this.fetch(url)
const blobUrl = URL.createObjectURL(blob)

try {
const result = await handler(blobUrl)
resolve(result)
} catch (error) {
reject(error)
} finally {
URL.revokeObjectURL(blobUrl)

// This hack is to simulate having processed the the promise with `React.use` already.
promise.then((result) => {
;(promise as any).status = 'fulfilled'
;(promise as any).value = result
})
}
})

this.promises.set(url, promise)

return promise
}

async fetch(url: string): Promise<Blob> {
const cache = await this.cachePromise

let response = await cache.match(url)

if (!response) {
const fetchResponse = await fetch(url)
if (fetchResponse.ok) {
await cache.put(url, fetchResponse.clone())
response = fetchResponse
}
}

return response!.blob()
}

add(url: string, promise: Promise<any>) {
this.promises.set(url, promise)
}

get(url: string) {
return this.promises.get(url)
}

has(url: string) {
return this.promises.has(url)
}

async delete(url: string): Promise<boolean> {
this.promises.delete(url)
return this.cachePromise.then((cache) => cache.delete(url))
}

async clear() {
this.promises.clear()
return this.cachePromise.then((cache) =>
cache.keys().then((keys) => Promise.all(keys.map((key) => cache.delete(key)))),
)
}
}

export const cacheName = 'assets'
95 changes: 52 additions & 43 deletions packages/fiber/src/core/hooks.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { context, RootState, RenderCallback, UpdateCallback, StageTypes, RootSto
import { buildGraph, ObjectMap, is, useMutableCallback, useIsomorphicLayoutEffect, isObject3D } from './utils'
import { Stages } from './stages'
import type { Instance } from './reconciler'
import { PromiseCache, cacheName } from './cache'

/**
* Exposes an object's {@link Instance}.
Expand Down Expand Up @@ -91,43 +92,36 @@ export type LoaderResult<T> = T extends { scene: THREE.Object3D } ? T & ObjectMa
export type Extensions<T> = (loader: Loader<T>) => void

const memoizedLoaders = new WeakMap<LoaderProto<any>, Loader<any>>()
const loaderCaches = new Map<Loader<any>, PromiseCache>()

const isConstructor = <T,>(value: unknown): value is LoaderProto<T> =>
typeof value === 'function' && value?.prototype?.constructor === value

function loadingFn<T>(extensions?: Extensions<T>, onProgress?: (event: ProgressEvent) => void) {
return async function (Proto: Loader<T> | LoaderProto<T>, ...input: string[]) {
let loader: Loader<any>

// Construct and cache loader if constructor was passed
if (isConstructor(Proto)) {
loader = memoizedLoaders.get(Proto)!
if (!loader) {
loader = new Proto()
memoizedLoaders.set(Proto, loader)
}
} else {
loader = Proto
function prepareLoaderInstance(loader: Loader<any> | LoaderProto<any>): Loader<any> {
let loaderInstance: Loader<any>

// Construct and cache loader if constructor was passed
if (isConstructor(loader)) {
loaderInstance = memoizedLoaders.get(loader)!
if (!loaderInstance) {
loaderInstance = new loader()
memoizedLoaders.set(loader, loaderInstance)
}
} else {
loaderInstance = loader as Loader<any>
}

// Apply loader extensions
if (extensions) extensions(loader)
Comment on lines -113 to -114
Copy link
Member

@CodyJasonBennett CodyJasonBennett May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to keep this and deopt calls to useLoader with extensions to only cache for the caller component? Can pass React.useId to discern the component.


// Go through the urls and load them
return Promise.all(
input.map(
(input) =>
new Promise<LoaderResult<T>>((res, reject) =>
loader.load(
input,
(data) => res(isObject3D(data?.scene) ? Object.assign(data, buildGraph(data.scene)) : data),
onProgress,
(error) => reject(new Error(`Could not load ${input}: ${(error as ErrorEvent)?.message}`)),
),
),
),
)
if (!loaderCaches.has(loaderInstance)) {
loaderCaches.set(loaderInstance, new PromiseCache(cacheName))
}

return loaderInstance
}

async function loadAsset(url: string, loaderInstance: Loader<any>, onProgress?: (event: ProgressEvent) => void) {
const result = await loaderInstance.loadAsync(url, onProgress)
const graph = isObject3D(result?.scene) ? Object.assign(result, buildGraph(result.scene)) : result
return graph
}

/**
Expand All @@ -139,16 +133,22 @@ function loadingFn<T>(extensions?: Extensions<T>, onProgress?: (event: ProgressE
export function useLoader<T, U extends string | string[] | string[][]>(
loader: Loader<T> | LoaderProto<T>,
input: U,
extensions?: Extensions<T>,
onProgress?: (event: ProgressEvent) => void,
) {
// Use suspense to load async assets
const keys = (Array.isArray(input) ? input : [input]) as string[]
const results = suspend(loadingFn(extensions, onProgress), [loader, ...keys], { equal: is.equ })
): U extends any[] ? LoaderResult<T>[] : LoaderResult<T> {
const urls = (Array.isArray(input) ? input : [input]) as string[]
const loaderInstance = prepareLoaderInstance(loader)
const cache = loaderCaches.get(loaderInstance)!

let results: any[] = []

for (const url of urls) {
if (!cache.has(url)) cache.run(url, async (cacheUrl) => loadAsset(cacheUrl, loaderInstance, onProgress))
const result = React.use(cache.get(url)!)
results.push(result)
}

// Return the object(s)
return (Array.isArray(input) ? results : results[0]) as unknown as U extends any[]
? LoaderResult<T>[]
: LoaderResult<T>
return (Array.isArray(input) ? results : results[0]) as any
}

/**
Expand All @@ -157,10 +157,14 @@ export function useLoader<T, U extends string | string[] | string[][]>(
useLoader.preload = function <T, U extends string | string[] | string[][]>(
loader: Loader<T> | LoaderProto<T>,
input: U,
extensions?: Extensions<T>,
): void {
const keys = (Array.isArray(input) ? input : [input]) as string[]
return preload(loadingFn(extensions), [loader, ...keys])
const urls = (Array.isArray(input) ? input : [input]) as string[]
const loaderInstance = prepareLoaderInstance(loader)
const cache = loaderCaches.get(loaderInstance)!

for (const url of urls) {
if (!cache.has(url)) cache.run(url, async (cacheUrl) => loadAsset(cacheUrl, loaderInstance))
}
}

/**
Expand All @@ -170,6 +174,11 @@ useLoader.clear = function <T, U extends string | string[] | string[][]>(
loader: Loader<T> | LoaderProto<T>,
input: U,
): void {
const keys = (Array.isArray(input) ? input : [input]) as string[]
return clear([loader, ...keys])
const urls = (Array.isArray(input) ? input : [input]) as string[]
const loaderInstance = prepareLoaderInstance(loader)
const cache = loaderCaches.get(loaderInstance)!

for (const url of urls) {
cache.delete(url)
}
}
112 changes: 35 additions & 77 deletions packages/fiber/tests/hooks.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import {
Instance,
extend,
} from '../src'
import { promiseCaches } from '../src/core/cache'

interface GLTF {
scene: THREE.Object3D
}

extend(THREE as any)
const root = createRoot(document.createElement('canvas'))
Expand All @@ -24,6 +29,8 @@ describe('hooks', () => {

beforeEach(() => {
canvas = createCanvas()
// Clear all caches before each test
promiseCaches.forEach(async (cache) => await cache.clear())
})

it('can handle useThree hook', async () => {
Expand Down Expand Up @@ -87,21 +94,10 @@ describe('hooks', () => {
})

it('can handle useLoader hook', async () => {
const MockMesh = new THREE.Mesh()
MockMesh.name = 'Scene'

interface GLTF {
scene: THREE.Object3D
}
class GLTFLoader extends THREE.Loader {
load(url: string, onLoad: (gltf: GLTF) => void): void {
onLoad({ scene: MockMesh })
}
}
let gltf: Record<string, any> = {}

let gltf!: GLTF & ObjectMap
const Component = () => {
gltf = useLoader(GLTFLoader, '/suzanne.glb')
gltf = useLoader(MockLoader, gltfs.diamond)
return <primitive object={gltf.scene} />
}

Expand All @@ -111,87 +107,49 @@ describe('hooks', () => {
expect(scene.children[0]).toBe(MockMesh)
expect(gltf.scene).toBe(MockMesh)
expect(gltf.nodes.Scene).toBe(MockMesh)
expect(gltf.json.nodes[0].name).toEqual('Diamond')
})

it('can handle useLoader hook with an array of strings', async () => {
const MockMesh = new THREE.Mesh()

const MockGroup = new THREE.Group()
const mat1 = new THREE.MeshBasicMaterial()
mat1.name = 'Mat 1'
const mesh1 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat1)
mesh1.name = 'Mesh 1'
const mat2 = new THREE.MeshBasicMaterial()
mat2.name = 'Mat 2'
const mesh2 = new THREE.Mesh(new THREE.BoxGeometry(2, 2), mat2)
mesh2.name = 'Mesh 2'
MockGroup.add(mesh1, mesh2)

class TestLoader extends THREE.Loader {
load = jest
.fn()
.mockImplementationOnce((_url, onLoad) => {
onLoad(MockMesh)
})
.mockImplementationOnce((_url, onLoad) => {
onLoad(MockGroup)
})
}

const extensions = jest.fn()
it('can handle useLoader with an existing loader instance', async () => {
let gltf: Record<string, any> = {}
const loader = new MockLoader()

const Component = () => {
const [mockMesh, mockScene] = useLoader(TestLoader, ['/suzanne.glb', '/myModels.glb'], extensions)

return (
<>
<primitive object={mockMesh as THREE.Mesh} />
<primitive object={mockScene as THREE.Scene} />
</>
)
gltf = useLoader(loader, gltfs.diamond)
return <primitive object={gltf.scene} />
}

const store = await act(async () => root.render(<Component />))
const { scene } = store.getState()

expect(scene.children[0]).toBe(MockMesh)
expect(scene.children[1]).toBe(MockGroup)
expect(extensions).toBeCalledTimes(1)
expect(gltf.scene).toBe(MockMesh)
expect(gltf.nodes.Scene).toBe(MockMesh)
expect(gltf.json.nodes[0].name).toEqual('Diamond')
})

it('can handle useLoader with an existing loader instance', async () => {
class Loader extends THREE.Loader {
load(_url: string, onLoad: (result: null) => void): void {
onLoad(null)
}
}

const loader = new Loader()
let proto!: Loader

function Test(): null {
return useLoader(loader, '', (loader) => (proto = loader))
}
await act(async () => root.render(<Test />))
it('can handle useLoader hook with an array of strings', async () => {
let results: Record<string, any>[] = []

expect(proto).toBe(loader)
})
// The same MockMesh gets returned for each url, but the json is different.
// Because of this we clone for the test.
const Component = () => {
results = useLoader(MockLoader, [gltfs.diamond, gltfs.lightning])

it('can handle useLoader with a loader extension', async () => {
class Loader extends THREE.Loader {
load(_url: string, onLoad: (result: null) => void): void {
onLoad(null)
}
return (
<>
<primitive object={results[0].scene.clone()} />
<primitive object={results[1].scene.clone()} />
</>
)
}

let proto!: Loader

function Test(): null {
return useLoader(Loader, '', (loader) => (proto = loader))
}
await act(async () => root.render(<Test />))
const store = await act(async () => root.render(<Component />))
const { scene } = store.getState()

expect(proto).toBeInstanceOf(Loader)
expect(scene.children).toHaveLength(2)
expect(results[0].json.nodes[0].name).toEqual('Diamond')
expect(results[1].json.nodes[0].name).toEqual('lightning')
})

it('can handle useGraph hook', async () => {
Expand Down
Loading