Skip to content

Commit

Permalink
refactor(modelService): add factory method to create model service
Browse files Browse the repository at this point in the history
Signed-off-by: Bill ZHANG <[email protected]>
  • Loading branch information
Lutra-Fs committed Sep 11, 2023
1 parent 364c3d7 commit 216a7a9
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 64 deletions.
19 changes: 3 additions & 16 deletions src/services/ONNXService.ts → src/services/model/ONNXService.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as ort from 'onnxruntime-web';
import { type Vector2 } from 'three';
import type ModelService from './modelService';
import { type ModelService } from './modelService';

export default class ONNXService implements ModelService {
session: ort.InferenceSession | null;
Expand Down Expand Up @@ -43,7 +43,7 @@ export default class ONNXService implements ModelService {
}

// static async method to create an instance
static async createModelService(
static async createService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
Expand All @@ -65,19 +65,6 @@ export default class ONNXService implements ModelService {
return modelServices;
}

async loadJSONFileFromUrl(path: string | URL): Promise<void> {
// check if the path is a relative path
if (typeof path === 'string' && !path.startsWith('http')) {
path = new URL(path, import.meta.url);
}
const matrix = await fetch(path).then(
async (res) => (await res.json()) as number[][][][],
);
if (matrix == null) {
throw new Error('Cannot fetch matrix from path');
}
this.initMatrixFromArray(matrix);
}

bindOutput(callback: (data: Float32Array) => void): void {
this.outputCallback = callback;
Expand Down Expand Up @@ -127,7 +114,7 @@ export default class ONNXService implements ModelService {
this.outputSize = batchSize * gridSize[0] * gridSize[1] * outputChannelSize;
}

private initMatrixFromArray(data: number[][][][]): void {
loadDataArray(data: number[][][][]): void {
console.log(
'🚀 ~ file: modelService.ts:132 ~ ModelService ~ initMatrixFromJSON ~ data:',
data,
Expand Down
23 changes: 5 additions & 18 deletions src/services/TfjsService.ts → src/services/model/TfjsService.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as tf from '@tensorflow/tfjs';
import { type Vector2 } from 'three';
import type ModelService from './modelService';
import { type ModelService } from './modelService';

export class TfjsService implements ModelService {
model!: tf.GraphModel;
Expand Down Expand Up @@ -29,7 +29,7 @@ export class TfjsService implements ModelService {
this.curFrameCountbyLastSecond = 0;
}

async createService(
static async createService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
Expand All @@ -45,23 +45,10 @@ export class TfjsService implements ModelService {
service.outputChannelSize = outputChannelSize;
service.fpsLimit = fpsLimit;

this.isPaused = false;
return this;
}

async loadJSONFileFromUrl(url: string): Promise<void> {
const response = await fetch(url);
const json = (await response.json()) as JSON;
// check if json is valid
if ('density' in json && 'velocity' in json && 'pressure' in json) {
throw new Error('Invalid JSON file');
}
// turn json into ModelData
this.loadMatrixFromJson(json);
return service;
}

loadMatrixFromJson(json: JSON /*ModelData*/): void {
const array = json as unknown as number[][][][];
loadDataArray(array: number[][][][]): void {
console.log(array);
const arrayTensor = tf.tensor4d(
array,
Expand Down Expand Up @@ -197,7 +184,7 @@ export class TfjsService implements ModelService {
setTimeout(() => {
this.curFrameCountbyLastSecond += 1;
console.log(this.curFrameCountbyLastSecond);
void this.iterate();
this.iterate();
}, 0);
}

Expand Down
48 changes: 48 additions & 0 deletions src/services/model/modelService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { type Vector2 } from 'three';
import { TfjsService } from './TfjsService';
import ONNXService from './ONNXService';

export interface ModelService {
startSimulation: () => void;
pauseSimulation: () => void;
bindOutput: (callback: (data: Float32Array) => void) => void;
getInputTensor: () => Float32Array;
updateForce: (pos: Vector2, forceDelta: Vector2) => void;
loadDataArray: (array: number[][][][]) => void;
}

// a simple factory function to create a model service
export async function createModelService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
channelSize = 5,
outputChannelSize = 3,
fpsLimit = 15,
): Promise<ModelService> {
// detect the model type
// TODO: read the model type from the model definition file
const modelType = modelPath.split('.').pop();
switch (modelType) {
case 'json':
return TfjsService.createService(
modelPath,
gridSize,
batchSize,
channelSize,
outputChannelSize,
fpsLimit,
);
case 'onnx':
return ONNXService.createService(
modelPath,
gridSize,
batchSize,
channelSize,
outputChannelSize,
fpsLimit,
);
default:
throw new Error('Invalid model type');
}
}
10 changes: 0 additions & 10 deletions src/services/modelService.ts

This file was deleted.

37 changes: 17 additions & 20 deletions src/workers/modelWorker.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// a worker that can control the modelService via messages

import { Vector2 } from 'three';
import ModelService from '../services/modelService';
import {
ModelService,
createModelService,
} from '../services/model/modelService';
import { IncomingMessage } from './modelWorkerMessage';

let modelService: ModelService | null = null;
Expand Down Expand Up @@ -51,22 +54,13 @@ export function onmessage(
case 'updateForce':
updateForce(data.args as UpdateForceArgs);
break;
case 'getFullMatrix':
case 'getInputTensor':
if (modelService == null) {
throw new Error('modelService is null');
}
this.postMessage({
type: 'fullMatrix',
matrix: modelService.getFullMatrix(),
});
break;
case 'getDensity':
if (modelService == null) {
throw new Error('modelService is null');
}
this.postMessage({
type: 'density',
density: modelService.getDensity(),
type: 'inputTensor',
tensor: modelService.getInputTensor(),
});
break;
default:
Expand All @@ -86,20 +80,23 @@ async function initModelService(
event: DedicatedWorkerGlobalScope,
): Promise<ModelService> {
const modelPath = '/model/bno_small_001.onnx';
const dataPath = new URL('/initData/pvf_incomp_44_nonneg/pvf_incomp_44_nonneg_0.json', import.meta.url);
const dataPath = new URL(
'/initData/pvf_incomp_44_nonneg/pvf_incomp_44_nonneg_0.json',
import.meta.url,
);
const outputCallback = (output: Float32Array): void => {
const density = new Float32Array(output.length / 3);
for (let i = 0; i < density.length; i++) {
density[i] = output[i * 3];
}
event.postMessage({ type: 'output', density });
};
const modelService = await ModelService.createModelService(
modelPath,
[64, 64],
1,
);
const modelService = await createModelService(modelPath, [64, 64], 1);
modelService.bindOutput(outputCallback);
await modelService.initMatrixFromPath(dataPath);
// fetch the data
const data = (await fetch(dataPath).then((res) =>
res.json(),
)) as number[][][][];
modelService.loadDataArray(data);
return modelService;
}

0 comments on commit 216a7a9

Please sign in to comment.