diff --git a/packages/fiber/src/core/hooks.tsx b/packages/fiber/src/core/hooks.tsx index 42cd9449dc..8317e09fb4 100644 --- a/packages/fiber/src/core/hooks.tsx +++ b/packages/fiber/src/core/hooks.tsx @@ -79,13 +79,20 @@ export function useGraph(object: THREE.Object3D) { return React.useMemo(() => buildGraph(object), [object]) } +const memoizedLoaders = new WeakMap, Loader>() + function loadingFn>( extensions?: Extensions, onProgress?: (event: ProgressEvent) => void, ) { return function (Proto: L, ...input: string[]) { // Construct new loader and run extensions - const loader = new Proto() + let loader = memoizedLoaders.get(Proto)! + if (!loader) { + loader = new Proto() + memoizedLoaders.set(Proto, loader) + } + if (extensions) extensions(loader) // Go through the urls and load them return Promise.all( @@ -103,7 +110,7 @@ function loadingFn>( ), ), ), - ) + ).finally(() => (loader as any).dispose?.()) } } diff --git a/packages/fiber/tests/core/hooks.test.tsx b/packages/fiber/tests/core/hooks.test.tsx index 0474bc9954..88e289c4d4 100644 --- a/packages/fiber/tests/core/hooks.test.tsx +++ b/packages/fiber/tests/core/hooks.test.tsx @@ -138,25 +138,21 @@ describe('hooks', () => { mesh2.name = 'Mesh 2' MockGroup.add(mesh1, mesh2) - jest.spyOn(Stdlib, 'GLTFLoader').mockImplementation( - () => - ({ - load: jest - .fn() - .mockImplementationOnce((_url, onLoad) => { - onLoad(MockMesh) - }) - .mockImplementationOnce((_url, onLoad) => { - onLoad({ scene: MockGroup }) - }), - setPath: () => {}, - } as unknown as Stdlib.GLTFLoader), - ) + class TestLoader extends THREE.Loader { + load = jest + .fn() + .mockImplementationOnce((_url, onLoad) => { + onLoad(MockMesh) + }) + .mockImplementationOnce((_url, onLoad) => { + onLoad(MockGroup) + }) + } + + const extensions = jest.fn() const Component = () => { - const [mockMesh, mockScene] = useLoader(Stdlib.GLTFLoader, ['/suzanne.glb', '/myModels.glb'], (loader) => { - loader.setPath('/public/models') - }) + const [mockMesh, mockScene] = useLoader(TestLoader, ['/suzanne.glb', '/myModels.glb'], extensions) return ( <> @@ -180,6 +176,8 @@ describe('hooks', () => { await waitFor(() => expect(scene.children[0]).toBeDefined()) expect(scene.children[0]).toBe(MockMesh) + expect(scene.children[1]).toBe(MockGroup) + expect(extensions).toBeCalledTimes(1) }) it('can handle useLoader with a loader extension', async () => {