Skip to content

Commit

Permalink
feat(model-runtime): add tfjs as a backend (#88)
Browse files Browse the repository at this point in the history
* feat(model): add tfjs bno model

This commit adds a tensorflowjs format bno model which is originally developed in pytorch
and has the following export path: pytorch->ONNX-> tensorflow SavedModel-> tensorflowjs.
Note that the model has not been test yet.
This commit also include @tensorflow/tfjs as a dependency.

* refactor(modelService): rename it to ONNXService

Signed-off-by: Bill ZHANG <[email protected]>

* feat(tfjsService): initialize tensorflowjs backend

Signed-off-by: Bill ZHANG <[email protected]>

* perf(tfjs-runtime): use tensor op instead of array iter

Signed-off-by: Bill ZHANG <[email protected]>

* feat(model): update modelService interface

update it to reflect newest code change for tfjs runtime

Signed-off-by: Bill ZHANG <[email protected]>

* fix(tfjs-runtime): attempt to fix the memory leak

Signed-off-by: Bill ZHANG <[email protected]>

* refactor(modelService): add factory method to create model service

Signed-off-by: Bill ZHANG <[email protected]>

---------

Signed-off-by: Bill ZHANG <[email protected]>
  • Loading branch information
Lutra-Fs authored Sep 14, 2023
1 parent c390a4a commit 774b781
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 80 deletions.
296 changes: 276 additions & 20 deletions package-lock.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"@ant-design/icons": "^5.2.5",
"@react-three/drei": "^9.83.3",
"@react-three/fiber": "^8.13.7",
"@tensorflow/tfjs": "^4.10.0",
"@vitejs/plugin-react": "^4.0.4",
"antd": "^5.8.4",
"onnxruntime-web": "^1.15.1",
Expand Down Expand Up @@ -63,6 +64,6 @@
}
},
"volta": {
"node": "18.15.0"
"node": "18.17.1"
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions public/model/bno_small_new_web/model.json

Large diffs are not rendered by default.

7 changes: 0 additions & 7 deletions src/services/model.ts

This file was deleted.

41 changes: 9 additions & 32 deletions src/services/modelService.ts → src/services/model/ONNXService.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import * as ort from 'onnxruntime-web';
import { type Vector2 } from 'three';
import type Model from './model';
import { type ModelService } from './modelService';

export default class ModelService implements Model {
export default class ONNXService implements ModelService {
session: ort.InferenceSession | null;
gridSize: [number, number];
batchSize: number;
Expand All @@ -20,12 +20,11 @@ export default class ModelService implements Model {
// 1, 2: partial velocity
// 3, 4: Force (currently not used)

private isPaused: boolean;
private curFrameCountbyLastSecond: number;
// hold constructor private to prevent direct instantiation
// ort.InferenceSession.create() is async,
// so we need to use a static async method to create an instance
private isPaused: boolean;
private curFrameCountbyLastSecond: number;

private constructor() {
this.session = null;
this.matrixArray = new Float32Array();
Expand All @@ -44,16 +43,16 @@ export default class ModelService implements Model {
}

// static async method to create an instance
static async createModelService(
static async createService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
channelSize = 5,
outputChannelSize = 3,
fpsLimit = 15,
): Promise<ModelService> {
): Promise<ONNXService> {
console.log('createModelService called');
const modelServices = new ModelService();
const modelServices = new ONNXService();
await modelServices.init(
modelPath,
gridSize,
Expand All @@ -66,19 +65,6 @@ export default class ModelService implements Model {
return modelServices;
}

async initMatrixFromPath(path: string | URL): Promise<void> {
// check if the path is a relative path
if (typeof path === 'string' && !path.startsWith('http')) {
path = new URL(path, import.meta.url);
}
const matrix = await fetch(path).then(
async (res) => (await res.json()) as number[][][][],
);
if (matrix == null) {
throw new Error('Cannot fetch matrix from path');
}
this.initMatrixFromArray(matrix);
}

bindOutput(callback: (data: Float32Array) => void): void {
this.outputCallback = callback;
Expand Down Expand Up @@ -128,7 +114,7 @@ export default class ModelService implements Model {
this.outputSize = batchSize * gridSize[0] * gridSize[1] * outputChannelSize;
}

private initMatrixFromArray(data: number[][][][]): void {
loadDataArray(data: number[][][][]): void {
console.log(
'🚀 ~ file: modelService.ts:132 ~ ModelService ~ initMatrixFromJSON ~ data:',
data,
Expand Down Expand Up @@ -389,16 +375,7 @@ export default class ModelService implements Model {
private roundFloat(value: number, decimal = 4): number {
return Math.round(value * 10 ** decimal) / 10 ** decimal;
}
getFullMatrix(): Float32Array {
getInputTensor(): Float32Array {
return this.matrixArray;
}
getDensity(): Float32Array {
return this.matrixMap(this.matrixArray, [0, 1], (v) => v);
}
getVelocity(): Float32Array {
return this.matrixMap(this.matrixArray, [1, 3], (v) => v);
}
getForce(): Float32Array {
return this.matrixMap(this.matrixArray, [3, 5], (v) => v);
}
}
218 changes: 218 additions & 0 deletions src/services/model/TfjsService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import * as tf from '@tensorflow/tfjs';
import { type Vector2 } from 'three';
import { type ModelService } from './modelService';

export class TfjsService implements ModelService {
model!: tf.GraphModel;
gridSize: [number, number];
batchSize: number;
channelSize: number;
outputChannelSize: number;
mass!: tf.Tensor;
fpsLimit: number;
density!: tf.Variable<tf.Rank.R4>;
velocity!: tf.Variable<tf.Rank.R4>;
pressure!: tf.TensorBuffer<tf.Rank.R4>;

isPaused: boolean;
curFrameCountbyLastSecond: number;
private outputCallback!: (data: Float32Array) => void;

constructor() {
this.gridSize = [0, 0];
this.batchSize = 0;
this.isPaused = true;
this.channelSize = 0;
this.outputChannelSize = 0;
this.mass = tf.variable(tf.zeros([0]));
this.fpsLimit = 30;
this.curFrameCountbyLastSecond = 0;
}

static async createService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
channelSize = 5,
outputChannelSize = 3,
fpsLimit = 15,
): Promise<TfjsService> {
const service = new TfjsService();
service.model = await tf.loadGraphModel(modelPath);
service.gridSize = gridSize;
service.batchSize = batchSize;
service.channelSize = channelSize;
service.outputChannelSize = outputChannelSize;
service.fpsLimit = fpsLimit;

return service;
}

loadDataArray(array: number[][][][]): void {
console.log(array);
const arrayTensor = tf.tensor4d(
array,
[this.batchSize, ...this.gridSize, this.channelSize],
'float32',
);
// 0: partial density
// 1, 2: partial velocity
// 3, 4: Pressure
const density = arrayTensor.slice(
[0, 0, 0, 0],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedDensity = TfjsService.normalizeTensor(density);
density.dispose();
this.density = tf.variable(normalizedDensity.maximum(0));
const velocityX = arrayTensor.slice(
[0, 0, 0, 1],
[this.batchSize, ...this.gridSize, 1],
);
const velocityY = arrayTensor.slice(
[0, 0, 0, 2],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedVelocityX = TfjsService.normalizeTensor(velocityX);
const normalizedVelocityY = TfjsService.normalizeTensor(velocityY);
velocityX.dispose();
velocityY.dispose();
this.velocity = tf.variable(
tf.concat([normalizedVelocityX, normalizedVelocityY], 3),
) as tf.Variable<tf.Rank.R4>;
normalizedVelocityX.dispose();
normalizedVelocityY.dispose();
const pressureX = arrayTensor.slice(
[0, 0, 0, 3],
[this.batchSize, ...this.gridSize, 1],
);
const pressureY = arrayTensor.slice(
[0, 0, 0, 4],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedPressureX = TfjsService.normalizeTensor(pressureX);
const normalizedPressureY = TfjsService.normalizeTensor(pressureY);
pressureX.dispose();
pressureY.dispose();
this.pressure = tf
.concat([normalizedPressureX, normalizedPressureY], 3)
.bufferSync() as tf.TensorBuffer<tf.Rank.R4>;
normalizedPressureX.dispose();

this.density = this.density.maximum(0);
this.mass = this.density.sum();
this.mass.print();
}

static normalizeTensor(tensor: tf.Tensor): tf.Tensor {
return tf.tidy(() => {
const { mean, variance } = tf.moments(tensor);
return tensor.sub(mean).div(variance.sqrt());
});
}

pauseSimulation(): void {
this.isPaused = true;
}

bindOutput(callback: (data: Float32Array) => void): void {
this.outputCallback = callback;
}

startSimulation(): void {
this.isPaused = false;
this.curFrameCountbyLastSecond = 0;
this.fpsHeartbeat();
this.iterate();
}

private fpsHeartbeat(): void {
setTimeout(() => {
this.curFrameCountbyLastSecond = 0;
if (this.curFrameCountbyLastSecond >= this.fpsLimit) {
this.startSimulation();
} else {
this.fpsHeartbeat();
}
}, 1000);
}
getInput(): tf.Tensor<tf.Rank> {
const pressure = this.pressure.toTensor();
const input = tf.concat([this.density, this.velocity, pressure], 3);
pressure.dispose();
return input;
}
private iterate(): void {
if (this.isPaused) {
return;
}
this.curFrameCountbyLastSecond += 1;
const input = this.getInput();
const energy = this.velocity.square().sum();
const output = this.model?.predict(input) as tf.Tensor<tf.Rank>;
// update density, velocity
this.density.assign(
output?.slice(
[0, 0, 0, 0],
[this.batchSize, ...this.gridSize, 1],
) as tf.Tensor4D,
);
this.velocity.assign(
output?.slice(
[0, 0, 0, 1],
[this.batchSize, ...this.gridSize, 2],
) as tf.Tensor4D,
);
// update density, velocity
const newEnergy = this.velocity.square().sum();
const energyScale = energy.div(newEnergy);
energyScale.print();

this.velocity = this.velocity.mul(energyScale.sqrt());
const newMass = this.density.sum();
const massScale = this.mass.div(newMass);
this.density = this.density.mul(massScale);
massScale.print();
newMass.dispose();
newEnergy.dispose();
energy.dispose();
energyScale.dispose();

this.outputCallback(output?.dataSync() as Float32Array);
output.dispose();
// set timeout to 0 to allow other tasks to run, like pause and apply force
setTimeout(() => {
this.curFrameCountbyLastSecond += 1;
console.log(this.curFrameCountbyLastSecond);
this.iterate();
}, 0);
}

updateForce(pos: Vector2, forceDelta: Vector2, batchIndex = 0): void {
this.pressure.set(
this.pressure.get(batchIndex, pos.x, pos.y, 0) + forceDelta.x,
batchIndex,
pos.x,
pos.y,
3,
);
this.pressure.set(
this.pressure.get(batchIndex, pos.x, pos.y, 1) + forceDelta.y,
batchIndex,
pos.x,
pos.y,
4,
);
}
getInputTensor(): Float32Array {
const input = this.getInput();
const data = input.dataSync();
input.dispose();
return data as Float32Array;
}
dispose(): void {
this.density.dispose();
this.velocity.dispose();
this.model.dispose();
}
}
48 changes: 48 additions & 0 deletions src/services/model/modelService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { type Vector2 } from 'three';
import { TfjsService } from './TfjsService';
import ONNXService from './ONNXService';

export interface ModelService {
startSimulation: () => void;
pauseSimulation: () => void;
bindOutput: (callback: (data: Float32Array) => void) => void;
getInputTensor: () => Float32Array;
updateForce: (pos: Vector2, forceDelta: Vector2) => void;
loadDataArray: (array: number[][][][]) => void;
}

// a simple factory function to create a model service
export async function createModelService(
modelPath: string,
gridSize: [number, number] = [64, 64],
batchSize = 1,
channelSize = 5,
outputChannelSize = 3,
fpsLimit = 15,
): Promise<ModelService> {
// detect the model type
// TODO: read the model type from the model definition file
const modelType = modelPath.split('.').pop();
switch (modelType) {
case 'json':
return TfjsService.createService(
modelPath,
gridSize,
batchSize,
channelSize,
outputChannelSize,
fpsLimit,
);
case 'onnx':
return ONNXService.createService(
modelPath,
gridSize,
batchSize,
channelSize,
outputChannelSize,
fpsLimit,
);
default:
throw new Error('Invalid model type');
}
}
Loading

0 comments on commit 774b781

Please sign in to comment.