diff --git a/src/services/ONNXService.ts b/src/services/model/ONNXService.ts similarity index 94% rename from src/services/ONNXService.ts rename to src/services/model/ONNXService.ts index 1e39e7a..d5c26a3 100644 --- a/src/services/ONNXService.ts +++ b/src/services/model/ONNXService.ts @@ -1,6 +1,6 @@ import * as ort from 'onnxruntime-web'; import { type Vector2 } from 'three'; -import type ModelService from './modelService'; +import { type ModelService } from './modelService'; export default class ONNXService implements ModelService { session: ort.InferenceSession | null; @@ -43,7 +43,7 @@ export default class ONNXService implements ModelService { } // static async method to create an instance - static async createModelService( + static async createService( modelPath: string, gridSize: [number, number] = [64, 64], batchSize = 1, @@ -65,19 +65,6 @@ export default class ONNXService implements ModelService { return modelServices; } - async loadJSONFileFromUrl(path: string | URL): Promise { - // check if the path is a relative path - if (typeof path === 'string' && !path.startsWith('http')) { - path = new URL(path, import.meta.url); - } - const matrix = await fetch(path).then( - async (res) => (await res.json()) as number[][][][], - ); - if (matrix == null) { - throw new Error('Cannot fetch matrix from path'); - } - this.initMatrixFromArray(matrix); - } bindOutput(callback: (data: Float32Array) => void): void { this.outputCallback = callback; @@ -127,7 +114,7 @@ export default class ONNXService implements ModelService { this.outputSize = batchSize * gridSize[0] * gridSize[1] * outputChannelSize; } - private initMatrixFromArray(data: number[][][][]): void { + loadDataArray(data: number[][][][]): void { console.log( '🚀 ~ file: modelService.ts:132 ~ ModelService ~ initMatrixFromJSON ~ data:', data, diff --git a/src/services/TfjsService.ts b/src/services/model/TfjsService.ts similarity index 90% rename from src/services/TfjsService.ts rename to src/services/model/TfjsService.ts index 0606e48..12bbb0d 100644 --- a/src/services/TfjsService.ts +++ b/src/services/model/TfjsService.ts @@ -1,6 +1,6 @@ import * as tf from '@tensorflow/tfjs'; import { type Vector2 } from 'three'; -import type ModelService from './modelService'; +import { type ModelService } from './modelService'; export class TfjsService implements ModelService { model!: tf.GraphModel; @@ -29,7 +29,7 @@ export class TfjsService implements ModelService { this.curFrameCountbyLastSecond = 0; } - async createService( + static async createService( modelPath: string, gridSize: [number, number] = [64, 64], batchSize = 1, @@ -45,23 +45,10 @@ export class TfjsService implements ModelService { service.outputChannelSize = outputChannelSize; service.fpsLimit = fpsLimit; - this.isPaused = false; - return this; - } - - async loadJSONFileFromUrl(url: string): Promise { - const response = await fetch(url); - const json = (await response.json()) as JSON; - // check if json is valid - if ('density' in json && 'velocity' in json && 'pressure' in json) { - throw new Error('Invalid JSON file'); - } - // turn json into ModelData - this.loadMatrixFromJson(json); + return service; } - loadMatrixFromJson(json: JSON /*ModelData*/): void { - const array = json as unknown as number[][][][]; + loadDataArray(array: number[][][][]): void { console.log(array); const arrayTensor = tf.tensor4d( array, @@ -197,7 +184,7 @@ export class TfjsService implements ModelService { setTimeout(() => { this.curFrameCountbyLastSecond += 1; console.log(this.curFrameCountbyLastSecond); - void this.iterate(); + this.iterate(); }, 0); } diff --git a/src/services/model/modelService.ts b/src/services/model/modelService.ts new file mode 100644 index 0000000..8da185a --- /dev/null +++ b/src/services/model/modelService.ts @@ -0,0 +1,48 @@ +import { type Vector2 } from 'three'; +import { TfjsService } from './TfjsService'; +import ONNXService from './ONNXService'; + +export interface ModelService { + startSimulation: () => void; + pauseSimulation: () => void; + bindOutput: (callback: (data: Float32Array) => void) => void; + getInputTensor: () => Float32Array; + updateForce: (pos: Vector2, forceDelta: Vector2) => void; + loadDataArray: (array: number[][][][]) => void; +} + +// a simple factory function to create a model service +export async function createModelService( + modelPath: string, + gridSize: [number, number] = [64, 64], + batchSize = 1, + channelSize = 5, + outputChannelSize = 3, + fpsLimit = 15, +): Promise { + // detect the model type + // TODO: read the model type from the model definition file + const modelType = modelPath.split('.').pop(); + switch (modelType) { + case 'json': + return TfjsService.createService( + modelPath, + gridSize, + batchSize, + channelSize, + outputChannelSize, + fpsLimit, + ); + case 'onnx': + return ONNXService.createService( + modelPath, + gridSize, + batchSize, + channelSize, + outputChannelSize, + fpsLimit, + ); + default: + throw new Error('Invalid model type'); + } +} diff --git a/src/services/modelService.ts b/src/services/modelService.ts deleted file mode 100644 index c2b1e9b..0000000 --- a/src/services/modelService.ts +++ /dev/null @@ -1,10 +0,0 @@ -import { type Vector2 } from 'three'; - -export default interface ModelService { - startSimulation: () => void; - pauseSimulation: () => void; - bindOutput: (callback: (data: Float32Array) => void) => void; - getInputTensor: () => Float32Array; - updateForce: (pos: Vector2, forceDelta: Vector2) => void; - loadJSONFileFromUrl: (url: string) => Promise; -} diff --git a/src/workers/modelWorker.ts b/src/workers/modelWorker.ts index 07ac1f4..9daf10c 100644 --- a/src/workers/modelWorker.ts +++ b/src/workers/modelWorker.ts @@ -1,7 +1,10 @@ // a worker that can control the modelService via messages import { Vector2 } from 'three'; -import ModelService from '../services/modelService'; +import { + ModelService, + createModelService, +} from '../services/model/modelService'; import { IncomingMessage } from './modelWorkerMessage'; let modelService: ModelService | null = null; @@ -51,22 +54,13 @@ export function onmessage( case 'updateForce': updateForce(data.args as UpdateForceArgs); break; - case 'getFullMatrix': + case 'getInputTensor': if (modelService == null) { throw new Error('modelService is null'); } this.postMessage({ - type: 'fullMatrix', - matrix: modelService.getFullMatrix(), - }); - break; - case 'getDensity': - if (modelService == null) { - throw new Error('modelService is null'); - } - this.postMessage({ - type: 'density', - density: modelService.getDensity(), + type: 'inputTensor', + tensor: modelService.getInputTensor(), }); break; default: @@ -86,7 +80,10 @@ async function initModelService( event: DedicatedWorkerGlobalScope, ): Promise { const modelPath = '/model/bno_small_001.onnx'; - const dataPath = new URL('/initData/pvf_incomp_44_nonneg/pvf_incomp_44_nonneg_0.json', import.meta.url); + const dataPath = new URL( + '/initData/pvf_incomp_44_nonneg/pvf_incomp_44_nonneg_0.json', + import.meta.url, + ); const outputCallback = (output: Float32Array): void => { const density = new Float32Array(output.length / 3); for (let i = 0; i < density.length; i++) { @@ -94,12 +91,12 @@ async function initModelService( } event.postMessage({ type: 'output', density }); }; - const modelService = await ModelService.createModelService( - modelPath, - [64, 64], - 1, - ); + const modelService = await createModelService(modelPath, [64, 64], 1); modelService.bindOutput(outputCallback); - await modelService.initMatrixFromPath(dataPath); + // fetch the data + const data = (await fetch(dataPath).then((res) => + res.json(), + )) as number[][][][]; + modelService.loadDataArray(data); return modelService; }