Skip to content

Commit

Permalink
Refactor scaffolding (#1155)
Browse files Browse the repository at this point in the history
  • Loading branch information
thekevinscott authored Oct 2, 2023
1 parent d3ea7fa commit 6eb0436
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 106 deletions.
40 changes: 12 additions & 28 deletions packages/upscalerjs/src/loadModel.browser.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { GraphModel, io, LayersModel } from '@tensorflow/tfjs';
import { vi } from 'vitest';
import { tf } from './dependencies.generated';
import {
CDNS,
CDN_PATH_DEFINITIONS,
Expand All @@ -11,6 +10,7 @@ import {
import {
loadTfModel,
} from './model-utils';
import * as tf from '@tensorflow/tfjs-node';

import {
getModelDefinitionError,
Expand All @@ -24,7 +24,6 @@ import {
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from '@upscalerjs/core';

import type * as dependenciesGenerated from './dependencies.generated';
import type * as core from '@upscalerjs/core';
import type * as modelUtils from './model-utils';
import type * as errorsAndWarnings from './errors-and-warnings';
Expand All @@ -41,7 +40,7 @@ vi.mock('./model-utils', async () => {
const { loadTfModel, ...rest } = await vi.importActual('./model-utils') as typeof modelUtils;
return {
...rest,
loadTfModel: vi.fn(loadTfModel),
loadTfModel: vi.fn(),
}
});

Expand All @@ -61,18 +60,6 @@ vi.mock('@upscalerjs/core', async () => {
}
});

vi.mock('./dependencies.generated', async () => {
const { tf, ...rest } = await vi.importActual('./dependencies.generated') as typeof dependenciesGenerated;
return {
...rest,
tf: {
...tf,
loadLayersModel: vi.fn(),
loadGraphModel: vi.fn(),
}
}
});

describe('loadModel browser tests', () => {
afterEach(() => {
vi.clearAllMocks();
Expand All @@ -91,7 +78,7 @@ describe('loadModel browser tests', () => {
version: 'version',
},
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(1);
expect(loadTfModel).toBeCalledWith(tf, 'foo', 'layers');
});
Expand All @@ -107,7 +94,7 @@ describe('loadModel browser tests', () => {
version: 'version',
},
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(1);
expect(loadTfModel).toBeCalledWith(tf, 'foo', 'graph');
});
Expand All @@ -118,7 +105,7 @@ describe('loadModel browser tests', () => {
path: 'foo',
modelType: 'layers',
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(1);
expect(loadTfModel).toBeCalledWith(tf, 'foo', 'layers');
});
Expand All @@ -138,7 +125,7 @@ describe('loadModel browser tests', () => {
},
modelType: 'layers',
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(1);
expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[0]](packageName, version, modelPath), 'layers');
});
Expand All @@ -156,7 +143,7 @@ describe('loadModel browser tests', () => {
},
modelType: 'graph',
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(1);
expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[0]](packageName, version, modelPath), 'graph');
});
Expand All @@ -180,7 +167,7 @@ describe('loadModel browser tests', () => {
},
modelType: 'layers',
};
await fetchModel(modelDefinition);
await fetchModel(tf, modelDefinition);
expect(loadTfModel).toBeCalledTimes(2);
expect(loadTfModel).toBeCalledWith(tf, CDN_PATH_DEFINITIONS[CDNS[1]](packageName, version, modelPath), 'layers');
});
Expand All @@ -201,7 +188,7 @@ describe('loadModel browser tests', () => {
},
modelType: 'layers',
};
await expect(() => fetchModel(modelDefinition))
await expect(() => fetchModel(tf, modelDefinition))
.rejects
.toThrowError(getLoadModelErrorMessage(CDNS.map((cdn, i) => [cdn, new Error(`next: ${i}`)]), modelPath, {
path: modelPath,
Expand All @@ -220,7 +207,7 @@ describe('loadModel browser tests', () => {
});
vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e);

await expect(() => loadModel(Promise.resolve({
await expect(() => loadModel(tf, Promise.resolve({
path: 'foo',
scale: 2,
modelType: 'layers',
Expand All @@ -239,7 +226,7 @@ describe('loadModel browser tests', () => {
modelType: 'layers',
};

const result = await loadModel(Promise.resolve(modelDefinition));
const result = await loadModel(tf, Promise.resolve(modelDefinition));

expect(loadTfModel).toHaveBeenCalledTimes(1);
expect(loadTfModel).toHaveBeenCalledWith(tf, modelDefinition.path, 'layers');
Expand All @@ -253,17 +240,14 @@ describe('loadModel browser tests', () => {
it('loads a valid graph model successfully', async () => {
vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true);
const model = 'foo' as unknown as GraphModel;
tf.loadLayersModel.mockImplementation(async () => 'layers model' as any);
tf.loadGraphModel.mockImplementation(async () => model);
expect(tf.loadLayersModel).toHaveBeenCalledTimes(0);

const modelDefinition: ModelDefinition = {
path: 'foo',
scale: 2,
modelType: 'graph',
};

const result = await loadModel(Promise.resolve(modelDefinition));
const result = await loadModel(tf, Promise.resolve(modelDefinition));

expect(loadTfModel).toHaveBeenCalledTimes(1);
expect(loadTfModel).toHaveBeenCalledWith(tf, modelDefinition.path, 'graph');
Expand Down
11 changes: 7 additions & 4 deletions packages/upscalerjs/src/loadModel.browser.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { tf, } from './dependencies.generated';
import { ModelDefinition, ModelType, ModelConfigurationInternals, } from '@upscalerjs/core';
import type { LayersModel, } from '@tensorflow/tfjs-layers';
import type { ModelDefinition, ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core';
import type { ParsedModelDefinition, ModelPackage, } from './types';
import {
loadTfModel,
Expand All @@ -10,6 +10,7 @@ import {
getModelDefinitionError,
} from './errors-and-warnings';
import {
TF,
isValidModelDefinition,
} from '@upscalerjs/core';
import {
Expand Down Expand Up @@ -38,7 +39,7 @@ export const getLoadModelErrorMessage = (errs: Errors, modelPath: string, intern
...errs.map(([cdn, err, ]) => `- ${cdn}: ${err.message}`),
].join('\n'));

export async function fetchModel<M extends ModelType, R = M extends 'graph' ? tf.GraphModel : tf.LayersModel>(modelConfiguration: {
export async function fetchModel<M extends ModelType, R = M extends 'graph' ? GraphModel : LayersModel>(tf: TF, modelConfiguration: {
modelType?: M;
} & Omit<ParsedModelDefinition, 'modelType'>): Promise<R> {
const { modelType, _internals, path: modelPath, } = modelConfiguration;
Expand All @@ -64,9 +65,11 @@ export async function fetchModel<M extends ModelType, R = M extends 'graph' ? tf
}

export const loadModel = async (
tf: TF,
_modelDefinition: Promise<ModelDefinition>,
): Promise<ModelPackage> => {
const modelDefinition = await _modelDefinition;

try {
isValidModelDefinition(modelDefinition);
} catch (err: unknown) {
Expand All @@ -78,7 +81,7 @@ export const loadModel = async (

const parsedModelDefinition = parseModelDefinition(modelDefinition);

const model = await fetchModel(parsedModelDefinition);
const model = await fetchModel(tf, parsedModelDefinition);

return {
model,
Expand Down
21 changes: 5 additions & 16 deletions packages/upscalerjs/src/loadModel.node.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { vi } from 'vitest';
import path from 'path';
import { resolver, } from './resolver';
import type { ModelDefinition } from "@upscalerjs/core";
import * as tf from '@tensorflow/tfjs-node';
import {
ERROR_MODEL_DEFINITION_BUG,
} from './errors-and-warnings';
Expand All @@ -21,7 +22,6 @@ import {
MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE,
} from '@upscalerjs/core';

import type * as dependenciesGenerated from './dependencies.generated';
import type * as core from '@upscalerjs/core';
import type * as modelUtils from './model-utils';
import type * as errorsAndWarnings from './errors-and-warnings';
Expand All @@ -31,7 +31,7 @@ vi.mock('./model-utils', async () => {
const { loadTfModel, ...rest } = await vi.importActual('./model-utils') as typeof modelUtils;
return {
...rest,
loadTfModel: vi.fn(loadTfModel),
loadTfModel: vi.fn(),
}
});

Expand All @@ -57,17 +57,6 @@ vi.mock('./resolver', async () => {
resolver: vi.fn(resolver),
};
});
vi.mock('./dependencies.generated', async () => {
const { tf, ...rest } = await vi.importActual('./dependencies.generated') as typeof dependenciesGenerated;
return {
...rest,
tf: {
...tf,
loadLayersModel: vi.fn(),
loadGraphModel: vi.fn(),
}
}
});

const getResolver = (fn: () => string) => (fn) as unknown as typeof require.resolve;

Expand Down Expand Up @@ -135,7 +124,7 @@ describe('loadModel.node', () => {
throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED);
});

await expect(loadModel(Promise.resolve({}) as Promise<ModelDefinition>))
await expect(loadModel(tf, Promise.resolve({}) as Promise<ModelDefinition>))
.rejects
.toThrow(error);
});
Expand All @@ -148,7 +137,7 @@ describe('loadModel.node', () => {
const path = 'foo';
const modelDefinition: ModelDefinition = { path, scale: 2, modelType: 'layers' };

const response = await loadModel(Promise.resolve(modelDefinition));
const response = await loadModel(tf, Promise.resolve(modelDefinition));
expect(loadTfModel).toHaveBeenCalledWith(tf, path, 'layers');
expect(response).toEqual({
model: 'layers model',
Expand All @@ -164,7 +153,7 @@ describe('loadModel.node', () => {
const path = 'foo';
const modelDefinition: ModelDefinition = { path, scale: 2, modelType: 'graph' };

const response = await loadModel(Promise.resolve(modelDefinition));
const response = await loadModel(tf, Promise.resolve(modelDefinition));
expect(loadTfModel).toHaveBeenCalledWith(tf, path, 'graph');
expect(response).toEqual({
model: 'graph model',
Expand Down
3 changes: 2 additions & 1 deletion packages/upscalerjs/src/loadModel.node.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import path from 'path';
import { tf, } from './dependencies.generated';
import type { ModelDefinition, } from "@upscalerjs/core";
import { loadTfModel, parseModelDefinition, } from './model-utils';
import { resolver, } from './resolver';
import { ParsedModelDefinition, ModelPackage, } from './types';
import {
isValidModelDefinition,
TF,
} from '@upscalerjs/core';
import {
ERROR_MODEL_DEFINITION_BUG,
Expand Down Expand Up @@ -43,6 +43,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string
};

export const loadModel = async (
tf: TF,
_modelDefinition: Promise<ModelDefinition>,
): Promise<ModelPackage> => {
const modelDefinition = await _modelDefinition;
Expand Down
Loading

0 comments on commit 6eb0436

Please sign in to comment.