diff --git a/src/services/TfjsService.ts b/src/services/TfjsService.ts index 5accda7..0606e48 100644 --- a/src/services/TfjsService.ts +++ b/src/services/TfjsService.ts @@ -1,18 +1,18 @@ import * as tf from '@tensorflow/tfjs'; import { type Vector2 } from 'three'; -import type Model from './model'; +import type ModelService from './modelService'; -export class TfjsService { +export class TfjsService implements ModelService { model!: tf.GraphModel; gridSize: [number, number]; batchSize: number; channelSize: number; outputChannelSize: number; - mass: tf.Tensor; + mass!: tf.Tensor; fpsLimit: number; - density: tf.Tensor; - velocity: tf.Tensor; - pressure: tf.TensorBuffer; + density!: tf.Variable; + velocity!: tf.Variable; + pressure!: tf.TensorBuffer; isPaused: boolean; curFrameCountbyLastSecond: number; @@ -27,9 +27,6 @@ export class TfjsService { this.mass = tf.variable(tf.zeros([0])); this.fpsLimit = 30; this.curFrameCountbyLastSecond = 0; - this.density = tf.zeros([0, 0, 0, 0]); - this.velocity = tf.zeros([0, 0, 0, 0]); - this.pressure = tf.buffer([0, 0, 0, 0]); } async createService( @@ -40,23 +37,19 @@ export class TfjsService { outputChannelSize = 3, fpsLimit = 15, ): Promise { - this.model = await tf.loadGraphModel(modelPath); - this.gridSize = gridSize; - this.batchSize = batchSize; - this.channelSize = channelSize; - this.outputChannelSize = outputChannelSize; - this.fpsLimit = fpsLimit; - - this.mass = tf.tensor(1.0); - this.density = tf.zeros([batchSize, ...gridSize, 1]); - this.velocity = tf.zeros([batchSize, ...gridSize, 2]); - this.pressure = tf.buffer([batchSize, ...gridSize, 1]); + const service = new TfjsService(); + service.model = await tf.loadGraphModel(modelPath); + service.gridSize = gridSize; + service.batchSize = batchSize; + service.channelSize = channelSize; + service.outputChannelSize = outputChannelSize; + service.fpsLimit = fpsLimit; this.isPaused = false; return this; } - async loadJSONFileFromUrl(url: string): Promise { + async loadJSONFileFromUrl(url: string): Promise { const response = await fetch(url); const json = (await response.json()) as JSON; // check if json is valid @@ -64,15 +57,71 @@ export class TfjsService { throw new Error('Invalid JSON file'); } // turn json into ModelData - - return json as unknown as ModelData; + this.loadMatrixFromJson(json); } - loadMatrixFromJson(json: ModelData): void { - this.density = tf.tensor4d(json.density); - this.velocity = tf.tensor4d(json.velocity); - this.pressure = tf.tensor4d(json.pressure).bufferSync(); + loadMatrixFromJson(json: JSON /*ModelData*/): void { + const array = json as unknown as number[][][][]; + console.log(array); + const arrayTensor = tf.tensor4d( + array, + [this.batchSize, ...this.gridSize, this.channelSize], + 'float32', + ); + // 0: partial density + // 1, 2: partial velocity + // 3, 4: Pressure + const density = arrayTensor.slice( + [0, 0, 0, 0], + [this.batchSize, ...this.gridSize, 1], + ); + const normalizedDensity = TfjsService.normalizeTensor(density); + density.dispose(); + this.density = tf.variable(normalizedDensity.maximum(0)); + const velocityX = arrayTensor.slice( + [0, 0, 0, 1], + [this.batchSize, ...this.gridSize, 1], + ); + const velocityY = arrayTensor.slice( + [0, 0, 0, 2], + [this.batchSize, ...this.gridSize, 1], + ); + const normalizedVelocityX = TfjsService.normalizeTensor(velocityX); + const normalizedVelocityY = TfjsService.normalizeTensor(velocityY); + velocityX.dispose(); + velocityY.dispose(); + this.velocity = tf.variable( + tf.concat([normalizedVelocityX, normalizedVelocityY], 3), + ) as tf.Variable; + normalizedVelocityX.dispose(); + normalizedVelocityY.dispose(); + const pressureX = arrayTensor.slice( + [0, 0, 0, 3], + [this.batchSize, ...this.gridSize, 1], + ); + const pressureY = arrayTensor.slice( + [0, 0, 0, 4], + [this.batchSize, ...this.gridSize, 1], + ); + const normalizedPressureX = TfjsService.normalizeTensor(pressureX); + const normalizedPressureY = TfjsService.normalizeTensor(pressureY); + pressureX.dispose(); + pressureY.dispose(); + this.pressure = tf + .concat([normalizedPressureX, normalizedPressureY], 3) + .bufferSync() as tf.TensorBuffer; + normalizedPressureX.dispose(); + + this.density = this.density.maximum(0); this.mass = this.density.sum(); + this.mass.print(); + } + + static normalizeTensor(tensor: tf.Tensor): tf.Tensor { + return tf.tidy(() => { + const { mean, variance } = tf.moments(tensor); + return tensor.sub(mean).div(variance.sqrt()); + }); } pauseSimulation(): void { @@ -101,10 +150,10 @@ export class TfjsService { }, 1000); } getInput(): tf.Tensor { - return tf.concat( - [this.density, this.velocity, this.pressure.toTensor()], - 3, - ); + const pressure = this.pressure.toTensor(); + const input = tf.concat([this.density, this.velocity, pressure], 3); + pressure.dispose(); + return input; } private iterate(): void { if (this.isPaused) { @@ -115,25 +164,40 @@ export class TfjsService { const energy = this.velocity.square().sum(); const output = this.model?.predict(input) as tf.Tensor; // update density, velocity - this.density = output?.slice( - [0, 0, 0, 0], - [this.batchSize, ...this.gridSize, 1], - ) as tf.Tensor4D; - this.velocity = output?.slice( - [0, 0, 0, 1], - [this.batchSize, ...this.gridSize, 2], - ) as tf.Tensor4D; + this.density.assign( + output?.slice( + [0, 0, 0, 0], + [this.batchSize, ...this.gridSize, 1], + ) as tf.Tensor4D, + ); + this.velocity.assign( + output?.slice( + [0, 0, 0, 1], + [this.batchSize, ...this.gridSize, 2], + ) as tf.Tensor4D, + ); // update density, velocity const newEnergy = this.velocity.square().sum(); const energyScale = energy.div(newEnergy); + energyScale.print(); + this.velocity = this.velocity.mul(energyScale.sqrt()); const newMass = this.density.sum(); const massScale = this.mass.div(newMass); this.density = this.density.mul(massScale); + massScale.print(); + newMass.dispose(); + newEnergy.dispose(); + energy.dispose(); + energyScale.dispose(); + this.outputCallback(output?.dataSync() as Float32Array); + output.dispose(); // set timeout to 0 to allow other tasks to run, like pause and apply force setTimeout(() => { - this.iterate(); + this.curFrameCountbyLastSecond += 1; + console.log(this.curFrameCountbyLastSecond); + void this.iterate(); }, 0); } @@ -153,9 +217,15 @@ export class TfjsService { 4, ); } -} -interface ModelData { - density: number[][][][]; - velocity: number[][][][]; - pressure: number[][][][]; + getInputTensor(): Float32Array { + const input = this.getInput(); + const data = input.dataSync(); + input.dispose(); + return data as Float32Array; + } + dispose(): void { + this.density.dispose(); + this.velocity.dispose(); + this.model.dispose(); + } }