From 734947640a400af4d47e3a6f6acb433e40d03d7d Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Fri, 26 Feb 2021 03:53:25 -0800 Subject: [PATCH 1/4] update wasm backend to consume latest ORT --- docs/development.md | 4 +- karma.conf.js | 4 +- lib/backends/wasm/session-handler.ts | 184 +++++++++++++++++++++++++-- lib/session.ts | 10 ++ lib/util.ts | 32 +++++ lib/wasm-binding-core.ts | 34 ++++- tools/build.ts | 101 +-------------- 7 files changed, 257 insertions(+), 112 deletions(-) diff --git a/docs/development.md b/docs/development.md index 0b022202..67439ae9 100644 --- a/docs/development.md +++ b/docs/development.md @@ -21,7 +21,9 @@ Please follow the following steps to running tests: 1. run `npm ci` in the root folder of the repo. -2. (Optional) run `npm run build` in the root folder of the repo to enable WebAssebmly features. +2. (Optional) build WebAssembly backend: + 1. build ONNX Runtime WebAssembly and copy files "onnxruntime_wasm.\*" to /dist/. + 2. run `npm run build` in the root folder of the repo to enable WebAssebmly features. 3. run `npm test` to run suite0 test cases and check the console output. - if (2) is not run, please run `npm test -- -b=cpu,webgl` to skip WebAssebmly tests diff --git a/karma.conf.js b/karma.conf.js index c1bc7d55..8a9441e5 100644 --- a/karma.conf.js +++ b/karma.conf.js @@ -48,10 +48,10 @@ module.exports = function (config) { { pattern: 'test/data/**/*', included: false, nocache: true }, { pattern: 'deps/data/data/test/**/*', included: false, nocache: true }, { pattern: 'deps/onnx/onnx/backend/test/data/**/*', included: false, nocache: true }, - { pattern: 'dist/onnx-wasm.wasm', included: false }, + { pattern: 'dist/onnxruntime_wasm.wasm', included: false }, ], proxies: { - '/onnx-wasm.wasm': '/base/dist/onnx-wasm.wasm', + '/onnxruntime_wasm.wasm': '/base/dist/onnxruntime_wasm.wasm', '/onnx-worker.js': '/base/test/onnx-worker.js', }, plugins: karmaPlugins, diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index a50f58c3..8b87737c 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -4,28 +4,190 @@ import {Backend, InferenceHandler, SessionHandler} from '../../backend'; import {Graph} from '../../graph'; import {Operator} from '../../operators'; -import {OpSet, resolveOperator} from '../../opset'; +import {OpSet} from '../../opset'; import {Session} from '../../session'; -import {CPU_OP_RESOLVE_RULES} from '../cpu/op-resolve-rules'; +import {Tensor} from '../../tensor'; +import {ProtoUtil} from '../../util'; +import {getInstance} from '../../wasm-binding-core'; import {WasmInferenceHandler} from './inference-handler'; -import {WASM_OP_RESOLVE_RULES} from './op-resolve-rules'; export class WasmSessionHandler implements SessionHandler { - private opResolveRules: ReadonlyArray; - constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) { - this.opResolveRules = fallbackToCpuOps ? WASM_OP_RESOLVE_RULES.concat(CPU_OP_RESOLVE_RULES) : WASM_OP_RESOLVE_RULES; + constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) {} + resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator { + throw new Error('Method not implemented.'); } createInferenceHandler(): InferenceHandler { return new WasmInferenceHandler(this, this.context.profiler); } - dispose(): void {} + // vNEXT latest: + ortInit: boolean; + sessionHandle: number; - resolve(node: Graph.Node, opsets: ReadonlyArray, graph: Graph): Operator { - const op = resolveOperator(node, opsets, this.opResolveRules); - op.initialize(node.attributes, node, graph); - return op; + inputNames: string[]; + inputNamesUTF8Encoded: number[]; + outputNames: string[]; + outputNamesUTF8Encoded: number[]; + + loadModel(model: Uint8Array) { + const wasm = getInstance(); + if (!this.ortInit) { + wasm._ort_init(); + this.ortInit = true; + } + + const modelDataOffset = wasm._malloc(model.byteLength); + try { + wasm.HEAPU8.set(model, modelDataOffset); + this.sessionHandle = wasm._ort_create_session(modelDataOffset, model.byteLength); + } finally { + wasm._free(modelDataOffset); + } + + const inputCount = wasm._ort_get_input_count(this.sessionHandle); + const outputCount = wasm._ort_get_output_count(this.sessionHandle); + + this.inputNames = []; + this.inputNamesUTF8Encoded = []; + this.outputNames = []; + this.outputNamesUTF8Encoded = []; + for (let i = 0; i < inputCount; i++) { + const name = wasm._ort_get_input_name(this.sessionHandle, i); + this.inputNamesUTF8Encoded.push(name); + this.inputNames.push(wasm.UTF8ToString(name)); + } + for (let i = 0; i < outputCount; i++) { + const name = wasm._ort_get_output_name(this.sessionHandle, i); + this.outputNamesUTF8Encoded.push(name); + this.outputNames.push(wasm.UTF8ToString(name)); + } + } + + run(inputs: Map|Tensor[]): Map { + const wasm = getInstance(); + + let inputIndices: number[] = []; + if (!Array.isArray(inputs)) { + const inputArray: Tensor[] = []; + inputs.forEach((tensor, name) => { + const index = this.inputNames.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}'`); + } + inputArray.push(tensor); + inputIndices.push(index); + }); + inputs = inputArray; + } else { + inputIndices = inputs.map((t, i) => i); + } + + const inputCount = inputs.length; + const outputCount = this.outputNames.length; + + const inputValues: number[] = []; + const inputDataOffsets: number[] = []; + // create input tensors + for (let i = 0; i < inputCount; i++) { + const data = inputs[i].numberData; + const dataOffset = wasm._malloc(data.byteLength); + inputDataOffsets.push(dataOffset); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength), dataOffset); + + const dims = inputs[i].dims; + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._ort_create_tensor( + ProtoUtil.tensorDataTypeStringToEnum(inputs[i].type), dataOffset, data.byteLength, dimsOffset, dims.length); + inputValues.push(tensor); + } finally { + wasm.stackRestore(stack); + } + } + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); + try { + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAP32[inputValuesIndex++] = inputValues[i]; + wasm.HEAP32[inputNamesIndex++] = this.inputNamesUTF8Encoded[i]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAP32[outputValuesIndex++] = 0; + wasm.HEAP32[outputNamesIndex++] = this.outputNamesUTF8Encoded[i]; + } + + wasm._ort_run( + this.sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, + outputValuesOffset); + + const output = new Map(); + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + try { + wasm._ort_get_tensor_data( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + const dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._ort_free(dimsOffset); + + const t = new Tensor(dims, ProtoUtil.tensorDataTypeFromProto(dataType)); + new Uint8Array(t.numberData.buffer, t.numberData.byteOffset, t.numberData.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + t.numberData.byteLength)); + output.set(this.outputNames[i], t); + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + } + + wasm._ort_release_tensor(tensor); + } + + inputValues.forEach(t => wasm._ort_release_tensor(t)); + inputDataOffsets.forEach(i => wasm._free(i)); + + return output; + } finally { + wasm.stackRestore(beforeRunStack); + } + } + dispose() { + const wasm = getInstance(); + if (this.inputNamesUTF8Encoded) { + this.inputNamesUTF8Encoded.forEach(str => wasm._ort_free(str)); + this.inputNamesUTF8Encoded = []; + } + if (this.outputNamesUTF8Encoded) { + this.outputNamesUTF8Encoded.forEach(str => wasm._ort_free(str)); + this.outputNamesUTF8Encoded = []; + } + if (this.sessionHandle) { + wasm._ort_release_session(this.sessionHandle); + this.sessionHandle = 0; + } } } diff --git a/lib/session.ts b/lib/session.ts index 912b2a34..9c4ef54d 100644 --- a/lib/session.ts +++ b/lib/session.ts @@ -5,6 +5,7 @@ import {readFile} from 'fs'; import {promisify} from 'util'; import {Backend, SessionHandlerType} from './backend'; +import {WasmSessionHandler} from './backends/wasm/session-handler'; import {ExecutionPlan} from './execution-plan'; import {Graph} from './graph'; import {Profiler} from './instrument'; @@ -79,6 +80,11 @@ export class Session { } this.profiler.event('session', 'Session.initialize', () => { + if ((this.sessionHandler as {run?: unknown}).run) { + (this.sessionHandler as WasmSessionHandler).loadModel(modelProtoBlob); + return; + } + // load graph const graphInitializer = this.sessionHandler.transformGraph ? this.sessionHandler as Graph.Initializer : undefined; @@ -104,6 +110,10 @@ export class Session { } return this.profiler.event('session', 'Session.run', async () => { + if ((this.sessionHandler as {run?: unknown}).run) { + return (this.sessionHandler as WasmSessionHandler).run(inputs); + } + const inputTensors = this.normalizeAndValidateInputs(inputs); const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors); diff --git a/lib/util.ts b/lib/util.ts index f269ae7e..f09d69fb 100644 --- a/lib/util.ts +++ b/lib/util.ts @@ -406,6 +406,38 @@ export class ProtoUtil { } } + static tensorDataTypeStringToEnum(type: string): onnx.TensorProto.DataType { + switch (type) { + case 'int8': + return onnx.TensorProto.DataType.INT8; + case 'uint8': + return onnx.TensorProto.DataType.UINT8; + case 'bool': + return onnx.TensorProto.DataType.BOOL; + case 'int16': + return onnx.TensorProto.DataType.INT16; + case 'uint16': + return onnx.TensorProto.DataType.UINT16; + case 'int32': + return onnx.TensorProto.DataType.INT32; + case 'uint32': + return onnx.TensorProto.DataType.UINT32; + case 'float32': + return onnx.TensorProto.DataType.FLOAT; + case 'float64': + return onnx.TensorProto.DataType.DOUBLE; + case 'string': + return onnx.TensorProto.DataType.STRING; + case 'int64': + return onnx.TensorProto.DataType.INT64; + case 'uint64': + return onnx.TensorProto.DataType.UINT64; + + default: + throw new Error(`unsupported data type: ${type}`); + } + } + static tensorDimsFromProto(dims: Array): number[] { // get rid of Long type for dims return dims.map(d => Long.isLong(d) ? d.toNumber() : d); diff --git a/lib/wasm-binding-core.ts b/lib/wasm-binding-core.ts index c1924cd4..037e20b1 100644 --- a/lib/wasm-binding-core.ts +++ b/lib/wasm-binding-core.ts @@ -17,6 +17,34 @@ declare interface OnnxWasmBindingJs { HEAPU32: Uint32Array; HEAPF32: Float32Array; HEAPF64: Float64Array; + + stackSave(): number; + stackRestore(stack: number): void; + stackAlloc(size: number): number; + + UTF8ToString(offset: number): string; + lengthBytesUTF8(str: string): number; + stringToUTF8(str: string, offset: number, maxBytes: number): void; + + _ort_init(): void; + + _ort_create_session(dataOffset: number, dataLength: number): number; + _ort_release_session(sessionHandle: number): void; + _ort_get_input_count(sessionHandle: number): number; + _ort_get_output_count(sessionHandle: number): number; + _ort_get_input_name(sessionHandle: number, index: number): number; + _ort_get_output_name(sessionHandle: number, index: number): number; + + _ort_free(stringHandle: number): void; + + _ort_create_tensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): + number; + _ort_get_tensor_data( + tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): void; + _ort_release_tensor(tensorHandle: number): void; + _ort_run( + sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, + outputNamesOffset: number, outputCount: number, outputsOffset: number): void; } // an interface to define argument handling @@ -70,7 +98,7 @@ export function init(): Promise { return new Promise((resolve, reject) => { // tslint:disable-next-line:no-require-imports - binding = require('../dist/onnx-wasm') as OnnxWasmBindingJs; + binding = require('../dist/onnxruntime_wasm') as OnnxWasmBindingJs; binding(binding).then( () => { // resolve init() promise @@ -327,6 +355,10 @@ export class WasmBinding { } } +export function getInstance(): OnnxWasmBindingJs { + return binding!; +} + /** * returns a number to represent the current timestamp in a resolution as high as possible. */ diff --git a/tools/build.ts b/tools/build.ts index 66b83f5c..f9f9af48 100644 --- a/tools/build.ts +++ b/tools/build.ts @@ -3,7 +3,6 @@ import {execSync, spawnSync} from 'child_process'; import * as fs from 'fs-extra'; -import * as globby from 'globby'; import npmlog from 'npmlog'; import * as path from 'path'; @@ -22,40 +21,13 @@ const buildBundle = process.argv.indexOf('--build-bundle') !== -1; // Path variables const ROOT = path.join(__dirname, '..'); const DEPS = path.join(ROOT, 'deps'); -const DEPS_EIGEN = path.join(DEPS, 'eigen'); -const DEPS_EMSDK = path.join(DEPS, 'emsdk'); -const DEPS_EMSDK_EMSCRIPTEN = path.join(DEPS_EMSDK, 'upstream', 'emscripten'); const DEPS_ONNX = path.join(DEPS, 'onnx'); -const EMSDK_BIN = path.join(DEPS_EMSDK, 'emsdk'); -const SRC = path.join(ROOT, 'src'); -const SRC_WASM_BUILD_CONFIG = path.join(SRC, 'wasm-build-config.json'); const TEST = path.join(ROOT, 'test'); const TEST_DATA = path.join(TEST, 'data'); const TEST_DATA_NODE = path.join(TEST_DATA, 'node'); const OUT = path.join(ROOT, 'dist'); -const OUT_WASM_JS = path.join(OUT, 'onnx-wasm.js'); -const OUT_WASM = path.join(OUT, 'onnx-wasm.wasm'); - -// Emcc (for Wasm) compile flags -// Add new compiler flags here (if needed) -const BUILD_OPTIONS = [ - '-I' + DEPS_EIGEN, - '-DEIGEN_MPL2_ONLY', - '-std=c++11', - '-s WASM=1', - '-s NO_EXIT_RUNTIME=0', - '-s ALLOW_MEMORY_GROWTH=1', - '-s SAFE_HEAP=0', - '-s MODULARIZE=1', - '-s SAFE_HEAP_LOG=0', - '-s STACK_OVERFLOW_CHECK=0', - // '-s DEBUG_LEVEL=0', // DEBUG_LEVEL is disabled in emsdk 1.39.16 - '-s VERBOSE=0', - '-s EXPORT_ALL=0', - '-o ' + OUT_WASM_JS, - '-O2', - '--llvm-lto 3', -]; +const OUT_WASM_JS = path.join(OUT, 'onnxruntime_wasm.js'); +const OUT_WASM = path.join(OUT, 'onnxruntime_wasm.wasm'); npmlog.info('Build', 'Initialization completed. Start to build...'); @@ -138,75 +110,10 @@ if (!buildWasm) { fs.writeFileSync(OUT_WASM_JS, `;throw new Error("please build WebAssembly before use wasm backend.");`); } } else { - // Step 1: emsdk install (if needed) - npmlog.info('Build.Wasm', '(1/4) Setting up emsdk...'); - if (!fs.existsSync(DEPS_EMSDK_EMSCRIPTEN)) { - npmlog.info('Build.Wasm', 'Installing emsdk...'); - const install = spawnSync(`${EMSDK_BIN} install latest`, {shell: true, stdio: 'inherit', cwd: DEPS_EMSDK}); - if (install.status !== 0) { - if (install.error) { - console.error(install.error); - } - process.exit(install.status === null ? undefined : install.status); - } - npmlog.info('Build.Wasm', 'Installing emsdk... DONE'); - - npmlog.info('Build.Wasm', 'Activating emsdk...'); - const activate = spawnSync(`${EMSDK_BIN} activate latest`, {shell: true, stdio: 'inherit', cwd: DEPS_EMSDK}); - if (activate.status !== 0) { - if (activate.error) { - console.error(activate.error); - } - process.exit(activate.status === null ? undefined : activate.status); - } - npmlog.info('Build.Wasm', 'Activating emsdk... DONE'); - } - npmlog.info('Build.Wasm', '(1/4) Setting up emsdk... DONE'); - - // Step 2: Find path to emcc - npmlog.info('Build.Wasm', '(2/4) Find path to emcc...'); - let emcc = globby.sync('./**/emcc', {cwd: DEPS_EMSDK_EMSCRIPTEN})[0]; - if (!emcc) { - npmlog.error('Build.Wasm', 'Unable to find emcc. Try re-building with --clean-install flag.'); - process.exit(2); - } - emcc = path.join(DEPS_EMSDK_EMSCRIPTEN, emcc); - npmlog.info('Build.Wasm', `(2/4) Find path to emcc... DONE, emcc: ${emcc}`); - - // Step 3: Prepare build config - npmlog.info('Build.Wasm', '(3/4) Preparing build config...'); - // tslint:disable-next-line:non-literal-require - const wasmBuildConfig = require(SRC_WASM_BUILD_CONFIG); - const exportedFunctions = wasmBuildConfig.exported_functions as string[]; - const srcPatterns = wasmBuildConfig.src as string[]; - if (exportedFunctions.length === 0) { - npmlog.error('Build.Wasm', `No exported functions specified in the file: ${SRC_WASM_BUILD_CONFIG}`); + if (!fs.existsSync(OUT_WASM)) { + npmlog.error('Build.Wasm', 'Please make sure onnxruntime_wasm.wasm is built and exists in /dist/'); process.exit(1); } - - BUILD_OPTIONS.push(`-s "EXPORTED_FUNCTIONS=[${exportedFunctions.map(f => `${f}`).join(',')}]"`); - - const cppFileNames = globby.sync(srcPatterns, {cwd: SRC}); - if (cppFileNames.length === 0) { - npmlog.error('Build.Wasm', 'Unable to find any cpp source files to compile and generate the WASM file'); - process.exit(2); - } - - const compileSourcesString = cppFileNames.map(f => path.join(SRC, f)).join(' '); - BUILD_OPTIONS.push(compileSourcesString); - npmlog.info('Build.Wasm', '(3/4) Preparing build config... DONE'); - - // Step 4: Compile the source code to generate the Wasm file - npmlog.info('Build.Wasm', '(4/4) Building...'); - npmlog.info('Build.Wasm', `CMD: ${emcc} ${BUILD_OPTIONS}`); - - const emccBuild = spawnSync(emcc, BUILD_OPTIONS, {shell: true, stdio: 'inherit', cwd: __dirname}); - - if (emccBuild.error) { - console.error(emccBuild.error); - process.exit(emccBuild.status === null ? undefined : emccBuild.status); - } - npmlog.info('Build.Wasm', '(4/4) Building... DONE'); } npmlog.info('Build', `Building WebAssembly sources... ${buildWasm ? 'DONE' : 'SKIPPED'}`); From ff65c6614e0bc52c96e5ed704c5c61bc9791a2b3 Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Tue, 23 Mar 2021 15:39:58 -0700 Subject: [PATCH 2/4] update function names to align with latest change --- lib/backends/wasm/session-handler.ts | 30 ++++++++++++++-------------- lib/wasm-binding-core.ts | 26 ++++++++++++------------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index 8b87737c..97f1c3f6 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -34,32 +34,32 @@ export class WasmSessionHandler implements SessionHandler { loadModel(model: Uint8Array) { const wasm = getInstance(); if (!this.ortInit) { - wasm._ort_init(); + wasm._OrtInit(); this.ortInit = true; } const modelDataOffset = wasm._malloc(model.byteLength); try { wasm.HEAPU8.set(model, modelDataOffset); - this.sessionHandle = wasm._ort_create_session(modelDataOffset, model.byteLength); + this.sessionHandle = wasm._OrtCreateSession(modelDataOffset, model.byteLength); } finally { wasm._free(modelDataOffset); } - const inputCount = wasm._ort_get_input_count(this.sessionHandle); - const outputCount = wasm._ort_get_output_count(this.sessionHandle); + const inputCount = wasm._OrtGetInputCount(this.sessionHandle); + const outputCount = wasm._OrtGetOutputCount(this.sessionHandle); this.inputNames = []; this.inputNamesUTF8Encoded = []; this.outputNames = []; this.outputNamesUTF8Encoded = []; for (let i = 0; i < inputCount; i++) { - const name = wasm._ort_get_input_name(this.sessionHandle, i); + const name = wasm._OrtGetInputName(this.sessionHandle, i); this.inputNamesUTF8Encoded.push(name); this.inputNames.push(wasm.UTF8ToString(name)); } for (let i = 0; i < outputCount; i++) { - const name = wasm._ort_get_output_name(this.sessionHandle, i); + const name = wasm._OrtGetOutputName(this.sessionHandle, i); this.outputNamesUTF8Encoded.push(name); this.outputNames.push(wasm.UTF8ToString(name)); } @@ -103,7 +103,7 @@ export class WasmSessionHandler implements SessionHandler { try { let dimIndex = dimsOffset / 4; dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._ort_create_tensor( + const tensor = wasm._OrtCreateTensor( ProtoUtil.tensorDataTypeStringToEnum(inputs[i].type), dataOffset, data.byteLength, dimsOffset, dims.length); inputValues.push(tensor); } finally { @@ -130,7 +130,7 @@ export class WasmSessionHandler implements SessionHandler { wasm.HEAP32[outputNamesIndex++] = this.outputNamesUTF8Encoded[i]; } - wasm._ort_run( + wasm._OrtRun( this.sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset); @@ -143,7 +143,7 @@ export class WasmSessionHandler implements SessionHandler { // stack allocate 4 pointer value const tensorDataOffset = wasm.stackAlloc(4 * 4); try { - wasm._ort_get_tensor_data( + wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; @@ -154,7 +154,7 @@ export class WasmSessionHandler implements SessionHandler { for (let i = 0; i < dimsLength; i++) { dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); } - wasm._ort_free(dimsOffset); + wasm._OrtFree(dimsOffset); const t = new Tensor(dims, ProtoUtil.tensorDataTypeFromProto(dataType)); new Uint8Array(t.numberData.buffer, t.numberData.byteOffset, t.numberData.byteLength) @@ -164,10 +164,10 @@ export class WasmSessionHandler implements SessionHandler { wasm.stackRestore(beforeGetTensorDataStack); } - wasm._ort_release_tensor(tensor); + wasm._OrtReleaseTensor(tensor); } - inputValues.forEach(t => wasm._ort_release_tensor(t)); + inputValues.forEach(t => wasm._OrtReleaseTensor(t)); inputDataOffsets.forEach(i => wasm._free(i)); return output; @@ -178,15 +178,15 @@ export class WasmSessionHandler implements SessionHandler { dispose() { const wasm = getInstance(); if (this.inputNamesUTF8Encoded) { - this.inputNamesUTF8Encoded.forEach(str => wasm._ort_free(str)); + this.inputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str)); this.inputNamesUTF8Encoded = []; } if (this.outputNamesUTF8Encoded) { - this.outputNamesUTF8Encoded.forEach(str => wasm._ort_free(str)); + this.outputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str)); this.outputNamesUTF8Encoded = []; } if (this.sessionHandle) { - wasm._ort_release_session(this.sessionHandle); + wasm._OrtReleaseSession(this.sessionHandle); this.sessionHandle = 0; } } diff --git a/lib/wasm-binding-core.ts b/lib/wasm-binding-core.ts index 037e20b1..4a9c1e2f 100644 --- a/lib/wasm-binding-core.ts +++ b/lib/wasm-binding-core.ts @@ -26,23 +26,23 @@ declare interface OnnxWasmBindingJs { lengthBytesUTF8(str: string): number; stringToUTF8(str: string, offset: number, maxBytes: number): void; - _ort_init(): void; + _OrtInit(): void; - _ort_create_session(dataOffset: number, dataLength: number): number; - _ort_release_session(sessionHandle: number): void; - _ort_get_input_count(sessionHandle: number): number; - _ort_get_output_count(sessionHandle: number): number; - _ort_get_input_name(sessionHandle: number, index: number): number; - _ort_get_output_name(sessionHandle: number, index: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number): number; + _OrtReleaseSession(sessionHandle: number): void; + _OrtGetInputCount(sessionHandle: number): number; + _OrtGetOutputCount(sessionHandle: number): number; + _OrtGetInputName(sessionHandle: number, index: number): number; + _OrtGetOutputName(sessionHandle: number, index: number): number; - _ort_free(stringHandle: number): void; + _OrtFree(stringHandle: number): void; - _ort_create_tensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): + _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): number; - _ort_get_tensor_data( - tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): void; - _ort_release_tensor(tensorHandle: number): void; - _ort_run( + _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): + void; + _OrtReleaseTensor(tensorHandle: number): void; + _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, outputNamesOffset: number, outputCount: number, outputsOffset: number): void; } From e15d7c5e8b90494dda5fae527258daeb8ca105f2 Mon Sep 17 00:00:00 2001 From: Sunghoon <35605090+hanbitmyths@users.noreply.github.com> Date: Mon, 12 Apr 2021 23:39:02 -0700 Subject: [PATCH 3/4] Support WebAssembly multi-threads --- docs/development.md | 3 +- karma.conf.js | 8 ++++ lib/api/inference-session-impl.ts | 2 +- lib/wasm-binding-core.ts | 73 +++++++++++++++++++++---------- lib/wasm-binding.ts | 5 ++- webpack.config.js | 8 +++- 6 files changed, 70 insertions(+), 29 deletions(-) diff --git a/docs/development.md b/docs/development.md index 67439ae9..8889bb3f 100644 --- a/docs/development.md +++ b/docs/development.md @@ -23,7 +23,8 @@ Please follow the following steps to running tests: 1. run `npm ci` in the root folder of the repo. 2. (Optional) build WebAssembly backend: 1. build ONNX Runtime WebAssembly and copy files "onnxruntime_wasm.\*" to /dist/. - 2. run `npm run build` in the root folder of the repo to enable WebAssebmly features. + 2. if building ONNX Runtime WebAssembly with multi-threads support, copy files "onnxruntime_wasm_threads.\*" to /dist/. + 3. run `npm run build` in the root folder of the repo to enable WebAssebmly features. 3. run `npm test` to run suite0 test cases and check the console output. - if (2) is not run, please run `npm test -- -b=cpu,webgl` to skip WebAssebmly tests diff --git a/karma.conf.js b/karma.conf.js index 8a9441e5..5f9358c7 100644 --- a/karma.conf.js +++ b/karma.conf.js @@ -48,10 +48,18 @@ module.exports = function (config) { { pattern: 'test/data/**/*', included: false, nocache: true }, { pattern: 'deps/data/data/test/**/*', included: false, nocache: true }, { pattern: 'deps/onnx/onnx/backend/test/data/**/*', included: false, nocache: true }, + { pattern: 'dist/onnxruntime_wasm.js', included: true }, { pattern: 'dist/onnxruntime_wasm.wasm', included: false }, + { pattern: 'dist/onnxruntime_wasm_threads.js', included: true }, + { pattern: 'dist/onnxruntime_wasm_threads.wasm', included: false }, + { pattern: 'dist/onnxruntime_wasm_threads.worker.js', included: false }, ], proxies: { + '/onnxruntime_wasm.js': '/base/dist/onnxruntime_wasm.js', '/onnxruntime_wasm.wasm': '/base/dist/onnxruntime_wasm.wasm', + '/onnxruntime_wasm_threads.js': '/base/dist/onnxruntime_wasm_threads.js', + '/onnxruntime_wasm_threads.wasm': '/base/dist/onnxruntime_wasm_threads.wasm', + '/onnxruntime_wasm_threads.worker.js': '/base/dist/onnxruntime_wasm_threads.worker.js', '/onnx-worker.js': '/base/test/onnx-worker.js', }, plugins: karmaPlugins, diff --git a/lib/api/inference-session-impl.ts b/lib/api/inference-session-impl.ts index 7931e3fa..126c2d6e 100644 --- a/lib/api/inference-session-impl.ts +++ b/lib/api/inference-session-impl.ts @@ -50,7 +50,7 @@ export class InferenceSession implements InferenceSessionInterface { output = await this.session.run(modelInputFeed); } else if (Array.isArray(inputFeed)) { const modelInputFeed: InternalTensor[] = []; - inputFeed.forEach((value) => { + inputFeed.forEach((value: ApiTensor) => { modelInputFeed.push(value.internalTensor); }); output = await this.session.run(modelInputFeed); diff --git a/lib/wasm-binding-core.ts b/lib/wasm-binding-core.ts index 4a9c1e2f..71b574cc 100644 --- a/lib/wasm-binding-core.ts +++ b/lib/wasm-binding-core.ts @@ -76,8 +76,16 @@ export interface PerformanceData { endTimeFunc?: number; } +// an interface to load wasm into global window instance +declare global { + interface Window { + onnxWasmBindingJs?: OnnxWasmBindingJs; + onnxWasmThreadsBindingJs?: OnnxWasmBindingJs; + } +} + // some global parameters to deal with wasm binding initialization -let binding: OnnxWasmBindingJs|undefined; +let onnxWasmBindingJs: OnnxWasmBindingJs|undefined; let initialized = false; let initializing = false; @@ -86,7 +94,7 @@ let initializing = false; * * this function should be called before any other calls to the WASM binding. */ -export function init(): Promise { +export function init(numWorkers: number): Promise { if (initialized) { return Promise.resolve(); } @@ -97,19 +105,36 @@ export function init(): Promise { initializing = true; return new Promise((resolve, reject) => { - // tslint:disable-next-line:no-require-imports - binding = require('../dist/onnxruntime_wasm') as OnnxWasmBindingJs; - binding(binding).then( - () => { - // resolve init() promise - resolve(); - initializing = false; - initialized = true; - }, - err => { - initializing = false; - reject(err); - }); + if (typeof window !== 'undefined') { // Browser + if (numWorkers > 0 && window.hasOwnProperty('onnxWasmThreadsBindingJs')) { + onnxWasmBindingJs = window.onnxWasmThreadsBindingJs as OnnxWasmBindingJs; + } else if (window.hasOwnProperty('onnxWasmBindingJs')) { + onnxWasmBindingJs = window.onnxWasmBindingJs as OnnxWasmBindingJs; + } + } else { // Node + if (numWorkers > 0) { + // tslint:disable-next-line:no-require-imports + onnxWasmBindingJs = require('../dist/onnxruntime_wasm_threads') as OnnxWasmBindingJs; + } else { + // tslint:disable-next-line:no-require-imports + onnxWasmBindingJs = require('../dist/onnxruntime_wasm') as OnnxWasmBindingJs; + } + } + if (typeof onnxWasmBindingJs === 'undefined') { + throw new Error('Wasm is not defined'); + } + onnxWasmBindingJs(onnxWasmBindingJs) + .then( + () => { + // resolve init() promise + resolve(); + initializing = false; + initialized = true; + }, + err => { + initializing = false; + reject(err); + }); }); } @@ -138,13 +163,13 @@ export class WasmBinding { if (size > this.numBytesAllocated) { this.expandMemory(size); } - WasmBinding.ccallSerialize(binding!.HEAPU8.subarray(this.ptr8, this.ptr8 + size), offset, params); + WasmBinding.ccallSerialize(onnxWasmBindingJs!.HEAPU8.subarray(this.ptr8, this.ptr8 + size), offset, params); const startTimeFunc = now(); this.func(functionName, this.ptr8); const endTimeFunc = now(); - WasmBinding.ccallDeserialize(binding!.HEAPU8.subarray(this.ptr8, this.ptr8 + size), offset, params); + WasmBinding.ccallDeserialize(onnxWasmBindingJs!.HEAPU8.subarray(this.ptr8, this.ptr8 + size), offset, params); const endTime = now(); return {startTime, endTime, startTimeFunc, endTimeFunc}; @@ -164,14 +189,14 @@ export class WasmBinding { } // copy input memory (data) to WASM heap - binding!.HEAPU8.subarray(this.ptr8, this.ptr8 + size).set(data); + onnxWasmBindingJs!.HEAPU8.subarray(this.ptr8, this.ptr8 + size).set(data); const startTimeFunc = now(); this.func(functionName, this.ptr8); const endTimeFunc = now(); // copy Wasm heap to output memory (data) - data.set(binding!.HEAPU8.subarray(this.ptr8, this.ptr8 + size)); + data.set(onnxWasmBindingJs!.HEAPU8.subarray(this.ptr8, this.ptr8 + size)); const endTime = now(); return {startTime, endTime, startTimeFunc, endTimeFunc}; @@ -179,7 +204,7 @@ export class WasmBinding { protected func(functionName: string, ptr8: number): void { // tslint:disable-next-line:no-any - const func = (binding as any)[functionName] as (data: number) => void; + const func = (onnxWasmBindingJs as any)[functionName] as (data: number) => void; func(ptr8); } @@ -335,11 +360,11 @@ export class WasmBinding { private expandMemory(minBytesRequired: number) { // free already held memory if applicable if (this.ptr8 !== 0) { - binding!._free(this.ptr8); + onnxWasmBindingJs!._free(this.ptr8); } // current simplistic strategy is to allocate 2 times the minimum bytes requested this.numBytesAllocated = 2 * minBytesRequired; - this.ptr8 = binding!._malloc(this.numBytesAllocated); + this.ptr8 = onnxWasmBindingJs!._malloc(this.numBytesAllocated); if (this.ptr8 === 0) { throw new Error('Unable to allocate requested amount of memory. Failing.'); } @@ -350,13 +375,13 @@ export class WasmBinding { throw new Error(`wasm not initialized. please ensure 'init()' is called.`); } if (this.ptr8 !== 0) { - binding!._free(this.ptr8); + onnxWasmBindingJs!._free(this.ptr8); } } } export function getInstance(): OnnxWasmBindingJs { - return binding!; + return onnxWasmBindingJs!; } /** diff --git a/lib/wasm-binding.ts b/lib/wasm-binding.ts index eb15000f..1ce4f446 100644 --- a/lib/wasm-binding.ts +++ b/lib/wasm-binding.ts @@ -61,7 +61,7 @@ export function init(numWorkers: number, initTimeout: number): Promise { initializing = false; }; - const bindingInitTask = bindingCore.init(); + const bindingInitTask = bindingCore.init(numWorkers); // a promise that gets rejected after 5s to work around the fact that // there is an unrejected promise in the wasm glue logic file when // it has some problem instantiating the wasm file @@ -78,7 +78,8 @@ export function init(numWorkers: number, initTimeout: number): Promise { if (areWebWorkersSupported()) { Logger.verbose( 'WebAssembly-Workers', `Environment supports usage of Workers. Will spawn ${numWorkers} Workers`); - WORKER_NUMBER = numWorkers; + // TODO: This code will be replaced to control the number of WebAssembly threads later. + WORKER_NUMBER = 0; } else { Logger.error('WebAssembly-Workers', 'Environment does not support usage of Workers. Will not spawn workers.'); WORKER_NUMBER = 0; diff --git a/webpack.config.js b/webpack.config.js index 154b7134..c481ea2a 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -10,7 +10,13 @@ module.exports = (env, argv) => { resolve: {extensions: ['.ts', '.js']}, plugins: [new webpack.WatchIgnorePlugin([/\.js$/, /\.d\.ts$/])], module: {rules: [{test: /\.tsx?$/, loader: 'ts-loader'}]}, - node: {fs: 'empty'} + node: {fs: 'empty'}, + externals: { + '../dist/onnxruntime_wasm':'onnxWasmBindingJs', + '../dist/onnxruntime_wasm_threads':'onnxWasmBindingJs', + 'perf_hooks':'perf_hooks', + 'worker_threads':'worker_threads' + } }; if (bundleMode === 'perf' || bundleMode === 'dev') { From 92b12dccef72170077885b0440b064bf68fb9534 Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Wed, 5 May 2021 15:17:15 -0700 Subject: [PATCH 4/4] update to use latest filename --- docs/development.md | 4 ++-- karma.conf.js | 20 ++++++++++---------- lib/backends/wasm/session-handler.ts | 2 +- lib/wasm-binding-core.ts | 18 +++++++++--------- tools/build.ts | 6 +++--- webpack.config.js | 4 ++-- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/docs/development.md b/docs/development.md index 8889bb3f..8ef957f7 100644 --- a/docs/development.md +++ b/docs/development.md @@ -22,8 +22,8 @@ Please follow the following steps to running tests: 1. run `npm ci` in the root folder of the repo. 2. (Optional) build WebAssembly backend: - 1. build ONNX Runtime WebAssembly and copy files "onnxruntime_wasm.\*" to /dist/. - 2. if building ONNX Runtime WebAssembly with multi-threads support, copy files "onnxruntime_wasm_threads.\*" to /dist/. + 1. build ONNX Runtime WebAssembly and copy files "ort-wasm.\*" to /dist/. + 2. if building ONNX Runtime WebAssembly with multi-threads support, copy files "ort-wasm-threaded.\*" to /dist/. 3. run `npm run build` in the root folder of the repo to enable WebAssebmly features. 3. run `npm test` to run suite0 test cases and check the console output. - if (2) is not run, please run `npm test -- -b=cpu,webgl` to skip WebAssebmly tests diff --git a/karma.conf.js b/karma.conf.js index 5f9358c7..9f96a8ff 100644 --- a/karma.conf.js +++ b/karma.conf.js @@ -48,18 +48,18 @@ module.exports = function (config) { { pattern: 'test/data/**/*', included: false, nocache: true }, { pattern: 'deps/data/data/test/**/*', included: false, nocache: true }, { pattern: 'deps/onnx/onnx/backend/test/data/**/*', included: false, nocache: true }, - { pattern: 'dist/onnxruntime_wasm.js', included: true }, - { pattern: 'dist/onnxruntime_wasm.wasm', included: false }, - { pattern: 'dist/onnxruntime_wasm_threads.js', included: true }, - { pattern: 'dist/onnxruntime_wasm_threads.wasm', included: false }, - { pattern: 'dist/onnxruntime_wasm_threads.worker.js', included: false }, + { pattern: 'dist/ort-wasm.js', included: true }, + { pattern: 'dist/ort-wasm.wasm', included: false }, + { pattern: 'dist/ort-wasm-threaded.js', included: true }, + { pattern: 'dist/ort-wasm-threaded.wasm', included: false }, + { pattern: 'dist/ort-wasm-threaded.worker.js', included: false }, ], proxies: { - '/onnxruntime_wasm.js': '/base/dist/onnxruntime_wasm.js', - '/onnxruntime_wasm.wasm': '/base/dist/onnxruntime_wasm.wasm', - '/onnxruntime_wasm_threads.js': '/base/dist/onnxruntime_wasm_threads.js', - '/onnxruntime_wasm_threads.wasm': '/base/dist/onnxruntime_wasm_threads.wasm', - '/onnxruntime_wasm_threads.worker.js': '/base/dist/onnxruntime_wasm_threads.worker.js', + '/ort-wasm.js': '/base/dist/ort-wasm.js', + '/ort-wasm.wasm': '/base/dist/ort-wasm.wasm', + '/ort-wasm-threaded.js': '/base/dist/ort-wasm-threaded.js', + '/ort-wasm-threaded.wasm': '/base/dist/ort-wasm-threaded.wasm', + '/ort-wasm-threaded.worker.js': '/base/dist/ort-wasm-threaded.worker.js', '/onnx-worker.js': '/base/test/onnx-worker.js', }, plugins: karmaPlugins, diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index 97f1c3f6..47cfd64d 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -34,7 +34,7 @@ export class WasmSessionHandler implements SessionHandler { loadModel(model: Uint8Array) { const wasm = getInstance(); if (!this.ortInit) { - wasm._OrtInit(); + wasm._OrtInit(4, 0); this.ortInit = true; } diff --git a/lib/wasm-binding-core.ts b/lib/wasm-binding-core.ts index 71b574cc..3b05bcba 100644 --- a/lib/wasm-binding-core.ts +++ b/lib/wasm-binding-core.ts @@ -26,7 +26,7 @@ declare interface OnnxWasmBindingJs { lengthBytesUTF8(str: string): number; stringToUTF8(str: string, offset: number, maxBytes: number): void; - _OrtInit(): void; + _OrtInit(numThreads: number, loggingLevel: number): void; _OrtCreateSession(dataOffset: number, dataLength: number): number; _OrtReleaseSession(sessionHandle: number): void; @@ -79,8 +79,8 @@ export interface PerformanceData { // an interface to load wasm into global window instance declare global { interface Window { - onnxWasmBindingJs?: OnnxWasmBindingJs; - onnxWasmThreadsBindingJs?: OnnxWasmBindingJs; + ortWasm?: OnnxWasmBindingJs; + ortWasmThreaded?: OnnxWasmBindingJs; } } @@ -106,18 +106,18 @@ export function init(numWorkers: number): Promise { return new Promise((resolve, reject) => { if (typeof window !== 'undefined') { // Browser - if (numWorkers > 0 && window.hasOwnProperty('onnxWasmThreadsBindingJs')) { - onnxWasmBindingJs = window.onnxWasmThreadsBindingJs as OnnxWasmBindingJs; - } else if (window.hasOwnProperty('onnxWasmBindingJs')) { - onnxWasmBindingJs = window.onnxWasmBindingJs as OnnxWasmBindingJs; + if (numWorkers > 0 && window.hasOwnProperty('ortWasmThreaded')) { + onnxWasmBindingJs = window.ortWasmThreaded as OnnxWasmBindingJs; + } else if (window.hasOwnProperty('ortWasm')) { + onnxWasmBindingJs = window.ortWasm as OnnxWasmBindingJs; } } else { // Node if (numWorkers > 0) { // tslint:disable-next-line:no-require-imports - onnxWasmBindingJs = require('../dist/onnxruntime_wasm_threads') as OnnxWasmBindingJs; + onnxWasmBindingJs = require('../dist/ort-wasm-threaded') as OnnxWasmBindingJs; } else { // tslint:disable-next-line:no-require-imports - onnxWasmBindingJs = require('../dist/onnxruntime_wasm') as OnnxWasmBindingJs; + onnxWasmBindingJs = require('../dist/ort-wasm') as OnnxWasmBindingJs; } } if (typeof onnxWasmBindingJs === 'undefined') { diff --git a/tools/build.ts b/tools/build.ts index f9f9af48..ccc41297 100644 --- a/tools/build.ts +++ b/tools/build.ts @@ -26,8 +26,8 @@ const TEST = path.join(ROOT, 'test'); const TEST_DATA = path.join(TEST, 'data'); const TEST_DATA_NODE = path.join(TEST_DATA, 'node'); const OUT = path.join(ROOT, 'dist'); -const OUT_WASM_JS = path.join(OUT, 'onnxruntime_wasm.js'); -const OUT_WASM = path.join(OUT, 'onnxruntime_wasm.wasm'); +const OUT_WASM_JS = path.join(OUT, 'ort-wasm.js'); +const OUT_WASM = path.join(OUT, 'ort-wasm.wasm'); npmlog.info('Build', 'Initialization completed. Start to build...'); @@ -111,7 +111,7 @@ if (!buildWasm) { } } else { if (!fs.existsSync(OUT_WASM)) { - npmlog.error('Build.Wasm', 'Please make sure onnxruntime_wasm.wasm is built and exists in /dist/'); + npmlog.error('Build.Wasm', 'Please make sure ort-wasm.wasm is built and exists in /dist/'); process.exit(1); } } diff --git a/webpack.config.js b/webpack.config.js index c481ea2a..dde6cdfa 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -12,8 +12,8 @@ module.exports = (env, argv) => { module: {rules: [{test: /\.tsx?$/, loader: 'ts-loader'}]}, node: {fs: 'empty'}, externals: { - '../dist/onnxruntime_wasm':'onnxWasmBindingJs', - '../dist/onnxruntime_wasm_threads':'onnxWasmBindingJs', + '../dist/ort-wasm':'onnxWasmBindingJs', + '../dist/ort-wasm-threaded':'onnxWasmBindingJs', 'perf_hooks':'perf_hooks', 'worker_threads':'worker_threads' }