Skip to content

Commit

Permalink
fix(tfjs-runtime): attempt to fix the memory leak
Browse files Browse the repository at this point in the history
Signed-off-by: Bill ZHANG <[email protected]>
  • Loading branch information
Lutra-Fs committed Sep 9, 2023
1 parent 81f49e9 commit 364c3d7
Showing 1 changed file with 115 additions and 45 deletions.
160 changes: 115 additions & 45 deletions src/services/TfjsService.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import * as tf from '@tensorflow/tfjs';
import { type Vector2 } from 'three';
import type Model from './model';
import type ModelService from './modelService';

export class TfjsService {
export class TfjsService implements ModelService {
model!: tf.GraphModel;
gridSize: [number, number];
batchSize: number;
channelSize: number;
outputChannelSize: number;
mass: tf.Tensor;
mass!: tf.Tensor;
fpsLimit: number;
density: tf.Tensor;
velocity: tf.Tensor;
pressure: tf.TensorBuffer<tf.Rank.R4>;
density!: tf.Variable<tf.Rank.R4>;
velocity!: tf.Variable<tf.Rank.R4>;
pressure!: tf.TensorBuffer<tf.Rank.R4>;

isPaused: boolean;
curFrameCountbyLastSecond: number;
Expand All @@ -27,9 +27,6 @@ export class TfjsService {
this.mass = tf.variable(tf.zeros([0]));
this.fpsLimit = 30;
this.curFrameCountbyLastSecond = 0;
this.density = tf.zeros([0, 0, 0, 0]);
this.velocity = tf.zeros([0, 0, 0, 0]);
this.pressure = tf.buffer([0, 0, 0, 0]);
}

async createService(
Expand All @@ -40,39 +37,91 @@ export class TfjsService {
outputChannelSize = 3,
fpsLimit = 15,
): Promise<TfjsService> {
this.model = await tf.loadGraphModel(modelPath);
this.gridSize = gridSize;
this.batchSize = batchSize;
this.channelSize = channelSize;
this.outputChannelSize = outputChannelSize;
this.fpsLimit = fpsLimit;

this.mass = tf.tensor(1.0);
this.density = tf.zeros([batchSize, ...gridSize, 1]);
this.velocity = tf.zeros([batchSize, ...gridSize, 2]);
this.pressure = tf.buffer([batchSize, ...gridSize, 1]);
const service = new TfjsService();
service.model = await tf.loadGraphModel(modelPath);
service.gridSize = gridSize;
service.batchSize = batchSize;
service.channelSize = channelSize;
service.outputChannelSize = outputChannelSize;
service.fpsLimit = fpsLimit;

this.isPaused = false;
return this;
}

async loadJSONFileFromUrl(url: string): Promise<ModelData> {
async loadJSONFileFromUrl(url: string): Promise<void> {
const response = await fetch(url);
const json = (await response.json()) as JSON;
// check if json is valid
if ('density' in json && 'velocity' in json && 'pressure' in json) {
throw new Error('Invalid JSON file');
}
// turn json into ModelData

return json as unknown as ModelData;
this.loadMatrixFromJson(json);
}

loadMatrixFromJson(json: ModelData): void {
this.density = tf.tensor4d(json.density);
this.velocity = tf.tensor4d(json.velocity);
this.pressure = tf.tensor4d(json.pressure).bufferSync();
loadMatrixFromJson(json: JSON /*ModelData*/): void {
const array = json as unknown as number[][][][];
console.log(array);
const arrayTensor = tf.tensor4d(
array,
[this.batchSize, ...this.gridSize, this.channelSize],
'float32',
);
// 0: partial density
// 1, 2: partial velocity
// 3, 4: Pressure
const density = arrayTensor.slice(
[0, 0, 0, 0],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedDensity = TfjsService.normalizeTensor(density);
density.dispose();
this.density = tf.variable(normalizedDensity.maximum(0));
const velocityX = arrayTensor.slice(
[0, 0, 0, 1],
[this.batchSize, ...this.gridSize, 1],
);
const velocityY = arrayTensor.slice(
[0, 0, 0, 2],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedVelocityX = TfjsService.normalizeTensor(velocityX);
const normalizedVelocityY = TfjsService.normalizeTensor(velocityY);
velocityX.dispose();
velocityY.dispose();
this.velocity = tf.variable(
tf.concat([normalizedVelocityX, normalizedVelocityY], 3),
) as tf.Variable<tf.Rank.R4>;
normalizedVelocityX.dispose();
normalizedVelocityY.dispose();
const pressureX = arrayTensor.slice(
[0, 0, 0, 3],
[this.batchSize, ...this.gridSize, 1],
);
const pressureY = arrayTensor.slice(
[0, 0, 0, 4],
[this.batchSize, ...this.gridSize, 1],
);
const normalizedPressureX = TfjsService.normalizeTensor(pressureX);
const normalizedPressureY = TfjsService.normalizeTensor(pressureY);
pressureX.dispose();
pressureY.dispose();
this.pressure = tf
.concat([normalizedPressureX, normalizedPressureY], 3)
.bufferSync() as tf.TensorBuffer<tf.Rank.R4>;
normalizedPressureX.dispose();

this.density = this.density.maximum(0);
this.mass = this.density.sum();
this.mass.print();
}

static normalizeTensor(tensor: tf.Tensor): tf.Tensor {
return tf.tidy(() => {
const { mean, variance } = tf.moments(tensor);
return tensor.sub(mean).div(variance.sqrt());
});
}

pauseSimulation(): void {
Expand Down Expand Up @@ -101,10 +150,10 @@ export class TfjsService {
}, 1000);
}
getInput(): tf.Tensor<tf.Rank> {
return tf.concat(
[this.density, this.velocity, this.pressure.toTensor()],
3,
);
const pressure = this.pressure.toTensor();
const input = tf.concat([this.density, this.velocity, pressure], 3);
pressure.dispose();
return input;
}
private iterate(): void {
if (this.isPaused) {
Expand All @@ -115,25 +164,40 @@ export class TfjsService {
const energy = this.velocity.square().sum();
const output = this.model?.predict(input) as tf.Tensor<tf.Rank>;
// update density, velocity
this.density = output?.slice(
[0, 0, 0, 0],
[this.batchSize, ...this.gridSize, 1],
) as tf.Tensor4D;
this.velocity = output?.slice(
[0, 0, 0, 1],
[this.batchSize, ...this.gridSize, 2],
) as tf.Tensor4D;
this.density.assign(
output?.slice(
[0, 0, 0, 0],
[this.batchSize, ...this.gridSize, 1],
) as tf.Tensor4D,
);
this.velocity.assign(
output?.slice(
[0, 0, 0, 1],
[this.batchSize, ...this.gridSize, 2],
) as tf.Tensor4D,
);
// update density, velocity
const newEnergy = this.velocity.square().sum();
const energyScale = energy.div(newEnergy);
energyScale.print();

this.velocity = this.velocity.mul(energyScale.sqrt());
const newMass = this.density.sum();
const massScale = this.mass.div(newMass);
this.density = this.density.mul(massScale);
massScale.print();
newMass.dispose();
newEnergy.dispose();
energy.dispose();
energyScale.dispose();

this.outputCallback(output?.dataSync() as Float32Array);
output.dispose();
// set timeout to 0 to allow other tasks to run, like pause and apply force
setTimeout(() => {
this.iterate();
this.curFrameCountbyLastSecond += 1;
console.log(this.curFrameCountbyLastSecond);
void this.iterate();
}, 0);
}

Expand All @@ -153,9 +217,15 @@ export class TfjsService {
4,
);
}
}
interface ModelData {
density: number[][][][];
velocity: number[][][][];
pressure: number[][][][];
getInputTensor(): Float32Array {
const input = this.getInput();
const data = input.dataSync();
input.dispose();
return data as Float32Array;
}
dispose(): void {
this.density.dispose();
this.velocity.dispose();
this.model.dispose();
}
}

0 comments on commit 364c3d7

Please sign in to comment.