Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

wasm: update backend to consume latest ONNX Runtime #270

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
Please follow the following steps to running tests:

1. run `npm ci` in the root folder of the repo.
2. (Optional) run `npm run build` in the root folder of the repo to enable WebAssebmly features.
2. (Optional) build WebAssembly backend:
1. build ONNX Runtime WebAssembly and copy files "ort-wasm.\*" to /dist/.
2. if building ONNX Runtime WebAssembly with multi-threads support, copy files "ort-wasm-threaded.\*" to /dist/.
3. run `npm run build` in the root folder of the repo to enable WebAssebmly features.
3. run `npm test` to run suite0 test cases and check the console output.
- if (2) is not run, please run `npm test -- -b=cpu,webgl` to skip WebAssebmly tests

Expand Down
12 changes: 10 additions & 2 deletions karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ module.exports = function (config) {
{ pattern: 'test/data/**/*', included: false, nocache: true },
{ pattern: 'deps/data/data/test/**/*', included: false, nocache: true },
{ pattern: 'deps/onnx/onnx/backend/test/data/**/*', included: false, nocache: true },
{ pattern: 'dist/onnx-wasm.wasm', included: false },
{ pattern: 'dist/ort-wasm.js', included: true },
{ pattern: 'dist/ort-wasm.wasm', included: false },
{ pattern: 'dist/ort-wasm-threaded.js', included: true },
{ pattern: 'dist/ort-wasm-threaded.wasm', included: false },
{ pattern: 'dist/ort-wasm-threaded.worker.js', included: false },
],
proxies: {
'/onnx-wasm.wasm': '/base/dist/onnx-wasm.wasm',
'/ort-wasm.js': '/base/dist/ort-wasm.js',
'/ort-wasm.wasm': '/base/dist/ort-wasm.wasm',
'/ort-wasm-threaded.js': '/base/dist/ort-wasm-threaded.js',
'/ort-wasm-threaded.wasm': '/base/dist/ort-wasm-threaded.wasm',
'/ort-wasm-threaded.worker.js': '/base/dist/ort-wasm-threaded.worker.js',
'/onnx-worker.js': '/base/test/onnx-worker.js',
},
plugins: karmaPlugins,
Expand Down
2 changes: 1 addition & 1 deletion lib/api/inference-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export class InferenceSession implements InferenceSessionInterface {
output = await this.session.run(modelInputFeed);
} else if (Array.isArray(inputFeed)) {
const modelInputFeed: InternalTensor[] = [];
inputFeed.forEach((value) => {
inputFeed.forEach((value: ApiTensor) => {
modelInputFeed.push(value.internalTensor);
});
output = await this.session.run(modelInputFeed);
Expand Down
184 changes: 173 additions & 11 deletions lib/backends/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,190 @@
import {Backend, InferenceHandler, SessionHandler} from '../../backend';
import {Graph} from '../../graph';
import {Operator} from '../../operators';
import {OpSet, resolveOperator} from '../../opset';
import {OpSet} from '../../opset';
import {Session} from '../../session';
import {CPU_OP_RESOLVE_RULES} from '../cpu/op-resolve-rules';
import {Tensor} from '../../tensor';
import {ProtoUtil} from '../../util';
import {getInstance} from '../../wasm-binding-core';

import {WasmInferenceHandler} from './inference-handler';
import {WASM_OP_RESOLVE_RULES} from './op-resolve-rules';

export class WasmSessionHandler implements SessionHandler {
private opResolveRules: ReadonlyArray<OpSet.ResolveRule>;
constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) {
this.opResolveRules = fallbackToCpuOps ? WASM_OP_RESOLVE_RULES.concat(CPU_OP_RESOLVE_RULES) : WASM_OP_RESOLVE_RULES;
constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) {}
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator {
throw new Error('Method not implemented.');
}

createInferenceHandler(): InferenceHandler {
return new WasmInferenceHandler(this, this.context.profiler);
}

dispose(): void {}
// vNEXT latest:
ortInit: boolean;
sessionHandle: number;

resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator {
const op = resolveOperator(node, opsets, this.opResolveRules);
op.initialize(node.attributes, node, graph);
return op;
inputNames: string[];
inputNamesUTF8Encoded: number[];
outputNames: string[];
outputNamesUTF8Encoded: number[];

loadModel(model: Uint8Array) {
const wasm = getInstance();
if (!this.ortInit) {
wasm._OrtInit(4, 0);
this.ortInit = true;
}

const modelDataOffset = wasm._malloc(model.byteLength);
try {
wasm.HEAPU8.set(model, modelDataOffset);
this.sessionHandle = wasm._OrtCreateSession(modelDataOffset, model.byteLength);
} finally {
wasm._free(modelDataOffset);
}

const inputCount = wasm._OrtGetInputCount(this.sessionHandle);
const outputCount = wasm._OrtGetOutputCount(this.sessionHandle);

this.inputNames = [];
this.inputNamesUTF8Encoded = [];
this.outputNames = [];
this.outputNamesUTF8Encoded = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(this.sessionHandle, i);
this.inputNamesUTF8Encoded.push(name);
this.inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(this.sessionHandle, i);
this.outputNamesUTF8Encoded.push(name);
this.outputNames.push(wasm.UTF8ToString(name));
}
}

run(inputs: Map<string, Tensor>|Tensor[]): Map<string, Tensor> {
const wasm = getInstance();

let inputIndices: number[] = [];
if (!Array.isArray(inputs)) {
const inputArray: Tensor[] = [];
inputs.forEach((tensor, name) => {
const index = this.inputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}'`);
}
inputArray.push(tensor);
inputIndices.push(index);
});
inputs = inputArray;
} else {
inputIndices = inputs.map((t, i) => i);
}

const inputCount = inputs.length;
const outputCount = this.outputNames.length;

const inputValues: number[] = [];
const inputDataOffsets: number[] = [];
// create input tensors
for (let i = 0; i < inputCount; i++) {
const data = inputs[i].numberData;
const dataOffset = wasm._malloc(data.byteLength);
inputDataOffsets.push(dataOffset);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength), dataOffset);

const dims = inputs[i].dims;

const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
ProtoUtil.tensorDataTypeStringToEnum(inputs[i].type), dataOffset, data.byteLength, dimsOffset, dims.length);
inputValues.push(tensor);
} finally {
wasm.stackRestore(stack);
}
}

const beforeRunStack = wasm.stackSave();
const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
try {
let inputValuesIndex = inputValuesOffset / 4;
let inputNamesIndex = inputNamesOffset / 4;
let outputValuesIndex = outputValuesOffset / 4;
let outputNamesIndex = outputNamesOffset / 4;
for (let i = 0; i < inputCount; i++) {
wasm.HEAP32[inputValuesIndex++] = inputValues[i];
wasm.HEAP32[inputNamesIndex++] = this.inputNamesUTF8Encoded[i];
}
for (let i = 0; i < outputCount; i++) {
wasm.HEAP32[outputValuesIndex++] = 0;
wasm.HEAP32[outputNamesIndex++] = this.outputNamesUTF8Encoded[i];
}

wasm._OrtRun(
this.sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset);

const output = new Map<string, Tensor>();

for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];

const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);
try {
wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
const dataOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);

const t = new Tensor(dims, ProtoUtil.tensorDataTypeFromProto(dataType));
new Uint8Array(t.numberData.buffer, t.numberData.byteOffset, t.numberData.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + t.numberData.byteLength));
output.set(this.outputNames[i], t);
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
}

wasm._OrtReleaseTensor(tensor);
}

inputValues.forEach(t => wasm._OrtReleaseTensor(t));
inputDataOffsets.forEach(i => wasm._free(i));

return output;
} finally {
wasm.stackRestore(beforeRunStack);
}
}
dispose() {
const wasm = getInstance();
if (this.inputNamesUTF8Encoded) {
this.inputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str));
this.inputNamesUTF8Encoded = [];
}
if (this.outputNamesUTF8Encoded) {
this.outputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str));
this.outputNamesUTF8Encoded = [];
}
if (this.sessionHandle) {
wasm._OrtReleaseSession(this.sessionHandle);
this.sessionHandle = 0;
}
}
}
10 changes: 10 additions & 0 deletions lib/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {readFile} from 'fs';
import {promisify} from 'util';

import {Backend, SessionHandlerType} from './backend';
import {WasmSessionHandler} from './backends/wasm/session-handler';
import {ExecutionPlan} from './execution-plan';
import {Graph} from './graph';
import {Profiler} from './instrument';
Expand Down Expand Up @@ -79,6 +80,11 @@ export class Session {
}

this.profiler.event('session', 'Session.initialize', () => {
if ((this.sessionHandler as {run?: unknown}).run) {
(this.sessionHandler as WasmSessionHandler).loadModel(modelProtoBlob);
return;
}

// load graph
const graphInitializer =
this.sessionHandler.transformGraph ? this.sessionHandler as Graph.Initializer : undefined;
Expand All @@ -104,6 +110,10 @@ export class Session {
}

return this.profiler.event('session', 'Session.run', async () => {
if ((this.sessionHandler as {run?: unknown}).run) {
return (this.sessionHandler as WasmSessionHandler).run(inputs);
}

const inputTensors = this.normalizeAndValidateInputs(inputs);

const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);
Expand Down
32 changes: 32 additions & 0 deletions lib/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,38 @@ export class ProtoUtil {
}
}

static tensorDataTypeStringToEnum(type: string): onnx.TensorProto.DataType {
switch (type) {
case 'int8':
return onnx.TensorProto.DataType.INT8;
case 'uint8':
return onnx.TensorProto.DataType.UINT8;
case 'bool':
return onnx.TensorProto.DataType.BOOL;
case 'int16':
return onnx.TensorProto.DataType.INT16;
case 'uint16':
return onnx.TensorProto.DataType.UINT16;
case 'int32':
return onnx.TensorProto.DataType.INT32;
case 'uint32':
return onnx.TensorProto.DataType.UINT32;
case 'float32':
return onnx.TensorProto.DataType.FLOAT;
case 'float64':
return onnx.TensorProto.DataType.DOUBLE;
case 'string':
return onnx.TensorProto.DataType.STRING;
case 'int64':
return onnx.TensorProto.DataType.INT64;
case 'uint64':
return onnx.TensorProto.DataType.UINT64;

default:
throw new Error(`unsupported data type: ${type}`);
}
}

static tensorDimsFromProto(dims: Array<number|Long>): number[] {
// get rid of Long type for dims
return dims.map(d => Long.isLong(d) ? d.toNumber() : d);
Expand Down
Loading