diff --git a/packages/shared/src/constants.test.ts b/packages/shared/src/constants.test.ts index a176d9c7a..460e08877 100644 --- a/packages/shared/src/constants.test.ts +++ b/packages/shared/src/constants.test.ts @@ -1,16 +1,16 @@ import * as tf from '@tensorflow/tfjs-node'; import { vi } from 'vitest'; -import { +import { ModelDefinition, MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from './types'; -import { +import { makeIsNDimensionalTensor, - isFourDimensionalTensor, - isThreeDimensionalTensor, + isFourDimensionalTensor, + isThreeDimensionalTensor, isTensor, isString, - isValidModelDefinition, + checkModelDefinition, hasValidChannels, isValidRange, isNumber, @@ -54,7 +54,7 @@ describe('isFourDimensionalTensor', () => { expect(isFourDimensionalTensor(tf.tensor([[[1,],],]))).toEqual(false); }); - expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false); + expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false); }); describe('isThreeDimensionalTensor', () => { @@ -92,33 +92,33 @@ describe('isString', () => { describe('isValidModelDefinition', () => { it('throws error if given an undefined', () => { - expect(() => isValidModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); + expect(() => checkModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); }); it('throws error if given no path', () => { - expect(() => isValidModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH); + expect(() => checkModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH); }); it('throws error if given invalid model type', () => { - expect(() => isValidModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE); + expect(() => checkModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE); }); it('returns true if given scale and path', () => { - expect(isValidModelDefinition({ - path: 'foo', + expect(checkModelDefinition({ + path: 'foo', scale: 2, modelType: 'layers', - })).toEqual(true); + })).toEqual(true); }); }); describe('hasValidChannels', () => { it('returns true if a tensor has valid channels', () => { - expect(hasValidChannels(tf.ones([4,4,3]))).toEqual(true); + expect(hasValidChannels(tf.ones([4, 4, 3]))).toEqual(true); }); it('returns false if a tensor does not have valid channels', () => { - expect(hasValidChannels(tf.ones([4,4,4]))).toEqual(false); + expect(hasValidChannels(tf.ones([4, 4, 4]))).toEqual(false); }); }); @@ -154,15 +154,15 @@ describe('isValidRange', () => { }); it('returns false if it gets an array with three numbers', () => { - expect(isValidRange([1,2,3])).toEqual(false); + expect(isValidRange([1, 2, 3])).toEqual(false); }); it('returns false if it gets an array with a number and a string', () => { - expect(isValidRange([1,'foo'])).toEqual(false); + expect(isValidRange([1, 'foo'])).toEqual(false); }); it('returns true if it gets an array with two numbers', () => { - expect(isValidRange([1,2])).toEqual(true); + expect(isValidRange([1, 2])).toEqual(true); }); }); @@ -176,19 +176,19 @@ describe('isShape4D', () => { }); it('returns false if given an array of 3 numbers', () => { - expect(isShape4D([1,2,3])).toEqual(false); + expect(isShape4D([1, 2, 3])).toEqual(false); }); it('returns false if given an array of 5 numbers', () => { - expect(isShape4D([1,2,3,4,5])).toEqual(false); + expect(isShape4D([1, 2, 3, 4, 5])).toEqual(false); }); it('returns false if given an array of not all numbers', () => { - expect(isShape4D([1,null,3,'foo'])).toEqual(false); + expect(isShape4D([1, null, 3, 'foo'])).toEqual(false); }); it('returns true if given an array of all numbers', () => { - expect(isShape4D([1,2,3,4])).toEqual(true); + expect(isShape4D([1, 2, 3, 4])).toEqual(true); }); it('returns true if given an array containing nulls', () => { @@ -201,9 +201,9 @@ describe('isFixedShape4D', () => { [[null, null, null, 3], false], [[null, -1, -1, 3], false], [[null, 2, 2, 3], true], - ])('%s | %s',(args, expectation) => { - expect(isFixedShape4D(args)).toEqual(expectation); - }); + ])('%s | %s', (args, expectation) => { + expect(isFixedShape4D(args)).toEqual(expectation); + }); }); describe('isDynamicShape', () => { @@ -212,7 +212,7 @@ describe('isDynamicShape', () => { [[null, -1, -1, 3], true], [[null, 2, 2, 3], false], ])('%s | %s', (args, expectation) => { - expect(isDynamicShape4D(args)).toEqual(expectation); - }); + expect(isDynamicShape4D(args)).toEqual(expectation); + }); }); diff --git a/packages/shared/src/constants.ts b/packages/shared/src/constants.ts index c1e2f6ce4..31de9abab 100644 --- a/packages/shared/src/constants.ts +++ b/packages/shared/src/constants.ts @@ -38,19 +38,6 @@ export class ModelDefinitionValidationError extends Error { } } -export const isValidModelDefinition = (modelDefinition?: ModelDefinition): modelDefinition is ModelDefinition => { - if (modelDefinition === undefined) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); - } - if (!isValidModelType(modelDefinition.modelType ?? 'layers')) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE); - } - if (!modelDefinition.path && !modelDefinition._internals?.path) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH); - } - return true; -}; - export const hasValidChannels = (tensor: tf.Tensor): boolean => tensor.shape.slice(-1)[0] === 3; export const isNumber = (el: unknown): el is number => typeof el === 'number'; diff --git a/packages/shared/src/types.ts b/packages/shared/src/types.ts index d7d6b31e4..9db8fa645 100644 --- a/packages/shared/src/types.ts +++ b/packages/shared/src/types.ts @@ -101,9 +101,3 @@ export type ModelDefinitionFn = (tf: TF) => ModelDefinition; export type ModelDefinitionObjectOrFn = ModelDefinitionFn | ModelDefinition; export type IsTensor = (pixels: Tensor) => pixels is T; - -export enum MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE { - UNDEFINED = 'undefined', - INVALID_MODEL_TYPE = 'invalidModelType', - MISSING_PATH = 'missingPath', -} diff --git a/packages/upscalerjs/src/browser/loadModel.browser.test.ts b/packages/upscalerjs/src/browser/loadModel.browser.test.ts index 181d77b82..45ddaca1c 100644 --- a/packages/upscalerjs/src/browser/loadModel.browser.test.ts +++ b/packages/upscalerjs/src/browser/loadModel.browser.test.ts @@ -13,20 +13,17 @@ import { import * as tf from '@tensorflow/tfjs-node'; import { - getModelDefinitionError, ERROR_MODEL_DEFINITION_BUG, } from '../shared/errors-and-warnings'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '../../../shared/src/types'; import { - ModelDefinitionValidationError, - isValidModelDefinition, -} from '../../../shared/src/constants'; + checkModelDefinition, +} from '../shared/utils'; -import type * as sharedConstants from '../../../shared/src/constants'; +import type * as sharedUtils from '../shared/utils'; import type * as modelUtils from '../shared/model-utils'; import type * as errorsAndWarnings from '../shared/errors-and-warnings'; import type * as loadModelBrowser from './loadModel.browser'; @@ -47,18 +44,17 @@ vi.mock('../shared/model-utils', async () => { }); vi.mock('../shared/errors-and-warnings', async () => { - const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; + const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; return { ...rest, - getModelDefinitionError: vi.fn(getModelDefinitionError), } }); -vi.mock('../../../shared/src/constants', async () => { - const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants; +vi.mock('../shared/utils', async () => { + const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils; return { ...rest, - isValidModelDefinition: vi.fn(isValidModelDefinition), + checkModelDefinition: vi.fn(checkModelDefinition), } }); @@ -202,22 +198,18 @@ describe('loadModel browser tests', () => { }); describe('loadModel', () => { - it('throws if not a valid model definition', async () => { - const e = new Error(ERROR_MODEL_DEFINITION_BUG); - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); + it('throws if given a bad model definition', async () => { + vi.mocked(checkModelDefinition).mockImplementation(() => { + throw new Error(); }); - vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e); - await expect(() => loadModel(tf, Promise.resolve({ - path: 'foo', - scale: 2, - modelType: 'layers', - }))).rejects.toThrowError(e); + await expect(loadModel(tf, Promise.resolve({}) as Promise)) + .rejects + .toThrow(); }); it('loads a valid layers model successfully', async () => { - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true); const model = 'foo' as unknown as LayersModel; vi.mocked(loadTfModel).mockImplementation(async () => model); expect(loadTfModel).toHaveBeenCalledTimes(0); @@ -240,7 +232,7 @@ describe('loadModel browser tests', () => { }); it('loads a valid graph model successfully', async () => { - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true); const model = 'foo' as unknown as GraphModel; const modelDefinition: ModelDefinition = { diff --git a/packages/upscalerjs/src/browser/loadModel.browser.ts b/packages/upscalerjs/src/browser/loadModel.browser.ts index 5f052d2dd..efc0f096a 100644 --- a/packages/upscalerjs/src/browser/loadModel.browser.ts +++ b/packages/upscalerjs/src/browser/loadModel.browser.ts @@ -7,17 +7,11 @@ import { } from '../shared/model-utils'; import { ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, } from '../shared/errors-and-warnings'; import type { TF, } from '../../../shared/src/types'; -import { - isValidModelDefinition, -} from '../../../shared/src/constants'; -import { - errIsModelDefinitionValidationError, -} from '../shared/utils'; +import { checkModelDefinition, } from '../shared/utils.js'; type CDN = 'jsdelivr' | 'unpkg'; @@ -69,17 +63,7 @@ export async function fetchModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; - try { - isValidModelDefinition(modelDefinition); - } catch (err: unknown) { - if (err instanceof Error) { - if (errIsModelDefinitionValidationError(err)) { - throw getModelDefinitionError(err, modelDefinition); - } - throw ERROR_MODEL_DEFINITION_BUG(err.message); - } - throw err; - } + checkModelDefinition(modelDefinition); const parsedModelDefinition = parseModelDefinition(modelDefinition); diff --git a/packages/upscalerjs/src/node/loadModel.node.test.ts b/packages/upscalerjs/src/node/loadModel.node.test.ts index 1ff53b2f4..2d337c658 100644 --- a/packages/upscalerjs/src/node/loadModel.node.test.ts +++ b/packages/upscalerjs/src/node/loadModel.node.test.ts @@ -1,4 +1,4 @@ -import { +import { loadModel, getModelPath, getModuleFolder, @@ -9,7 +9,6 @@ import path from 'path'; import { resolver, } from './resolver'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from "../../../shared/src/types"; import * as tf from '@tensorflow/tfjs-node'; import { @@ -19,11 +18,10 @@ import { loadTfModel, } from '../shared/model-utils'; import { - isValidModelDefinition, - ModelDefinitionValidationError, -} from '../../../shared/src/constants'; + checkModelDefinition, +} from '../shared/utils'; -import type * as sharedConstants from '../../../shared/src/constants'; +import type * as sharedUtils from '../shared/utils'; import type * as modelUtils from '../shared/model-utils'; import type * as errorsAndWarnings from '../shared/errors-and-warnings'; import type * as resolverModule from './resolver'; @@ -37,18 +35,17 @@ vi.mock('../shared/model-utils', async () => { }); vi.mock('../shared/errors-and-warnings', async () => { - const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; + const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; return { ...rest, - getModelDefinitionError: vi.fn(getModelDefinitionError), } }); -vi.mock('../../../shared/src/constants', async () => { - const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants; +vi.mock('../shared/utils', async () => { + const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils; return { ...rest, - isValidModelDefinition: vi.fn(isValidModelDefinition), + checkModelDefinition: vi.fn(checkModelDefinition), } }); vi.mock('./resolver', async () => { @@ -91,8 +88,8 @@ describe('loadModel.node', () => { describe('getModelPath', () => { it('returns model path if provided a path', () => { vi.mocked(resolver).mockImplementation(getResolver(() => '')); - expect(getModelPath({ - path: 'foo', + expect(getModelPath({ + path: 'foo', _internals: { path: 'some-model', name: 'baz', @@ -100,7 +97,7 @@ describe('loadModel.node', () => { }, scale: 2, modelType: 'layers', - })).toEqual('foo'); + })).toEqual('foo'); }); it('returns model path if not provided a path', () => { @@ -118,21 +115,21 @@ describe('loadModel.node', () => { }); describe('loadModel', () => { - it('throws if given an undefined model definition', async () => { + it('throws if given a bad model definition', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); const error = ERROR_MODEL_DEFINITION_BUG; - vi.mocked(isValidModelDefinition).mockImplementation(() => { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); + vi.mocked(checkModelDefinition).mockImplementation(() => { + throw new Error(); }); await expect(loadModel(tf, Promise.resolve({}) as Promise)) .rejects - .toThrow(error); + .toThrow(); }); it('loads a valid layers model', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); - vi.mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(checkModelDefinition).mockImplementation(() => true); vi.mocked(loadTfModel).mockImplementation(async () => 'layers model' as any); const path = 'foo'; @@ -148,7 +145,7 @@ describe('loadModel.node', () => { it('loads a valid graph model', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); - vi.mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(checkModelDefinition).mockImplementation(() => true); vi.mocked(loadTfModel).mockImplementation(async () => 'graph model' as any); const path = 'foo'; diff --git a/packages/upscalerjs/src/node/loadModel.node.ts b/packages/upscalerjs/src/node/loadModel.node.ts index f383e9a97..9f6d9903f 100644 --- a/packages/upscalerjs/src/node/loadModel.node.ts +++ b/packages/upscalerjs/src/node/loadModel.node.ts @@ -2,19 +2,13 @@ import path from 'path'; import { loadTfModel, parseModelDefinition, } from '../shared/model-utils'; import { resolver, } from './resolver'; import { ParsedModelDefinition, LoadModel, } from '../shared/types'; -import { - isValidModelDefinition, -} from '../../../shared/src/constants'; import type { TF, } from '../../../shared/src/types'; import { ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, } from '../shared/errors-and-warnings'; -import { - errIsModelDefinitionValidationError, -} from '../shared/utils'; +import { checkModelDefinition, } from '../shared/utils.js'; export const getMissingMatchesError = (moduleEntryPoint: string): Error => new Error( `No matches could be found for module entry point ${moduleEntryPoint}` @@ -46,17 +40,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string export const loadModel: LoadModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; - try { - isValidModelDefinition(modelDefinition); - } catch (err: unknown) { - if (err instanceof Error) { - if (errIsModelDefinitionValidationError(err)) { - throw getModelDefinitionError(err, modelDefinition); - } - throw ERROR_MODEL_DEFINITION_BUG(err.message); - } - throw err; - } + checkModelDefinition(modelDefinition); const parsedModelDefinition = parseModelDefinition(modelDefinition); diff --git a/packages/upscalerjs/src/shared/errors-and-warnings.ts b/packages/upscalerjs/src/shared/errors-and-warnings.ts index 1a5f76ca6..9432895f6 100644 --- a/packages/upscalerjs/src/shared/errors-and-warnings.ts +++ b/packages/upscalerjs/src/shared/errors-and-warnings.ts @@ -1,6 +1,5 @@ import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from "../../../shared/src/types"; // import { ModelDefinitionValidationError, } from "../constants"; @@ -60,11 +59,12 @@ const ERROR_INVALID_MODEL_TYPE_URL = 'https://upscalerjs.com/documentation/troub const WARNING_INPUT_SIZE_AND_PATCH_SIZE_URL = 'https://upscalerjs.com/documentation/troubleshooting#input-size-and-patch-size'; const ERROR_WITH_MODEL_INPUT_SHAPE_URL = 'https://upscalerjs.com/documentation/troubleshooting#error-with-model-input-shape'; +export const ERROR_UNDEFINED_MODEL = new Error('An undefined model was provided to UpscalerJS'); export const ERROR_INVALID_MODEL_TYPE = (modelType: unknown) => ([ `You've provided an invalid model type: ${JSON.stringify(modelType)}. Accepted types are "layers" and "graph".`, `For more information, see ${ERROR_INVALID_MODEL_TYPE_URL}.`, ].join(' ')); -export const ERROR_MODEL_DEFINITION_BUG = (err: string) => new Error(`There is a bug with the upscaler code. Please report this. Error: ${err}`); +export const ERROR_MODEL_DEFINITION_BUG = (err?: string) => new Error(`There is a bug with the upscaler code. Please report this. ${err ? `Error: ${err}` : ''}`.trim()); export const WARNING_INPUT_SIZE_AND_PATCH_SIZE = [ 'You have provided a patchSize, but the model definition already includes an input size.', 'Your patchSize will be ignored.', @@ -113,23 +113,3 @@ export const GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS = (modelConfigur `For more information, see ${MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS_URL}.`, `The model configuration provided was: ${JSON.stringify(modelConfiguration)}`, ].join(' '); - -// TODO: Import this from ../constants -export class ModelDefinitionValidationError extends Error { - type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE; - - constructor(type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE) { - super(type); - this.type = type; - } -} -export function getModelDefinitionError({ type, message, }: ModelDefinitionValidationError, modelDefinition?: ModelDefinition): Error { - switch (type) { - case MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE: - return new Error(ERROR_INVALID_MODEL_TYPE(modelDefinition?.modelType)); - case MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH: - return new Error(GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDefinition)); - default: - return ERROR_MODEL_DEFINITION_BUG(message); - } -} diff --git a/packages/upscalerjs/src/shared/model-utils.test.ts b/packages/upscalerjs/src/shared/model-utils.test.ts index c94dfa6bb..4f5ad250d 100644 --- a/packages/upscalerjs/src/shared/model-utils.test.ts +++ b/packages/upscalerjs/src/shared/model-utils.test.ts @@ -1,5 +1,5 @@ import { vi } from 'vitest'; -import { +import { parseModelDefinition, getModel, loadTfModel, @@ -7,29 +7,27 @@ import { getModelInputShape, getPatchSizeAsMultiple, } from './model-utils'; -import type * as utils from './utils'; +import type * as utils from './utils'; import { warn, } from './utils'; import * as isLayersModel from './isLayersModel'; -import { - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, +import { ModelDefinition, ModelDefinitionFn, - } from '../../../shared/src/types'; +} from '../../../shared/src/types'; import type * as sharedConstants from '../../../shared/src/constants'; -import { +import { isShape4D, - } from '../../../shared/src/constants'; +} from '../../../shared/src/constants'; import { ModelPackage } from './types'; import { ERROR_INVALID_MODEL_TYPE, - ERROR_MODEL_DEFINITION_BUG, - ERROR_WITH_MODEL_INPUT_SHAPE, + ERROR_MODEL_DEFINITION_BUG, + ERROR_WITH_MODEL_INPUT_SHAPE, GET_INVALID_PATCH_SIZE, WARNING_INPUT_SIZE_AND_PATCH_SIZE, WARNING_UNDEFINED_PADDING, - getModelDefinitionError, MODEL_INPUT_SIZE_MUST_BE_SQUARE, GET_INVALID_PATCH_SIZE_AND_PADDING, GET_WARNING_PATCH_SIZE_INDIVISIBLE_BY_DIVISIBILITY_FACTOR, @@ -81,18 +79,6 @@ describe('model-utils', () => { vi.clearAllMocks(); }); - describe('getModelDefinitionError', () => { - it('returns an error if invalid model type is provided', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_INVALID_MODEL_TYPE('foo')); - }); - - it('returns a generic error otherwise', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_MODEL_DEFINITION_BUG); - }); - }) - describe('getModel', () => { describe('ModelDefinition', () => { it('returns model definition', async () => { diff --git a/packages/upscalerjs/src/shared/utils.test.ts b/packages/upscalerjs/src/shared/utils.test.ts index 67ac8d67f..bf2f9554c 100644 --- a/packages/upscalerjs/src/shared/utils.test.ts +++ b/packages/upscalerjs/src/shared/utils.test.ts @@ -1,22 +1,23 @@ import { Tensor3D } from '@tensorflow/tfjs-node'; import { vi } from 'vitest'; import * as tf from '@tensorflow/tfjs-node'; -import { +import { processAndDisposeOfTensor, - wrapGenerator, - isSingleArgProgress, - isMultiArgTensorProgress, - warn, + wrapGenerator, + isSingleArgProgress, + isMultiArgTensorProgress, + warn, isAborted, + checkModelDefinition, } from './utils'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '../../../shared/src/types'; import { ERROR_INVALID_MODEL_TYPE, - ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, + ERROR_MODEL_DEFINITION_BUG, + ERROR_UNDEFINED_MODEL, + GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS, } from './errors-and-warnings'; describe('isAborted', () => { @@ -109,7 +110,7 @@ describe('wrapGenerator', () => { return 'baz'; } - const callback = vi.fn(async () => {}); + const callback = vi.fn(async () => { }); await wrapGenerator(foo(), callback); expect(callback).toHaveBeenCalledTimes(2); expect(callback).toHaveBeenCalledWith('foo'); @@ -141,57 +142,72 @@ describe('wrapGenerator', () => { describe('isSingleArgProgress', () => { it('returns true for function', () => { - expect(isSingleArgProgress(() => {})).toEqual(true); + expect(isSingleArgProgress(() => { })).toEqual(true); }); it('returns true for a single arg function', () => { - expect(isSingleArgProgress((_1: any) => {})).toEqual(true); + expect(isSingleArgProgress((_1: any) => { })).toEqual(true); }); it('returns false for a double arg function', () => { - expect(isSingleArgProgress((_1: any, _2: any) => {})).toEqual(false); + expect(isSingleArgProgress((_1: any, _2: any) => { })).toEqual(false); }); }); describe('isMultiArgProgress', () => { it('returns false for a single arg function', () => { - expect(isMultiArgTensorProgress((_1: any) => {}, undefined, undefined)).toEqual(false); + expect(isMultiArgTensorProgress((_1: any) => { }, undefined, undefined)).toEqual(false); }); it('returns false for a zero arg function', () => { - expect(isMultiArgTensorProgress(() => {}, undefined, undefined, )).toEqual(false); + expect(isMultiArgTensorProgress(() => { }, undefined, undefined,)).toEqual(false); }); it('returns false for a multi arg tensor string function', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'base64', 'base64')).toEqual(false); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'base64', 'base64')).toEqual(false); }); it('returns false for a multi arg tensor string function with overloaded outputs', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', 'base64')).toEqual(false); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', 'base64')).toEqual(false); }); it('returns true for a multi arg tensor function', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', 'tensor')).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', 'tensor')).toEqual(true); }); it('returns true for a multi arg tensor function with conflicting outputs', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'base64', 'tensor')).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'base64', 'tensor')).toEqual(true); }); it('returns true for a multi arg tensor function with conflicting outputs with an undefined progressOutput', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', undefined)).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', undefined)).toEqual(true); }); }); -describe('getModelDefinitionError', () => { - it('returns an error if invalid model type is provided', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_INVALID_MODEL_TYPE('foo')); +describe('checkModelDefinition', () => { + it('throws if an undefined model is provided', () => { + expect(() => checkModelDefinition(undefined)).toThrowError(ERROR_UNDEFINED_MODEL); + }); + + it('throws if an invalid model is provided', () => { + const modelDef = { + modelType: 'foo', + } as unknown as ModelDefinition; + expect(() => checkModelDefinition(modelDef)).toThrowError(ERROR_INVALID_MODEL_TYPE(modelDef)); }); - it('returns a generic error otherwise', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_MODEL_DEFINITION_BUG); + it('throws if a model is missing a path and _internals', () => { + const modelDef = { + modelType: 'layers', + } as unknown as ModelDefinition; + expect(() => checkModelDefinition(modelDef)).toThrowError(GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDef)); + }); + + it('passes with a valid model', () => { + checkModelDefinition({ + modelType: 'layers', + path: '/foo', + }); }); }) @@ -202,7 +218,7 @@ describe('processAndDisposeOfTensor', () => { isDisposed = false; value?: number; - mockDispose: typeof vi.fn = vi.fn().mockImplementation(() => {}); + mockDispose: typeof vi.fn = vi.fn().mockImplementation(() => { }); constructor({ mockDispose, diff --git a/packages/upscalerjs/src/shared/utils.ts b/packages/upscalerjs/src/shared/utils.ts index 9d669958b..3f0bc6da2 100644 --- a/packages/upscalerjs/src/shared/utils.ts +++ b/packages/upscalerjs/src/shared/utils.ts @@ -1,12 +1,16 @@ import type { Tensor, } from '@tensorflow/tfjs-core'; import type { Progress, SingleArgProgress, ResultFormat, MultiArgTensorProgress, } from './types'; -import type { - ProcessFn, - TF, +import { + type ModelDefinition, + type ProcessFn, + type TF, } from '../../../shared/src/types'; import { - ModelDefinitionValidationError, -} from '../../../shared/src/constants'; + ERROR_INVALID_MODEL_TYPE, + ERROR_UNDEFINED_MODEL, + GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS, +} from './errors-and-warnings'; +import { isValidModelType, } from '../../../shared/src/constants'; export const warn = (msg: string | string[]): void => { console.warn(Array.isArray(msg) ? msg.join('\n') : msg);// skipcq: JS-0002 @@ -68,4 +72,14 @@ export function processAndDisposeOfTensor( return tensor; } -export const errIsModelDefinitionValidationError = (err: unknown): err is ModelDefinitionValidationError => err instanceof Error && 'type' in err; +export const checkModelDefinition = (modelDefinition?: ModelDefinition): void => { + if (modelDefinition === undefined) { + throw ERROR_UNDEFINED_MODEL; + } + if (!isValidModelType(modelDefinition.modelType ?? 'layers')) { + throw ERROR_INVALID_MODEL_TYPE(modelDefinition); + } + if (!modelDefinition.path && !modelDefinition._internals?.path) { + throw GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDefinition); + } +};