Skip to content

Commit

Permalink
Simplify model definition validation
Browse files Browse the repository at this point in the history
  • Loading branch information
thekevinscott committed Nov 21, 2023
1 parent 073a779 commit 68860d6
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 201 deletions.
52 changes: 26 additions & 26 deletions packages/shared/src/constants.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import * as tf from '@tensorflow/tfjs-node';
import { vi } from 'vitest';
import {
import {
ModelDefinition,
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from './types';
import {
import {
makeIsNDimensionalTensor,
isFourDimensionalTensor,
isThreeDimensionalTensor,
isFourDimensionalTensor,
isThreeDimensionalTensor,
isTensor,
isString,
isValidModelDefinition,
checkModelDefinition,
hasValidChannels,
isValidRange,
isNumber,
Expand Down Expand Up @@ -54,7 +54,7 @@ describe('isFourDimensionalTensor', () => {
expect(isFourDimensionalTensor(tf.tensor([[[1,],],]))).toEqual(false);
});

expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false);
expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false);
});

describe('isThreeDimensionalTensor', () => {
Expand Down Expand Up @@ -92,33 +92,33 @@ describe('isString', () => {

describe('isValidModelDefinition', () => {
it('throws error if given an undefined', () => {
expect(() => isValidModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
expect(() => checkModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
});

it('throws error if given no path', () => {
expect(() => isValidModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH);
expect(() => checkModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH);
});

it('throws error if given invalid model type', () => {
expect(() => isValidModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE);
expect(() => checkModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE);
});

it('returns true if given scale and path', () => {
expect(isValidModelDefinition({
path: 'foo',
expect(checkModelDefinition({
path: 'foo',
scale: 2,
modelType: 'layers',
})).toEqual(true);
})).toEqual(true);
});
});

describe('hasValidChannels', () => {
it('returns true if a tensor has valid channels', () => {
expect(hasValidChannels(tf.ones([4,4,3]))).toEqual(true);
expect(hasValidChannels(tf.ones([4, 4, 3]))).toEqual(true);
});

it('returns false if a tensor does not have valid channels', () => {
expect(hasValidChannels(tf.ones([4,4,4]))).toEqual(false);
expect(hasValidChannels(tf.ones([4, 4, 4]))).toEqual(false);
});
});

Expand Down Expand Up @@ -154,15 +154,15 @@ describe('isValidRange', () => {
});

it('returns false if it gets an array with three numbers', () => {
expect(isValidRange([1,2,3])).toEqual(false);
expect(isValidRange([1, 2, 3])).toEqual(false);
});

it('returns false if it gets an array with a number and a string', () => {
expect(isValidRange([1,'foo'])).toEqual(false);
expect(isValidRange([1, 'foo'])).toEqual(false);
});

it('returns true if it gets an array with two numbers', () => {
expect(isValidRange([1,2])).toEqual(true);
expect(isValidRange([1, 2])).toEqual(true);
});
});

Expand All @@ -176,19 +176,19 @@ describe('isShape4D', () => {
});

it('returns false if given an array of 3 numbers', () => {
expect(isShape4D([1,2,3])).toEqual(false);
expect(isShape4D([1, 2, 3])).toEqual(false);
});

it('returns false if given an array of 5 numbers', () => {
expect(isShape4D([1,2,3,4,5])).toEqual(false);
expect(isShape4D([1, 2, 3, 4, 5])).toEqual(false);
});

it('returns false if given an array of not all numbers', () => {
expect(isShape4D([1,null,3,'foo'])).toEqual(false);
expect(isShape4D([1, null, 3, 'foo'])).toEqual(false);
});

it('returns true if given an array of all numbers', () => {
expect(isShape4D([1,2,3,4])).toEqual(true);
expect(isShape4D([1, 2, 3, 4])).toEqual(true);
});

it('returns true if given an array containing nulls', () => {
Expand All @@ -201,9 +201,9 @@ describe('isFixedShape4D', () => {
[[null, null, null, 3], false],
[[null, -1, -1, 3], false],
[[null, 2, 2, 3], true],
])('%s | %s',(args, expectation) => {
expect(isFixedShape4D(args)).toEqual(expectation);
});
])('%s | %s', (args, expectation) => {
expect(isFixedShape4D(args)).toEqual(expectation);
});
});

describe('isDynamicShape', () => {
Expand All @@ -212,7 +212,7 @@ describe('isDynamicShape', () => {
[[null, -1, -1, 3], true],
[[null, 2, 2, 3], false],
])('%s | %s', (args, expectation) => {
expect(isDynamicShape4D(args)).toEqual(expectation);
});
expect(isDynamicShape4D(args)).toEqual(expectation);
});
});

13 changes: 0 additions & 13 deletions packages/shared/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ export class ModelDefinitionValidationError extends Error {
}
}

export const isValidModelDefinition = (modelDefinition?: ModelDefinition): modelDefinition is ModelDefinition => {
if (modelDefinition === undefined) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
}
if (!isValidModelType(modelDefinition.modelType ?? 'layers')) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE);
}
if (!modelDefinition.path && !modelDefinition._internals?.path) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH);
}
return true;
};

export const hasValidChannels = (tensor: tf.Tensor): boolean => tensor.shape.slice(-1)[0] === 3;

export const isNumber = (el: unknown): el is number => typeof el === 'number';
Expand Down
6 changes: 0 additions & 6 deletions packages/shared/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,3 @@ export type ModelDefinitionFn = (tf: TF) => ModelDefinition;
export type ModelDefinitionObjectOrFn = ModelDefinitionFn | ModelDefinition;

export type IsTensor<T extends tf.Tensor> = (pixels: Tensor) => pixels is T;

export enum MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE {
UNDEFINED = 'undefined',
INVALID_MODEL_TYPE = 'invalidModelType',
MISSING_PATH = 'missingPath',
}
38 changes: 15 additions & 23 deletions packages/upscalerjs/src/browser/loadModel.browser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@ import {
import * as tf from '@tensorflow/tfjs-node';

import {
getModelDefinitionError,
ERROR_MODEL_DEFINITION_BUG,
} from '../shared/errors-and-warnings';
import {
ModelDefinition,
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from '../../../shared/src/types';

import {
ModelDefinitionValidationError,
isValidModelDefinition,
} from '../../../shared/src/constants';
checkModelDefinition,
} from '../shared/utils';

import type * as sharedConstants from '../../../shared/src/constants';
import type * as sharedUtils from '../shared/utils';
import type * as modelUtils from '../shared/model-utils';
import type * as errorsAndWarnings from '../shared/errors-and-warnings';
import type * as loadModelBrowser from './loadModel.browser';
Expand All @@ -47,18 +44,17 @@ vi.mock('../shared/model-utils', async () => {
});

vi.mock('../shared/errors-and-warnings', async () => {
const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings;
const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings;
return {
...rest,
getModelDefinitionError: vi.fn(getModelDefinitionError),
}
});

vi.mock('../../../shared/src/constants', async () => {
const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants;
vi.mock('../shared/utils', async () => {
const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils;
return {
...rest,
isValidModelDefinition: vi.fn(isValidModelDefinition),
checkModelDefinition: vi.fn(checkModelDefinition),
}
});

Expand Down Expand Up @@ -202,22 +198,18 @@ describe('loadModel browser tests', () => {
});

describe('loadModel', () => {
it('throws if not a valid model definition', async () => {
const e = new Error(ERROR_MODEL_DEFINITION_BUG);
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
it('throws if given a bad model definition', async () => {
vi.mocked(checkModelDefinition).mockImplementation(() => {
throw new Error();
});
vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e);

await expect(() => loadModel(tf, Promise.resolve({
path: 'foo',
scale: 2,
modelType: 'layers',
}))).rejects.toThrowError(e);
await expect(loadModel(tf, Promise.resolve({}) as Promise<ModelDefinition>))
.rejects
.toThrow();
});

it('loads a valid layers model successfully', async () => {
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true);
vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true);
const model = 'foo' as unknown as LayersModel;
vi.mocked(loadTfModel).mockImplementation(async () => model);
expect(loadTfModel).toHaveBeenCalledTimes(0);
Expand All @@ -240,7 +232,7 @@ describe('loadModel browser tests', () => {
});

it('loads a valid graph model successfully', async () => {
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true);
vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true);
const model = 'foo' as unknown as GraphModel;

const modelDefinition: ModelDefinition = {
Expand Down
20 changes: 2 additions & 18 deletions packages/upscalerjs/src/browser/loadModel.browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@ import {
} from '../shared/model-utils';
import {
ERROR_MODEL_DEFINITION_BUG,
getModelDefinitionError,
} from '../shared/errors-and-warnings';
import type {
TF,
} from '../../../shared/src/types';
import {
isValidModelDefinition,
} from '../../../shared/src/constants';
import {
errIsModelDefinitionValidationError,
} from '../shared/utils';
import { checkModelDefinition, } from '../shared/utils.js';

type CDN = 'jsdelivr' | 'unpkg';

Expand Down Expand Up @@ -69,17 +63,7 @@ export async function fetchModel<M extends ModelType, R = M extends 'graph' ? Gr
export const loadModel: LoadModel<TF> = async (tf, _modelDefinition) => {
const modelDefinition = await _modelDefinition;

try {
isValidModelDefinition(modelDefinition);
} catch (err: unknown) {
if (err instanceof Error) {
if (errIsModelDefinitionValidationError(err)) {
throw getModelDefinitionError(err, modelDefinition);
}
throw ERROR_MODEL_DEFINITION_BUG(err.message);
}
throw err;
}
checkModelDefinition(modelDefinition);

const parsedModelDefinition = parseModelDefinition(modelDefinition);

Expand Down
Loading

0 comments on commit 68860d6

Please sign in to comment.