diff --git a/src/components/Simulation.tsx b/src/components/Simulation.tsx index d3f8aaa..7e7b5d2 100644 --- a/src/components/Simulation.tsx +++ b/src/components/Simulation.tsx @@ -65,7 +65,7 @@ function DiffusionPlane( // INITIALISATION // WebGPU capability test - if (WebGPU.isAvailable() === true) { + if (WebGPU.isAvailable()) { const webgpuRenderer = new WebGPURenderer({ antialias: true }); console.log('browser supports webgpu rendering'); console.log('webgpu renderer context', webgpuRenderer); diff --git a/src/services/model/ONNXService.ts b/src/services/model/ONNXService.ts index d5c26a3..e70744b 100644 --- a/src/services/model/ONNXService.ts +++ b/src/services/model/ONNXService.ts @@ -375,6 +375,7 @@ export default class ONNXService implements ModelService { private roundFloat(value: number, decimal = 4): number { return Math.round(value * 10 ** decimal) / 10 ** decimal; } + getInputTensor(): Float32Array { return this.matrixArray; } diff --git a/src/services/model/TfjsService.ts b/src/services/model/TfjsService.ts index 12bbb0d..cb5c333 100644 --- a/src/services/model/TfjsService.ts +++ b/src/services/model/TfjsService.ts @@ -136,12 +136,14 @@ export class TfjsService implements ModelService { } }, 1000); } + getInput(): tf.Tensor { const pressure = this.pressure.toTensor(); const input = tf.concat([this.density, this.velocity, pressure], 3); pressure.dispose(); return input; } + private iterate(): void { if (this.isPaused) { return; @@ -204,12 +206,14 @@ export class TfjsService implements ModelService { 4, ); } + getInputTensor(): Float32Array { const input = this.getInput(); const data = input.dataSync(); input.dispose(); return data as Float32Array; } + dispose(): void { this.density.dispose(); this.velocity.dispose(); diff --git a/src/services/model/modelService.ts b/src/services/model/modelService.ts index 8da185a..35e75e6 100644 --- a/src/services/model/modelService.ts +++ b/src/services/model/modelService.ts @@ -25,7 +25,7 @@ export async function createModelService( const modelType = modelPath.split('.').pop(); switch (modelType) { case 'json': - return TfjsService.createService( + return await TfjsService.createService( modelPath, gridSize, batchSize, @@ -34,7 +34,7 @@ export async function createModelService( fpsLimit, ); case 'onnx': - return ONNXService.createService( + return await ONNXService.createService( modelPath, gridSize, batchSize, diff --git a/src/workers/modelWorker.ts b/src/workers/modelWorker.ts index 20eef32..d75e437 100644 --- a/src/workers/modelWorker.ts +++ b/src/workers/modelWorker.ts @@ -2,7 +2,7 @@ import { type Vector2 } from 'three'; import { - ModelService, + type ModelService, createModelService, } from '../services/model/modelService'; import { type IncomingMessage } from './modelWorkerMessage'; @@ -94,8 +94,8 @@ async function initModelService( const modelService = await createModelService(modelPath, [64, 64], 1); modelService.bindOutput(outputCallback); // fetch the data - const data = (await fetch(dataPath).then((res) => - res.json(), + const data = (await fetch(dataPath).then(async (res) => + await res.json(), )) as number[][][][]; modelService.loadDataArray(data); return modelService;