Skip to content

Commit

Permalink
Refactor error handling to surface the error (#1270)
Browse files Browse the repository at this point in the history
* Refactor error handling to surface the error
  • Loading branch information
thekevinscott authored Nov 21, 2023
1 parent 5b6279b commit 516a2c7
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 222 deletions.
63 changes: 18 additions & 45 deletions packages/shared/src/constants.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import * as tf from '@tensorflow/tfjs-node';
import { vi } from 'vitest';
import {
ModelDefinition,
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from './types';
import {
import {
makeIsNDimensionalTensor,
isFourDimensionalTensor,
isThreeDimensionalTensor,
isFourDimensionalTensor,
isThreeDimensionalTensor,
isTensor,
isString,
isValidModelDefinition,
hasValidChannels,
isValidRange,
isNumber,
Expand Down Expand Up @@ -54,7 +49,7 @@ describe('isFourDimensionalTensor', () => {
expect(isFourDimensionalTensor(tf.tensor([[[1,],],]))).toEqual(false);
});

expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false);
expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false);
});

describe('isThreeDimensionalTensor', () => {
Expand Down Expand Up @@ -90,35 +85,13 @@ describe('isString', () => {
});
});

describe('isValidModelDefinition', () => {
it('throws error if given an undefined', () => {
expect(() => isValidModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
});

it('throws error if given no path', () => {
expect(() => isValidModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH);
});

it('throws error if given invalid model type', () => {
expect(() => isValidModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE);
});

it('returns true if given scale and path', () => {
expect(isValidModelDefinition({
path: 'foo',
scale: 2,
modelType: 'layers',
})).toEqual(true);
});
});

describe('hasValidChannels', () => {
it('returns true if a tensor has valid channels', () => {
expect(hasValidChannels(tf.ones([4,4,3]))).toEqual(true);
expect(hasValidChannels(tf.ones([4, 4, 3]))).toEqual(true);
});

it('returns false if a tensor does not have valid channels', () => {
expect(hasValidChannels(tf.ones([4,4,4]))).toEqual(false);
expect(hasValidChannels(tf.ones([4, 4, 4]))).toEqual(false);
});
});

Expand Down Expand Up @@ -154,15 +127,15 @@ describe('isValidRange', () => {
});

it('returns false if it gets an array with three numbers', () => {
expect(isValidRange([1,2,3])).toEqual(false);
expect(isValidRange([1, 2, 3])).toEqual(false);
});

it('returns false if it gets an array with a number and a string', () => {
expect(isValidRange([1,'foo'])).toEqual(false);
expect(isValidRange([1, 'foo'])).toEqual(false);
});

it('returns true if it gets an array with two numbers', () => {
expect(isValidRange([1,2])).toEqual(true);
expect(isValidRange([1, 2])).toEqual(true);
});
});

Expand All @@ -176,19 +149,19 @@ describe('isShape4D', () => {
});

it('returns false if given an array of 3 numbers', () => {
expect(isShape4D([1,2,3])).toEqual(false);
expect(isShape4D([1, 2, 3])).toEqual(false);
});

it('returns false if given an array of 5 numbers', () => {
expect(isShape4D([1,2,3,4,5])).toEqual(false);
expect(isShape4D([1, 2, 3, 4, 5])).toEqual(false);
});

it('returns false if given an array of not all numbers', () => {
expect(isShape4D([1,null,3,'foo'])).toEqual(false);
expect(isShape4D([1, null, 3, 'foo'])).toEqual(false);
});

it('returns true if given an array of all numbers', () => {
expect(isShape4D([1,2,3,4])).toEqual(true);
expect(isShape4D([1, 2, 3, 4])).toEqual(true);
});

it('returns true if given an array containing nulls', () => {
Expand All @@ -201,9 +174,9 @@ describe('isFixedShape4D', () => {
[[null, null, null, 3], false],
[[null, -1, -1, 3], false],
[[null, 2, 2, 3], true],
])('%s | %s',(args, expectation) => {
expect(isFixedShape4D(args)).toEqual(expectation);
});
])('%s | %s', (args, expectation) => {
expect(isFixedShape4D(args)).toEqual(expectation);
});
});

describe('isDynamicShape', () => {
Expand All @@ -212,7 +185,7 @@ describe('isDynamicShape', () => {
[[null, -1, -1, 3], true],
[[null, 2, 2, 3], false],
])('%s | %s', (args, expectation) => {
expect(isDynamicShape4D(args)).toEqual(expectation);
});
expect(isDynamicShape4D(args)).toEqual(expectation);
});
});

23 changes: 1 addition & 22 deletions packages/shared/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as tf from '@tensorflow/tfjs-core';
import { Tensor, Tensor3D, Tensor4D, } from '@tensorflow/tfjs-core';
import { DynamicShape4D, FixedShape4D, IsTensor, MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, ModelDefinition, ModelType, Shape4D } from './types';
import { DynamicShape4D, FixedShape4D, IsTensor, ModelType, Shape4D } from './types';

export const isShape4D = (shape?: unknown): shape is Shape4D => {
if (!Boolean(shape) || !Array.isArray(shape) || shape.length !== 4) {
Expand Down Expand Up @@ -29,27 +29,6 @@ export const isTensor = (input: unknown): input is tf.Tensor => input instanceof
export const isString = (el: unknown): el is string => typeof el === 'string';

export const isValidModelType = (modelType: unknown): modelType is ModelType => typeof modelType === 'string' && ['layers', 'graph',].includes(modelType);
export class ModelDefinitionValidationError extends Error {
type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE;

constructor(type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE) {
super(type);
this.type = type;
}
}

export const isValidModelDefinition = (modelDefinition?: ModelDefinition): modelDefinition is ModelDefinition => {
if (modelDefinition === undefined) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
}
if (!isValidModelType(modelDefinition.modelType ?? 'layers')) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE);
}
if (!modelDefinition.path && !modelDefinition._internals?.path) {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH);
}
return true;
};

export const hasValidChannels = (tensor: tf.Tensor): boolean => tensor.shape.slice(-1)[0] === 3;

Expand Down
6 changes: 0 additions & 6 deletions packages/shared/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,3 @@ export type ModelDefinitionFn = (tf: TF) => ModelDefinition;
export type ModelDefinitionObjectOrFn = ModelDefinitionFn | ModelDefinition;

export type IsTensor<T extends tf.Tensor> = (pixels: Tensor) => pixels is T;

export enum MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE {
UNDEFINED = 'undefined',
INVALID_MODEL_TYPE = 'invalidModelType',
MISSING_PATH = 'missingPath',
}
38 changes: 15 additions & 23 deletions packages/upscalerjs/src/browser/loadModel.browser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@ import {
import * as tf from '@tensorflow/tfjs-node';

import {
getModelDefinitionError,
ERROR_MODEL_DEFINITION_BUG,
} from '../shared/errors-and-warnings';
import {
ModelDefinition,
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from '../../../shared/src/types';

import {
ModelDefinitionValidationError,
isValidModelDefinition,
} from '../../../shared/src/constants';
checkModelDefinition,
} from '../shared/utils';

import type * as sharedConstants from '../../../shared/src/constants';
import type * as sharedUtils from '../shared/utils';
import type * as modelUtils from '../shared/model-utils';
import type * as errorsAndWarnings from '../shared/errors-and-warnings';
import type * as loadModelBrowser from './loadModel.browser';
Expand All @@ -47,18 +44,17 @@ vi.mock('../shared/model-utils', async () => {
});

vi.mock('../shared/errors-and-warnings', async () => {
const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings;
const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings;
return {
...rest,
getModelDefinitionError: vi.fn(getModelDefinitionError),
}
});

vi.mock('../../../shared/src/constants', async () => {
const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants;
vi.mock('../shared/utils', async () => {
const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils;
return {
...rest,
isValidModelDefinition: vi.fn(isValidModelDefinition),
checkModelDefinition: vi.fn(checkModelDefinition),
}
});

Expand Down Expand Up @@ -202,22 +198,18 @@ describe('loadModel browser tests', () => {
});

describe('loadModel', () => {
it('throws if not a valid model definition', async () => {
const e = new Error(ERROR_MODEL_DEFINITION_BUG);
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
it('throws if given a bad model definition', async () => {
vi.mocked(checkModelDefinition).mockImplementation(() => {
throw new Error();
});
vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e);

await expect(() => loadModel(tf, Promise.resolve({
path: 'foo',
scale: 2,
modelType: 'layers',
}))).rejects.toThrowError(e);
await expect(loadModel(tf, Promise.resolve({}) as Promise<ModelDefinition>))
.rejects
.toThrow();
});

it('loads a valid layers model successfully', async () => {
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true);
vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true);
const model = 'foo' as unknown as LayersModel;
vi.mocked(loadTfModel).mockImplementation(async () => model);
expect(loadTfModel).toHaveBeenCalledTimes(0);
Expand All @@ -240,7 +232,7 @@ describe('loadModel browser tests', () => {
});

it('loads a valid graph model successfully', async () => {
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true);
vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true);
const model = 'foo' as unknown as GraphModel;

const modelDefinition: ModelDefinition = {
Expand Down
25 changes: 6 additions & 19 deletions packages/upscalerjs/src/browser/loadModel.browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@ import {
} from '../shared/model-utils';
import {
ERROR_MODEL_DEFINITION_BUG,
getModelDefinitionError,
} from '../shared/errors-and-warnings';
import type {
TF,
} from '../../../shared/src/types';
import {
isValidModelDefinition,
} from '../../../shared/src/constants';
import {
errIsModelDefinitionValidationError,
} from '../shared/utils';
import { checkModelDefinition, } from '../shared/utils.js';

type CDN = 'jsdelivr' | 'unpkg';

Expand All @@ -38,7 +32,7 @@ export const CDNS: CDN[] = [
export const getLoadModelErrorMessage = (errs: Errors, modelPath: string, internals: ModelConfigurationInternals): Error => new Error([
`Could not resolve URL ${modelPath} for package ${internals?.name}@${internals?.version}`,
'Errors include:',
...errs.map(([cdn, err, ]) => `- ${cdn}: ${err.message}`),
...errs.map(([cdn, err,]) => `- ${cdn}: ${err.message}`),
].join('\n'));

export async function fetchModel<M extends ModelType, R = M extends 'graph' ? GraphModel : LayersModel>(tf: TF, modelConfiguration: {
Expand All @@ -50,7 +44,7 @@ export async function fetchModel<M extends ModelType, R = M extends 'graph' ? Gr
}
if (!_internals) {
// This should never happen. This should have been caught by isValidModelDefinition.
throw new Error(ERROR_MODEL_DEFINITION_BUG);
throw ERROR_MODEL_DEFINITION_BUG('Missing internals');
}
const errs: Errors = [];
for (const cdn of CDNS) {
Expand All @@ -60,23 +54,16 @@ export async function fetchModel<M extends ModelType, R = M extends 'graph' ? Gr
return await loadTfModel(tf, url, modelType);
} catch (err: unknown) {
// there was an issue with the CDN, try another
errs.push([cdn, err instanceof Error ? err : new Error(`There was an unknown error: ${JSON.stringify(err)}`), ]);
errs.push([cdn, err instanceof Error ? err : new Error(`There was an unknown error: ${JSON.stringify(err)}`),]);
}
}
throw getLoadModelErrorMessage(errs, modelPath || _internals.path, _internals);
}

export const loadModel: LoadModel<TF> = async (tf, _modelDefinition) => {
const modelDefinition = await _modelDefinition;

try {
isValidModelDefinition(modelDefinition);
} catch (err: unknown) {
if (errIsModelDefinitionValidationError(err)) {
throw getModelDefinitionError(err.type, modelDefinition);
}
throw new Error(ERROR_MODEL_DEFINITION_BUG);
}

checkModelDefinition(modelDefinition);

const parsedModelDefinition = parseModelDefinition(modelDefinition);

Expand Down
Loading

0 comments on commit 516a2c7

Please sign in to comment.