diff --git a/packages/upscalerjs/src/loadModel.browser.test.ts b/packages/upscalerjs/src/loadModel.browser.test.ts index 3bbd68e51..0af906c67 100644 --- a/packages/upscalerjs/src/loadModel.browser.test.ts +++ b/packages/upscalerjs/src/loadModel.browser.test.ts @@ -1,6 +1,5 @@ import type { GraphModel, io, LayersModel } from '@tensorflow/tfjs'; import { vi } from 'vitest'; -import { tf } from './dependencies.generated'; import { CDNS, CDN_PATH_DEFINITIONS, @@ -11,6 +10,7 @@ import { import { loadTfModel, } from './model-utils'; +import * as tf from '@tensorflow/tfjs-node'; import { getModelDefinitionError, @@ -24,7 +24,6 @@ import { MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '@upscalerjs/core'; -import type * as dependenciesGenerated from './dependencies.generated'; import type * as core from '@upscalerjs/core'; import type * as modelUtils from './model-utils'; import type * as errorsAndWarnings from './errors-and-warnings'; @@ -41,7 +40,7 @@ vi.mock('./model-utils', async () => { const { loadTfModel, ...rest } = await vi.importActual('./model-utils') as typeof modelUtils; return { ...rest, - loadTfModel: vi.fn(loadTfModel), + loadTfModel: vi.fn(), } }); @@ -61,18 +60,6 @@ vi.mock('@upscalerjs/core', async () => { } }); -vi.mock('./dependencies.generated', async () => { - const { tf, ...rest } = await vi.importActual('./dependencies.generated') as typeof dependenciesGenerated; - return { - ...rest, - tf: { - ...tf, - loadLayersModel: vi.fn(), - loadGraphModel: vi.fn(), - } - } -}); - describe('loadModel browser tests', () => { afterEach(() => { vi.clearAllMocks(); @@ -91,7 +78,7 @@ describe('loadModel browser tests', () => { version: 'version', }, }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(1); expect(loadTfModel).toBeCalledWith(tf, 'foo', 'layers'); }); @@ -107,7 +94,7 @@ describe('loadModel browser tests', () => { version: 'version', }, }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(1); expect(loadTfModel).toBeCalledWith(tf, 'foo', 'graph'); }); @@ -118,7 +105,7 @@ describe('loadModel browser tests', () => { path: 'foo', modelType: 'layers', }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(1); expect(loadTfModel).toBeCalledWith(tf, 'foo', 'layers'); }); @@ -138,7 +125,7 @@ describe('loadModel browser tests', () => { }, modelType: 'layers', }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(1); expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[0]](packageName, version, modelPath), 'layers'); }); @@ -156,7 +143,7 @@ describe('loadModel browser tests', () => { }, modelType: 'graph', }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(1); expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[0]](packageName, version, modelPath), 'graph'); }); @@ -180,7 +167,7 @@ describe('loadModel browser tests', () => { }, modelType: 'layers', }; - await fetchModel(modelDefinition); + await fetchModel(tf, modelDefinition); expect(loadTfModel).toBeCalledTimes(2); expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[1]](packageName, version, modelPath), 'layers'); }); @@ -201,7 +188,7 @@ describe('loadModel browser tests', () => { }, modelType: 'layers', }; - await expect(() => fetchModel(modelDefinition)) + await expect(() => fetchModel(tf, modelDefinition)) .rejects .toThrowError(getLoadModelErrorMessage(CDNS.map((cdn, i) => [cdn, new Error(`next: ${i}`)]), modelPath, { path: modelPath, @@ -220,7 +207,7 @@ describe('loadModel browser tests', () => { }); vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e); - await expect(() => loadModel(Promise.resolve({ + await expect(() => loadModel(tf, Promise.resolve({ path: 'foo', scale: 2, modelType: 'layers', @@ -239,7 +226,7 @@ describe('loadModel browser tests', () => { modelType: 'layers', }; - const result = await loadModel(Promise.resolve(modelDefinition)); + const result = await loadModel(tf, Promise.resolve(modelDefinition)); expect(loadTfModel).toHaveBeenCalledTimes(1); expect(loadTfModel).toHaveBeenCalledWith(tf, modelDefinition.path, 'layers'); @@ -253,9 +240,6 @@ describe('loadModel browser tests', () => { it('loads a valid graph model successfully', async () => { vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true); const model = 'foo' as unknown as GraphModel; - tf.loadLayersModel.mockImplementation(async () => 'layers model' as any); - tf.loadGraphModel.mockImplementation(async () => model); - expect(tf.loadLayersModel).toHaveBeenCalledTimes(0); const modelDefinition: ModelDefinition = { path: 'foo', @@ -263,7 +247,7 @@ describe('loadModel browser tests', () => { modelType: 'graph', }; - const result = await loadModel(Promise.resolve(modelDefinition)); + const result = await loadModel(tf, Promise.resolve(modelDefinition)); expect(loadTfModel).toHaveBeenCalledTimes(1); expect(loadTfModel).toHaveBeenCalledWith(tf, modelDefinition.path, 'graph'); diff --git a/packages/upscalerjs/src/loadModel.browser.ts b/packages/upscalerjs/src/loadModel.browser.ts index f8d42ad72..191002a25 100644 --- a/packages/upscalerjs/src/loadModel.browser.ts +++ b/packages/upscalerjs/src/loadModel.browser.ts @@ -1,5 +1,5 @@ -import { tf, } from './dependencies.generated'; -import { ModelDefinition, ModelType, ModelConfigurationInternals, } from '@upscalerjs/core'; +import type { LayersModel, } from '@tensorflow/tfjs-layers'; +import type { ModelDefinition, ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core'; import type { ParsedModelDefinition, ModelPackage, } from './types'; import { loadTfModel, @@ -10,6 +10,7 @@ import { getModelDefinitionError, } from './errors-and-warnings'; import { + TF, isValidModelDefinition, } from '@upscalerjs/core'; import { @@ -38,7 +39,7 @@ export const getLoadModelErrorMessage = (errs: Errors, modelPath: string, intern ...errs.map(([cdn, err, ]) => `- ${cdn}: ${err.message}`), ].join('\n')); -export async function fetchModel(modelConfiguration: { +export async function fetchModel(tf: TF, modelConfiguration: { modelType?: M; } & Omit): Promise { const { modelType, _internals, path: modelPath, } = modelConfiguration; @@ -64,9 +65,11 @@ export async function fetchModel, ): Promise => { const modelDefinition = await _modelDefinition; + try { isValidModelDefinition(modelDefinition); } catch (err: unknown) { @@ -78,7 +81,7 @@ export const loadModel = async ( const parsedModelDefinition = parseModelDefinition(modelDefinition); - const model = await fetchModel(parsedModelDefinition); + const model = await fetchModel(tf, parsedModelDefinition); return { model, diff --git a/packages/upscalerjs/src/loadModel.node.test.ts b/packages/upscalerjs/src/loadModel.node.test.ts index 3d6ac1a80..40e7120a6 100644 --- a/packages/upscalerjs/src/loadModel.node.test.ts +++ b/packages/upscalerjs/src/loadModel.node.test.ts @@ -9,6 +9,7 @@ import { vi } from 'vitest'; import path from 'path'; import { resolver, } from './resolver'; import type { ModelDefinition } from "@upscalerjs/core"; +import * as tf from '@tensorflow/tfjs-node'; import { ERROR_MODEL_DEFINITION_BUG, } from './errors-and-warnings'; @@ -21,7 +22,6 @@ import { MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '@upscalerjs/core'; -import type * as dependenciesGenerated from './dependencies.generated'; import type * as core from '@upscalerjs/core'; import type * as modelUtils from './model-utils'; import type * as errorsAndWarnings from './errors-and-warnings'; @@ -31,7 +31,7 @@ vi.mock('./model-utils', async () => { const { loadTfModel, ...rest } = await vi.importActual('./model-utils') as typeof modelUtils; return { ...rest, - loadTfModel: vi.fn(loadTfModel), + loadTfModel: vi.fn(), } }); @@ -57,17 +57,6 @@ vi.mock('./resolver', async () => { resolver: vi.fn(resolver), }; }); -vi.mock('./dependencies.generated', async () => { - const { tf, ...rest } = await vi.importActual('./dependencies.generated') as typeof dependenciesGenerated; - return { - ...rest, - tf: { - ...tf, - loadLayersModel: vi.fn(), - loadGraphModel: vi.fn(), - } - } -}); const getResolver = (fn: () => string) => (fn) as unknown as typeof require.resolve; @@ -135,7 +124,7 @@ describe('loadModel.node', () => { throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); }); - await expect(loadModel(Promise.resolve({}) as Promise)) + await expect(loadModel(tf, Promise.resolve({}) as Promise)) .rejects .toThrow(error); }); @@ -148,7 +137,7 @@ describe('loadModel.node', () => { const path = 'foo'; const modelDefinition: ModelDefinition = { path, scale: 2, modelType: 'layers' }; - const response = await loadModel(Promise.resolve(modelDefinition)); + const response = await loadModel(tf, Promise.resolve(modelDefinition)); expect(loadTfModel).toHaveBeenCalledWith(tf, path, 'layers'); expect(response).toEqual({ model: 'layers model', @@ -164,7 +153,7 @@ describe('loadModel.node', () => { const path = 'foo'; const modelDefinition: ModelDefinition = { path, scale: 2, modelType: 'graph' }; - const response = await loadModel(Promise.resolve(modelDefinition)); + const response = await loadModel(tf, Promise.resolve(modelDefinition)); expect(loadTfModel).toHaveBeenCalledWith(tf, path, 'graph'); expect(response).toEqual({ model: 'graph model', diff --git a/packages/upscalerjs/src/loadModel.node.ts b/packages/upscalerjs/src/loadModel.node.ts index f2131fb27..d96030a90 100644 --- a/packages/upscalerjs/src/loadModel.node.ts +++ b/packages/upscalerjs/src/loadModel.node.ts @@ -1,11 +1,11 @@ import path from 'path'; -import { tf, } from './dependencies.generated'; import type { ModelDefinition, } from "@upscalerjs/core"; import { loadTfModel, parseModelDefinition, } from './model-utils'; import { resolver, } from './resolver'; import { ParsedModelDefinition, ModelPackage, } from './types'; import { isValidModelDefinition, + TF, } from '@upscalerjs/core'; import { ERROR_MODEL_DEFINITION_BUG, @@ -43,6 +43,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string }; export const loadModel = async ( + tf: TF, _modelDefinition: Promise, ): Promise => { const modelDefinition = await _modelDefinition; diff --git a/packages/upscalerjs/src/upscale.test.ts b/packages/upscalerjs/src/upscale.test.ts index 807717005..19e33e616 100644 --- a/packages/upscalerjs/src/upscale.test.ts +++ b/packages/upscalerjs/src/upscale.test.ts @@ -127,6 +127,7 @@ describe('predict', () => { const spy = vi.spyOn(model, 'predict'); tensor = getTensor(2, 2); const result = await wrapGenerator(processPixels( + tf, tensor.expandDims(0), { output: 'base64', @@ -150,6 +151,7 @@ describe('predict', () => { it('should make a prediction with a patchSize', async () => { tensor = getTensor(2, 2); const result = await wrapGenerator(processPixels( + tf, tensor.expandDims(0), { output: 'base64', @@ -170,6 +172,7 @@ describe('predict', () => { it('should make a prediction with a patchSize and a tall image', async () => { const tensor = getTensor(4, 2); const result = await wrapGenerator(processPixels( + tf, tensor.expandDims(0), { output: 'base64', @@ -192,7 +195,7 @@ describe('predict', () => { const patchSize = 2; const progress = vi.fn(); await wrapGenerator( - processPixels(tensor, { + processPixels(tf, tensor, { progress, output: 'base64', progressOutput: 'base64', @@ -219,6 +222,7 @@ describe('predict', () => { const patchSize = 2; const progress = vi.fn((_1: any, _2: any) => { }); await wrapGenerator(processPixels( + tf, tensor, { progress, output: 'base64', @@ -246,6 +250,7 @@ describe('predict', () => { const patchSize = 2; const progress = vi.fn((_1: any, _2: any) => { }); await wrapGenerator(processPixels( + tf, tensor, { progress, @@ -278,6 +283,7 @@ describe('predict', () => { const patchSize = 2; const progress = vi.fn, Parameters>((_1: any, _2: any, _3: any) => { }); await wrapGenerator(processPixels( + tf, tensor, { progress, output: 'base64', @@ -314,6 +320,7 @@ describe('predict', () => { } }); await wrapGenerator(processPixels( + tf, tensor, { progress, @@ -364,6 +371,7 @@ describe('predict', () => { } }); await wrapGenerator(processPixels( + tf, tensor, { progress, @@ -414,6 +422,7 @@ describe('predict', () => { } }); await wrapGenerator(processPixels( + tf, tensor, { progress, @@ -453,6 +462,7 @@ describe('predict', () => { it('should warn if provided a progress callback without patchSize', async () => { tensor = getTensor(4, 4).expandDims(0) as tf.Tensor4D; await wrapGenerator(processPixels( + tf, tensor, { output: 'base64', @@ -474,7 +484,7 @@ describe('predict', () => { const IMG_SIZE = 2; tensor = getTensor(IMG_SIZE, IMG_SIZE).expandDims(0) as tf.Tensor4D; const startingTensors = tf.memory().numTensors; - const gen = processPixels(tensor, { + const gen = processPixels(tf, tensor, { output: 'base64', progressOutput: 'base64', }, modelPackage, { @@ -514,6 +524,7 @@ describe('predict', () => { const startingTensors = tf.memory().numTensors; const patchSize = 2; const gen = processPixels( + tf, tensor, { output: 'base64', @@ -618,7 +629,7 @@ describe('upscale', () => { }] } as unknown as tf.LayersModel; const tensorAsBase64 = vi.fn().mockImplementation(() => 'foobarbaz4'); - const result = await wrapGenerator(upscale(img, { + const result = await wrapGenerator(upscale(tf, img, { output: 'base64', progressOutput: 'base64', }, { @@ -649,7 +660,7 @@ describe('upscale', () => { shape: [null, null, null, 3], }] } as unknown as tf.LayersModel; - const result = await wrapGenerator(upscale(img, { + const result = await wrapGenerator(upscale(tf, img, { output: 'base64', progressOutput: 'base64', }, { @@ -684,7 +695,7 @@ describe('upscale', () => { }] } as unknown as tf.LayersModel; // (mockedTensorAsBase as any).default = async() => 'foobarbaz5'; - const result = await wrapGenerator(upscale(img, { output: 'tensor', progressOutput: 'tensor', }, { + const result = await wrapGenerator(upscale(tf, img, { output: 'tensor', progressOutput: 'tensor', }, { model, modelDefinition: { scale: 2, } as ModelDefinition, }, { @@ -722,7 +733,7 @@ describe('cancellableUpscale', () => { throw new Error(`Rate is too high: ${rate}`); } }); - await expect(() => cancellableUpscale(img, { + await expect(() => cancellableUpscale(tf, img, { output: 'base64', progressOutput: 'base64', patchSize, @@ -769,7 +780,7 @@ describe('cancellableUpscale', () => { throw new Error(`Rate is too high: ${rate}`); } }); - await expect(() => cancellableUpscale(img, { + await expect(() => cancellableUpscale(tf, img, { patchSize, padding: 0, progress, @@ -807,7 +818,7 @@ describe('cancellableUpscale', () => { shape: [null, null, null, 3], }] } as unknown as tf.LayersModel; - const result = await cancellableUpscale(img, { + const result = await cancellableUpscale(tf, img, { patchSize, padding: 0, output: 'base64', diff --git a/packages/upscalerjs/src/upscale.ts b/packages/upscalerjs/src/upscale.ts index 3ea2235dd..e63fb9f60 100644 --- a/packages/upscalerjs/src/upscale.ts +++ b/packages/upscalerjs/src/upscale.ts @@ -1,4 +1,4 @@ -import { tf, } from './dependencies.generated'; +import type { Tensor3D, Tensor4D, } from '@tensorflow/tfjs-core'; import type { PrivateUpscaleArgs, ModelPackage, @@ -10,9 +10,6 @@ import type { GetImageAsTensor, TensorAsBase64, } from './types'; -import { - Input, -} from './image.generated'; import { wrapGenerator, warn, @@ -36,6 +33,7 @@ import { isTensor, isFourDimensionalTensor, FixedShape4D, + TF, } from '@upscalerjs/core'; import { makeTick, } from './makeTick'; import { GraphModel, LayersModel, } from '@tensorflow/tfjs'; @@ -54,7 +52,7 @@ export const getPercentageComplete = (row: number, col: number, columns: number, return percent; }; -export const executeModel = (model: LayersModel | GraphModel, pixels: tf.Tensor4D): tf.Tensor4D => { +export const executeModel = (model: LayersModel | GraphModel, pixels: Tensor4D): Tensor4D => { const predictedPixels = model.predict(pixels); if (!isTensor(predictedPixels)) { throw new Error(ERROR_INVALID_MODEL_PREDICTION); @@ -67,8 +65,9 @@ export const executeModel = (model: LayersModel | GraphModel, pixels: tf.Tensor4 }; /* eslint-disable @typescript-eslint/require-await */ -export async function* processPixels( - pixels: tf.Tensor4D, +export async function* processPixels( + tf: T, + pixels: Tensor4D, { output, progress, progressOutput, }: Pick, modelPackage: ModelPackage, { @@ -80,8 +79,8 @@ export async function* processPixels( } & Pick, { tensorAsBase64, - }: Pick, 'tensorAsBase64'> -): AsyncGenerator { + }: Pick, 'tensorAsBase64'> +): AsyncGenerator { const { model, modelDefinition, } = modelPackage; const scale = modelDefinition.scale ?? 1; @@ -89,12 +88,12 @@ export async function* processPixels( const [height, width,] = pixels.shape.slice(1); const patches = getPatchesFromImage([width, height,], patchSize, padding); yield; - let upscaledTensor: undefined | tf.Tensor4D; + let upscaledTensor: undefined | Tensor4D; const total = patches.length * patches[0].length; for (let rowIdx = 0; rowIdx < patches.length; rowIdx++) { const row = patches[rowIdx]; const columns = row.length; - let colTensor: undefined | tf.Tensor4D; + let colTensor: undefined | Tensor4D; yield [colTensor, upscaledTensor,]; for (let colIdx = 0; colIdx < columns; colIdx++) { const { pre, post, } = row[colIdx]; @@ -124,7 +123,7 @@ export async function* processPixels( progress(percent); } else { /* eslint-disable @typescript-eslint/no-unnecessary-type-assertion */ - const squeezedTensor = processedPrediction.squeeze() as tf.Tensor3D; + const squeezedTensor = processedPrediction.squeeze() as Tensor3D; const sliceData: SliceData = { row: rowIdx, col: colIdx, @@ -146,12 +145,12 @@ export async function* processPixels( } yield [upscaledTensor, colTensor, processedPrediction,]; - colTensor = concatTensors(tf, [colTensor, processedPrediction,], 2); + colTensor = concatTensors(tf, [colTensor, processedPrediction,], 2); processedPrediction.dispose(); yield [upscaledTensor, colTensor,]; } - upscaledTensor = concatTensors(tf, [upscaledTensor, colTensor,], 1); + upscaledTensor = concatTensors(tf, [upscaledTensor, colTensor,], 1); /* eslint-disable @typescript-eslint/no-non-null-assertion */ colTensor!.dispose(); @@ -168,7 +167,7 @@ export async function* processPixels( upscaledTensor?.dispose(); yield [processedUpscaledTensor,]; - const squeezedTensor = processedUpscaledTensor!.squeeze() as tf.Tensor3D; + const squeezedTensor = processedUpscaledTensor!.squeeze() as Tensor3D; /* eslint-disable @typescript-eslint/no-non-null-assertion */ processedUpscaledTensor!.dispose(); return squeezedTensor; @@ -193,36 +192,40 @@ export async function* processPixels( // https://github.com/tensorflow/tfjs/issues/1125 /* eslint-disable @typescript-eslint/no-unnecessary-type-assertion */ - const squeezedTensor = postprocessedTensor.squeeze() as tf.Tensor3D; + const squeezedTensor = postprocessedTensor.squeeze() as Tensor3D; postprocessedTensor.dispose(); return squeezedTensor; } -export function upscale( +export function upscale( + tf: T, input: I, args: Omit & { output: BASE64; }, modelPackage: ModelPackage, - internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> + internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> ): AsyncGenerator; -export function upscale( +export function upscale( + tf: T, input: I, args: Omit & { output: TENSOR; }, modelPackage: ModelPackage, - internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> - ): AsyncGenerator; -export function upscale( + internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> +): AsyncGenerator; +export function upscale( + tf: T, input: I, args: Omit & { output: BASE64 | TENSOR; }, modelPackage: ModelPackage, - internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> -): AsyncGenerator; -export async function* upscale( + internalConfig: Pick, 'getImageAsTensor' | 'tensorAsBase64'> +): AsyncGenerator; +export async function* upscale( + tf: T, input: I, args: Omit & { output: BASE64 | TENSOR; @@ -231,8 +234,8 @@ export async function* upscale( { getImageAsTensor, tensorAsBase64, - }: Pick, 'getImageAsTensor' | 'tensorAsBase64'> -): AsyncGenerator { + }: Pick, 'getImageAsTensor' | 'tensorAsBase64'> +): AsyncGenerator { const parsedInput = getCopyOfInput(input); const startingPixels = await getImageAsTensor(tf, parsedInput); yield startingPixels; @@ -253,6 +256,7 @@ export async function* upscale( yield preprocessedPixels; const gen = processPixels( + tf, preprocessedPixels, { output: args.output, @@ -282,7 +286,7 @@ export async function* upscale( } } preprocessedPixels.dispose(); - const upscaledPixels: tf.Tensor3D = result.value; + const upscaledPixels: Tensor3D = result.value; if (args.output === 'tensor') { return upscaledPixels; @@ -293,38 +297,42 @@ export async function* upscale( return base64Src; }; -interface InternalConfig { +interface InternalConfig { checkValidEnvironment: CheckValidEnvironment; - getImageAsTensor: GetImageAsTensor, + getImageAsTensor: GetImageAsTensor, tensorAsBase64: TensorAsBase64, } -export function cancellableUpscale( - input: Input, +export function cancellableUpscale( + tf: T, + input: I, { signal, awaitNextFrame, ...args }: Omit & { output: TENSOR}, internalArgs: ModelPackage & { signal: AbortSignal; }, - internalConfig: InternalConfig, - ): Promise; -export function cancellableUpscale( - input: Input, + internalConfig: InternalConfig, +): Promise; +export function cancellableUpscale( + tf: T, + input: I, { signal, awaitNextFrame, ...args }: Omit & { output: BASE64}, internalArgs: ModelPackage & { signal: AbortSignal; }, - internalConfig: InternalConfig, + internalConfig: InternalConfig, ): Promise; -export function cancellableUpscale( - input: Input, +export function cancellableUpscale( + tf: T, + input: I, { signal, awaitNextFrame, ...args }: Omit & { output: BASE64 | TENSOR }, internalArgs: ModelPackage & { signal: AbortSignal; }, - internalConfig: InternalConfig, -): Promise; -export async function cancellableUpscale( - input: Input, + internalConfig: InternalConfig, +): Promise; +export async function cancellableUpscale( + tf: T, + input: I, { signal, awaitNextFrame, ...args }: Omit & { output: BASE64 | TENSOR}, internalArgs: ModelPackage & { signal: AbortSignal; @@ -332,7 +340,7 @@ export async function cancellableUpscale( { checkValidEnvironment, ...internalConfig - }: InternalConfig + }: InternalConfig ) { checkValidEnvironment(input, { output: args.output, @@ -341,6 +349,7 @@ export async function cancellableUpscale( const tick = makeTick(tf, signal || internalArgs.signal, awaitNextFrame); await tick(); const upscaledPixels = await wrapGenerator(upscale( + tf, input, args, internalArgs, diff --git a/packages/upscalerjs/src/upscaler.test.ts b/packages/upscalerjs/src/upscaler.test.ts index aecc380bc..08bc0c8bb 100644 --- a/packages/upscalerjs/src/upscaler.test.ts +++ b/packages/upscalerjs/src/upscaler.test.ts @@ -89,7 +89,7 @@ describe('Upscaler', () => { const tick = () => new Promise(resolve => setTimeout(resolve)); let count = 0; - vi.mocked(cancellableUpscale).mockImplementation(async function (_1, _2, { signal }: { + vi.mocked(cancellableUpscale).mockImplementation(async function (_0, _1, _2, { signal }: { signal: AbortSignal; }) { try { diff --git a/packages/upscalerjs/src/upscaler.ts b/packages/upscalerjs/src/upscaler.ts index 4d07daa8f..5a89dbb9a 100644 --- a/packages/upscalerjs/src/upscaler.ts +++ b/packages/upscalerjs/src/upscaler.ts @@ -91,7 +91,7 @@ export class Upscaler { this._opts = { ...opts, }; - this._model = loadModel(getModel(this.tf, this._opts.model || DEFAULT_MODEL)); + this._model = loadModel(tf, getModel(tf, this._opts.model || DEFAULT_MODEL)); this.ready = new Promise((resolve, reject) => { this._model.then(() => cancellableWarmup( tf, @@ -161,7 +161,7 @@ export class Upscaler { ) { await this.ready; const modelPackage = await this._model; - return cancellableUpscale(image, getUpscaleOptions(options), { + return cancellableUpscale(tf, image, getUpscaleOptions(options), { ...modelPackage, signal: this._abortController.signal, }, { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 5913f4b84..c95366c3c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -813,6 +813,12 @@ importers: packages/upscalerjs: dependencies: + '@tensorflow/tfjs-core': + specifier: ~4.8.0 + version: 4.8.0 + '@tensorflow/tfjs-layers': + specifier: ~4.8.0 + version: 4.8.0(@tensorflow/tfjs-core@4.8.0) '@upscalerjs/core': specifier: workspace:* version: link:../core diff --git a/test/lib/node/prepare.ts b/test/lib/node/prepare.ts index e79ec39b8..fe9f4f5c1 100644 --- a/test/lib/node/prepare.ts +++ b/test/lib/node/prepare.ts @@ -51,7 +51,7 @@ export const prepareScriptBundleForNodeCJS: Bundle = async ({ callback: async ({ moduleFolder }) => { const expectedFiles = extractAllFilesFromPackageJSON(moduleFolder).filter(file => file.includes('node')); await validateBuild(moduleFolder, expectedFiles, { includeFilesFromPackageJSON: false }); - console.log(`successfully built upscaler in ${moduleFolder}`); + console.log(`successfully built upscaler in ${moduleFolder} for node`); }, }, ...getAllAvailableModelPackages().map((packageName): DependencyDefinition => { @@ -66,7 +66,7 @@ export const prepareScriptBundleForNodeCJS: Bundle = async ({ callback: async ({ moduleFolder }) => { const expectedFiles = extractAllFilesFromPackageJSON(moduleFolder).filter(file => file.includes('node')); await validateBuild(moduleFolder, expectedFiles, { includeFilesFromPackageJSON: false }); - console.log(`successfully built ${packageName} in ${moduleFolder}`); + console.log(`successfully built ${packageName} in ${moduleFolder} for node`); }, }; }),