diff --git a/lib/backend.ts b/lib/backend.ts index 5b3cbacc..c467063e 100644 --- a/lib/backend.ts +++ b/lib/backend.ts @@ -3,6 +3,7 @@ import {Graph} from './graph'; import {Operator} from './operators'; +import {OpSet} from './opset'; import {Session} from './session'; export interface InferenceHandler { @@ -30,12 +31,11 @@ export interface SessionHandler { dispose(): void; /** - * Resolves the operator from the name; backend specific + * Resolves the operator from the name and opset version; backend specific * @param node - * @param domain - * @param version + * @param opsets */ - resolve(node: Graph.Node, domain: string, version: number): Operator; + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator; /** * This method let's the sessionHandler know that the graph initialization is complete * @param graph the completely initialized graph diff --git a/lib/backends/cpu/op-resolve-rules.ts b/lib/backends/cpu/op-resolve-rules.ts new file mode 100644 index 00000000..759753c7 --- /dev/null +++ b/lib/backends/cpu/op-resolve-rules.ts @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {FLOAT_TYPES, NUMBER_TYPES} from '../../operators'; +import {OpSet} from '../../opset'; + +import {CpuArgMax} from './ops/argMax'; +import {CpuBatchNormalization} from './ops/batch-normalization'; +import {CpuBinaryOp} from './ops/binary-op'; +import {CpuConcat} from './ops/concat'; +import {CpuConv} from './ops/conv'; +import {CpuDropout} from './ops/dropout'; +import {CpuFlatten} from './ops/flatten'; +import {CpuGather} from './ops/gather'; +import {CpuGemm} from './ops/gemm'; +import {CpuImageScaler} from './ops/image-scaler'; +import {CpuInstanceNormalization} from './ops/instance-normalization'; +import {CpuLrn} from './ops/lrn'; +import {CpuMatMul} from './ops/matmul'; +import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool'; +import * as cpuReduce from './ops/reduce'; +import {CpuReshape} from './ops/reshape'; +import {CpuSlice} from './ops/slice'; +import {CpuSoftmax} from './ops/softmax'; +import {CpuSqueeze} from './ops/squeeze'; +import {CpuSum} from './ops/sum'; +import {CpuTile} from './ops/tile'; +import {CpuTranspose} from './ops/transpose'; +import * as unaryOps from './ops/unary-op'; +import {CpuUnsqueeze} from './ops/unsqueeze'; + +export const CPU_OP_RESOLVE_RULES: ReadonlyArray = [ + ['Abs', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.abs)], + ['Acos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.acos)], + ['Add', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2))], + ['And', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2))], + ['ArgMax', '', '1+', () => new CpuArgMax()], + ['Asin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.asin)], + ['Atan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.atan)], + ['AveragePool', '', '7+', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10 + ['BatchNormalization', '', '7+', () => new CpuBatchNormalization()], + ['Ceil', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.ceil)], + ['Clip', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip)], + ['Concat', '', '4+', () => new CpuConcat()], + ['Conv', '', '1+', () => new CpuConv()], + ['Cos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.cos)], + ['Div', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 / e2))], + ['Dropout', '', '7+', () => new CpuDropout()], + ['Elu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.elu)], + ['Exp', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.exp)], + ['Flatten', '', '1+', () => new CpuFlatten()], + ['Floor', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.floor)], + ['Gather', '', '1+', () => new CpuGather()], + ['Gemm', '', '7+', () => new CpuGemm()], + ['GlobalAveragePool', '', '1+', () => new CpuGlobalAveragePool()], + ['GlobalMaxPool', '', '1+', () => new CpuGlobalMaxPool()], + ['ImageScaler', '', '1+', () => new CpuImageScaler()], + ['InstanceNormalization', '', '6+', () => new CpuInstanceNormalization()], + ['LeakyRelu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.leakyRelu)], + ['Log', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.log)], + ['LRN', '', '1+', () => new CpuLrn()], + ['MatMul', '', '1+', () => new CpuMatMul()], + ['MaxPool', '', '1+', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10 + ['Mul', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2))], + ['Neg', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.neg)], + ['Or', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 || e2))], + ['PRelu', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 >= 0 ? e1 : e1 * e2))], + ['ReduceLogSum', '', '1+', () => new cpuReduce.CpuReduceLogSum()], + ['ReduceMax', '', '1+', () => new cpuReduce.CpuReduceMax()], + ['ReduceMean', '', '1+', () => new cpuReduce.CpuReduceMean()], + ['ReduceMin', '', '1+', () => new cpuReduce.CpuReduceMin()], + ['ReduceProd', '', '1+', () => new cpuReduce.CpuReduceProd()], + ['ReduceSum', '', '1+', () => new cpuReduce.CpuReduceSum()], + ['ReduceSumSquare', '', '1+', () => new cpuReduce.CpuReduceSumSquare()], + ['Relu', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.relu)], + ['Reshape', '', '5+', () => new CpuReshape()], + ['Sigmoid', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sigmoid)], + ['Sin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sin)], + ['Slice', '', '1+', () => new CpuSlice()], + ['Softmax', '', '1+', () => new CpuSoftmax()], + ['Sqrt', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sqrt)], + ['Squeeze', '', '1+', () => new CpuSqueeze()], + ['Sub', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 - e2))], + ['Sum', '', '6+', () => new CpuSum()], // TODO: support multidirectional broadcast for Sum-8 + ['Tan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tan)], + ['Tanh', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tanh)], + ['Tile', '', '6+', () => new CpuTile()], + ['Transpose', '', '1+', () => new CpuTranspose()], + ['Unsqueeze', '', '1+', () => new CpuUnsqueeze()], + ['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))], +]; diff --git a/lib/backends/cpu/ops-resolve.ts b/lib/backends/cpu/ops-resolve.ts deleted file mode 100644 index 9310d144..00000000 --- a/lib/backends/cpu/ops-resolve.ts +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -import {Graph} from '../../graph'; -import {FLOAT_TYPES, NUMBER_TYPES, Operator} from '../../operators'; - -import {CpuArgMax} from './ops/argMax'; -import {CpuBatchNormalization} from './ops/batch-normalization'; -import {CpuBinaryOp} from './ops/binary-op'; -import {CpuConcat} from './ops/concat'; -import {CpuConv} from './ops/conv'; -import {CpuDropout} from './ops/dropout'; -import {CpuFlatten} from './ops/flatten'; -import {CpuGather} from './ops/gather'; -import {CpuGemm} from './ops/gemm'; -import {CpuImageScaler} from './ops/image-scaler'; -import {CpuInstanceNormalization} from './ops/instance-normalization'; -import {CpuLrn} from './ops/lrn'; -import {CpuMatMul} from './ops/matmul'; -import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool'; -import * as cpuReduce from './ops/reduce'; -import {CpuReshape} from './ops/reshape'; -import {CpuSlice} from './ops/slice'; -import {CpuSoftmax} from './ops/softmax'; -import {CpuSqueeze} from './ops/squeeze'; -import {CpuSum} from './ops/sum'; -import {CpuTile} from './ops/tile'; -import {CpuTranspose} from './ops/transpose'; -import * as unaryOps from './ops/unary-op'; -import {CpuUnsqueeze} from './ops/unsqueeze'; - -export function resolve(node: Graph.Node, domain: string, version: number): Operator { - const op = createOperator(node, domain, version); - op.initialize(node.attributes); - return op; -} - -function createOperator(node: Graph.Node, domain: string, version: number): Operator { - // assume domain=ai.onnx, version=v7 - switch (node.opType) { - // Unary ops - case 'Abs': - return new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.abs); - case 'Neg': - return new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.neg); - case 'Acos': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.acos); - case 'Ceil': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.ceil); - case 'Cos': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.cos); - case 'Clip': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip); - case 'Sin': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.sin); - case 'Tan': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.tan); - case 'Tanh': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.tanh); - case 'Exp': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.exp); - case 'Floor': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.floor); - case 'Atan': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.atan); - case 'Relu': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.relu); - case 'Log': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.log); - case 'Sqrt': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.sqrt); - case 'Asin': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.asin); - case 'Sigmoid': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.sigmoid); - // Binary arithmetic ops - case 'Add': - return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2)); - case 'Sub': - return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 - e2)); - case 'Mul': - return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2)); - case 'Div': - // TODO: Handle division by zero - return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 / e2)); - // Binary logical ops - case 'Xor': - return new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2)); - case 'Or': - return new CpuBinaryOp(['bool'], (e1, e2) => (e1 || e2)); - case 'And': - return new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2)); - // Non-unary and non-binary ops - case 'ArgMax': - return new CpuArgMax(); - case 'BatchNormalization': - return new CpuBatchNormalization(); - case 'Concat': - return new CpuConcat(); - case 'Conv': - return new CpuConv(); - case 'Dropout': - return new CpuDropout(); - case 'Flatten': - return new CpuFlatten(); - case 'Gemm': - return new CpuGemm(); - case 'ImageScaler': - return new CpuImageScaler(); - case 'LRN': - return new CpuLrn(); - case 'LeakyRelu': - // opLambda will be resolved when the op is initialized at which time it will have context of the attribute - // 'alpha' - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.leakyRelu); - case 'Elu': - return new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.elu); - case 'MatMul': - return new CpuMatMul(); - case 'AveragePool': - return new CpuAveragePool(); - case 'MaxPool': - return new CpuMaxPool(); - case 'Gather': - return new CpuGather(); - case 'GlobalMaxPool': - return new CpuGlobalMaxPool(); - case 'GlobalAveragePool': - return new CpuGlobalAveragePool(); - case 'InstanceNormalization': - return new CpuInstanceNormalization(); - case 'PRelu': - return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 >= 0 ? e1 : e1 * e2)); - case 'Reshape': - return new CpuReshape(); - case 'ReduceLogSum': - return new cpuReduce.CpuReduceLogSum(); - case 'ReduceMax': - return new cpuReduce.CpuReduceMax(); - case 'ReduceMean': - return new cpuReduce.CpuReduceMean(); - case 'ReduceMin': - return new cpuReduce.CpuReduceMin(); - case 'ReduceProd': - return new cpuReduce.CpuReduceProd(); - case 'ReduceSum': - return new cpuReduce.CpuReduceSum(); - case 'ReduceSumSquare': - return new cpuReduce.CpuReduceSumSquare(); - case 'Slice': - return new CpuSlice(); - case 'Softmax': - return new CpuSoftmax(); - case 'Squeeze': - return new CpuSqueeze(); - case 'Sum': - return new CpuSum(); - case 'Tile': - return new CpuTile(); - case 'Transpose': - return new CpuTranspose(); - case 'Unsqueeze': - return new CpuUnsqueeze(); - default: - throw new TypeError(`unrecognized operator '${node.opType}'`); - } -} diff --git a/lib/backends/cpu/session-handler.ts b/lib/backends/cpu/session-handler.ts index f4e7bd7c..0f7328c6 100644 --- a/lib/backends/cpu/session-handler.ts +++ b/lib/backends/cpu/session-handler.ts @@ -4,10 +4,11 @@ import {Backend, InferenceHandler, SessionHandler} from '../../backend'; import {Graph} from '../../graph'; import {Operator} from '../../operators'; +import {OpSet, resolveOperator} from '../../opset'; import {Session} from '../../session'; import {CpuInferenceHandler} from './inference-handler'; -import {resolve} from './ops-resolve'; +import {CPU_OP_RESOLVE_RULES} from './op-resolve-rules'; export class CpuSessionHandler implements SessionHandler { constructor(readonly backend: Backend, readonly context: Session.Context) {} @@ -18,9 +19,9 @@ export class CpuSessionHandler implements SessionHandler { dispose(): void {} - resolve(node: Graph.Node, domain: string, version: number): Operator { - // We have kept the ops resolve logic separately to be leveraged by other components (if needed) - // This is valid only if there is no statefulness associated with the op resolution logic (which is currently true) - return resolve(node, domain, version); + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator { + const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES); + op.initialize(node.attributes); + return op; } } diff --git a/lib/backends/wasm/op-resolve-rules.ts b/lib/backends/wasm/op-resolve-rules.ts new file mode 100644 index 00000000..826426d5 --- /dev/null +++ b/lib/backends/wasm/op-resolve-rules.ts @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {OpSet} from '../../opset'; + +import {WasmBatchNormalization} from './ops/batch-normalization'; +import {WasmBinaryOp} from './ops/binary-op'; +import {WasmClip} from './ops/clip'; +import {WasmConv} from './ops/conv'; +import {WasmGemm} from './ops/gemm'; +import {WasmInstanceNormalization} from './ops/instance-normalization'; +import {WasmMatMul} from './ops/matmul'; +import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool'; +import {WasmSoftmax} from './ops/softmax'; +import {WasmSum} from './ops/sum'; + +export const WASM_OP_RESOLVE_RULES: ReadonlyArray = [ + ['Add', '', '7+', () => new WasmBinaryOp(['float32'], 'Add')], + ['And', '', '7+', () => new WasmBinaryOp(['bool'], 'And')], + ['AveragePool', '', '7+', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10 + ['BatchNormalization', '', '7+', () => new WasmBatchNormalization()], + ['Clip', '', '6+', () => new WasmClip()], + ['Conv', '', '1+', () => new WasmConv()], + ['Div', '', '7+', () => new WasmBinaryOp(['float32'], 'Div')], + ['Gemm', '', '7+', () => new WasmGemm()], + ['GlobalAveragePool', '', '1+', () => new WasmGlobalAveragePool()], + ['GlobalMaxPool', '', '1+', () => new WasmGlobalMaxPool()], + ['InstanceNormalization', '', '6+', () => new WasmInstanceNormalization()], + ['MatMul', '', '1+', () => new WasmMatMul()], + ['MaxPool', '', '1+', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10 + ['Mul', '', '7+', () => new WasmBinaryOp(['float32'], 'Mul')], + ['Or', '', '7+', () => new WasmBinaryOp(['bool'], 'Or')], + ['PRelu', '', '7+', () => new WasmBinaryOp(['float32'], 'PRelu')], + ['Softmax', '', '1+', () => new WasmSoftmax()], + ['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')], + ['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8 + ['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')], +]; diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index 7eb10274..87b9b421 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -4,23 +4,18 @@ import {Backend, InferenceHandler, SessionHandler} from '../../backend'; import {Graph} from '../../graph'; import {Operator} from '../../operators'; +import {OpSet, resolveOperator} from '../../opset'; import {Session} from '../../session'; -import {resolve} from '../cpu/ops-resolve'; +import {CPU_OP_RESOLVE_RULES} from '../cpu/op-resolve-rules'; import {WasmInferenceHandler} from './inference-handler'; -import {WasmBatchNormalization} from './ops/batch-normalization'; -import {WasmBinaryOp} from './ops/binary-op'; -import {WasmClip} from './ops/clip'; -import {WasmConv} from './ops/conv'; -import {WasmGemm} from './ops/gemm'; -import {WasmInstanceNormalization} from './ops/instance-normalization'; -import {WasmMatMul} from './ops/matmul'; -import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool'; -import {WasmSoftmax} from './ops/softmax'; -import {WasmSum} from './ops/sum'; +import {WASM_OP_RESOLVE_RULES} from './op-resolve-rules'; export class WasmSessionHandler implements SessionHandler { - constructor(readonly backend: Backend, readonly context: Session.Context, private fallbackToCpuOps: boolean) {} + private opResolveRules: ReadonlyArray; + constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) { + this.opResolveRules = fallbackToCpuOps ? WASM_OP_RESOLVE_RULES.concat(CPU_OP_RESOLVE_RULES) : WASM_OP_RESOLVE_RULES; + } createInferenceHandler(): InferenceHandler { return new WasmInferenceHandler(this, this.context.profiler); @@ -28,64 +23,9 @@ export class WasmSessionHandler implements SessionHandler { dispose(): void {} - resolve(node: Graph.Node, domain: string, version: number): Operator { - const op = this.createOperator(node, domain, version); + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator { + const op = resolveOperator(node, opsets, this.opResolveRules); op.initialize(node.attributes); return op; } - - private createOperator(node: Graph.Node, domain: string, version: number): Operator { - // assume domain=ai.onnx, version=v7 - switch (node.opType) { - // Binary arithmetic ops - case 'Add': - return new WasmBinaryOp(['float32'], 'Add'); - case 'Sub': - return new WasmBinaryOp(['float32'], 'Sub'); - case 'Mul': - return new WasmBinaryOp(['float32'], 'Mul'); - case 'Div': - // TODO: Handle division by zero - return new WasmBinaryOp(['float32'], 'Div'); - // Binary logical ops - case 'Xor': - return new WasmBinaryOp(['bool'], 'Xor'); - case 'Or': - return new WasmBinaryOp(['bool'], 'Or'); - case 'And': - return new WasmBinaryOp(['bool'], 'And'); - // Misc ops - case 'Conv': - return new WasmConv(); - case 'Clip': - return new WasmClip(); - case 'BatchNormalization': - return new WasmBatchNormalization(); - case 'Gemm': - return new WasmGemm(); - case 'MatMul': - return new WasmMatMul(); - case 'Softmax': - return new WasmSoftmax(); - case 'Sum': - return new WasmSum(); - case 'AveragePool': - return new WasmAveragePool(); - case 'MaxPool': - return new WasmMaxPool(); - case 'GlobalMaxPool': - return new WasmGlobalMaxPool(); - case 'GlobalAveragePool': - return new WasmGlobalAveragePool(); - case 'InstanceNormalization': - return new WasmInstanceNormalization(); - case 'PRelu': - return new WasmBinaryOp(['float32'], 'PRelu'); - default: - if (this.fallbackToCpuOps) { - return resolve(node, domain, version); - } - throw new TypeError(`unrecognized operator '${node.opType}'`); - } - } } diff --git a/lib/backends/webgl/op-resolve-rules.ts b/lib/backends/webgl/op-resolve-rules.ts new file mode 100644 index 00000000..c1cad4a0 --- /dev/null +++ b/lib/backends/webgl/op-resolve-rules.ts @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {FLOAT_TYPES, NUMBER_TYPES} from '../../operators'; +import {OpSet} from '../../opset'; + +import {WebGLBatchNormalization} from './ops/batch-normalization'; +import * as binaryOps from './ops/binary-op'; +import {WebGLClip} from './ops/clip'; +import {WebGLConcat} from './ops/concat'; +import {WebGLConv} from './ops/conv'; +import {WebGLDropout} from './ops/dropout'; +import {WebGLElu} from './ops/elu'; +import {WebGLFlatten} from './ops/flatten'; +import {WebGLGather} from './ops/gather'; +import {WebGLGemm} from './ops/gemm'; +import {WebGLImageScaler} from './ops/image-scaler'; +import {WebGLLeakyRelu} from './ops/leaky-relu'; +import {WebGLMatMul} from './ops/matmul'; +import {WebGLPad} from './ops/pad'; +import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPool} from './ops/pool'; +import * as reduceOps from './ops/reduce'; +import {WebGLReshape} from './ops/reshape'; +import {WebGLSlice} from './ops/slice'; +import {WebGLSoftmax} from './ops/softmax'; +import {WebGLSplit} from './ops/split'; +import {WebGLSqueeze} from './ops/squeeze'; +import {WebGLSum} from './ops/sum'; +import {WebGLTile} from './ops/tile'; +import {WebGLTranspose} from './ops/transpose'; +import * as unaryOps from './ops/unary-op'; +import {WebGLUnsqueeze} from './ops/unsqueeze'; + +export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray = [ + ['Abs', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslAbs())], + ['Acos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAcos())], + ['Add', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslAdd())], + ['And', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslAnd())], + ['Asin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAsin())], + ['Atan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAtan())], + ['AveragePool', '', '7+', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10 + ['BatchNormalization', '', '7+', () => new WebGLBatchNormalization()], + ['Ceil', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil())], + ['Clip', '', '6+', () => new WebGLClip()], + ['Concat', '', '4+', () => new WebGLConcat()], + ['Conv', '', '1+', () => new WebGLConv()], + ['Cos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos())], + ['Div', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslDiv())], + ['Dropout', '', '7+', () => new WebGLDropout()], + ['Equal', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool')], + ['Elu', '', '6+', () => new WebGLElu()], + ['Exp', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp())], + ['Flatten', '', '1+', () => new WebGLFlatten()], + ['Floor', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor())], + ['Gather', '', '1+', () => new WebGLGather()], + ['Gemm', '', '7+', () => new WebGLGemm()], + ['GlobalAveragePool', '', '1+', () => new WebGLGlobalAveragePool()], + ['GlobalMaxPool', '', '1+', () => new WebGLGlobalMaxPool()], + ['Greater', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool')], + ['Identity', '', '1+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslIdentity())], + ['ImageScaler', '', '1+', () => new WebGLImageScaler()], + ['LeakyRelu', '', '6+', () => new WebGLLeakyRelu()], + ['Less', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool')], + ['Log', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog())], + ['MatMul', '', '1+', () => new WebGLMatMul()], + ['MaxPool', '', '1+', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10 + ['Mul', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslMul())], + ['Neg', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslNeg())], + ['Not', '', '1+', () => new unaryOps.WebGLUnaryOp(['bool'], unaryOps.glslNot())], + ['Or', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslOr())], + ['Pad', '', '2+', () => new WebGLPad()], + ['Pow', '', '7+', () => new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPow())], + ['PRelu', '', '7+', () => new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPRelu())], + ['ReduceLogSum', '', '1+', () => new reduceOps.WebGLReduceLogSum()], + ['ReduceMax', '', '1+', () => new reduceOps.WebGLReduceMax()], + ['ReduceMean', '', '1+', () => new reduceOps.WebGLReduceMean()], + ['ReduceMin', '', '1+', () => new reduceOps.WebGLReduceMin()], + ['ReduceProd', '', '1+', () => new reduceOps.WebGLReduceProd()], + ['ReduceSum', '', '1+', () => new reduceOps.WebGLReduceSum()], + ['ReduceSumSquare', '', '1+', () => new reduceOps.WebGLReduceSumSquare()], + ['Relu', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslRelu())], + ['Reshape', '', '5+', () => new WebGLReshape()], + ['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())], + ['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())], + ['Slice', '', '1+', () => new WebGLSlice()], + ['Softmax', '', '1+', () => new WebGLSoftmax()], + // 'Split' operator has an optional attribute 'split' + // this attribute determines how the specified axis of input data + // is split. When the attribute is missing, we need the count of number of outputs + // so that we can determine the 'split' attribute from the runtime input to the Operator + ['Split', '', '2+', (node) => new WebGLSplit(node.outputs.length)], + ['Sqrt', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSqrt())], + ['Squeeze', '', '1+', () => new WebGLSqueeze()], + ['Sub', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslSub())], + ['Sum', '', '6+', () => new WebGLSum()], // TODO: support multidirectional broadcast for Sum-8 + ['Tan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTan())], + ['Tanh', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTanh())], + ['Tile', '', '6+', () => new WebGLTile()], + ['Transpose', '', '1+', () => new WebGLTranspose()], + ['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()], + ['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())], +]; diff --git a/lib/backends/webgl/session-handler.ts b/lib/backends/webgl/session-handler.ts index 67cd79f3..a6659aef 100644 --- a/lib/backends/webgl/session-handler.ts +++ b/lib/backends/webgl/session-handler.ts @@ -1,47 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +import {SessionHandler} from '../../backend'; import {Graph} from '../../graph'; import {Logger} from '../../instrument'; -import {FLOAT_TYPES, NUMBER_TYPES, Operator} from '../../operators'; +import {Operator} from '../../operators'; +import {OpSet, resolveOperator} from '../../opset'; import {Session} from '../../session'; import {Tensor} from '../../tensor'; import {WebGLBackend} from '../backend-webgl'; -import {SessionHandler} from './../../backend'; import {WebGLInferenceHandler} from './inference-handler'; -import {WebGLBatchNormalization} from './ops/batch-normalization'; -import * as binaryOps from './ops/binary-op'; -import {WebGLClip} from './ops/clip'; -import {WebGLConcat} from './ops/concat'; -import {WebGLConv} from './ops/conv'; -import {WebGLDropout} from './ops/dropout'; -import {WebGLElu} from './ops/elu'; -import {WebGLFlatten} from './ops/flatten'; -import {WebGLGather} from './ops/gather'; -import {WebGLGemm} from './ops/gemm'; -import {WebGLImageScaler} from './ops/image-scaler'; -import {WebGLLeakyRelu} from './ops/leaky-relu'; -import {WebGLMatMul} from './ops/matmul'; -import {WebGLPad} from './ops/pad'; -import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPool} from './ops/pool'; -import {WebGLReduceSum} from './ops/reduce'; -import {WebGLReduceMean} from './ops/reduce'; -import {WebGLReduceMax} from './ops/reduce'; -import {WebGLReduceMin} from './ops/reduce'; -import {WebGLReduceProd} from './ops/reduce'; -import {WebGLReduceLogSum} from './ops/reduce'; -import {WebGLReduceSumSquare} from './ops/reduce'; -import {WebGLReshape} from './ops/reshape'; -import {WebGLSlice} from './ops/slice'; -import {WebGLSoftmax} from './ops/softmax'; -import {WebGLSplit} from './ops/split'; -import {WebGLSqueeze} from './ops/squeeze'; -import {WebGLSum} from './ops/sum'; -import {WebGLTile} from './ops/tile'; -import {WebGLTranspose} from './ops/transpose'; -import * as unaryOps from './ops/unary-op'; -import {WebGLUnsqueeze} from './ops/unsqueeze'; +import {WEBGL_OP_RESOLVE_RULES} from './op-resolve-rules'; import {ProgramManager} from './program-manager'; import {TextureData} from './texture-data'; import {TextureHelper} from './texture-helper'; @@ -84,148 +54,9 @@ export class WebGLSessionHandler implements SessionHandler { this.textureDataCache.forEach(td => this.textureHelper.releaseTexture(td.texture)); this.textureDataCache = new Map(); } - resolve(node: Graph.Node, domain: string, version: number): Operator { - const op = this.createOperator(node, domain, version); + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator { + const op = resolveOperator(node, opsets, WEBGL_OP_RESOLVE_RULES); op.initialize(node.attributes); return op; } - - private createOperator(node: Graph.Node, domain: string, version: number): Operator { - // assume domain=ai.onnx, version=v7 - switch (node.opType) { - // Unary ops - case 'Abs': - return new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslAbs()); - case 'Acos': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAcos()); - case 'Add': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslAdd()); - case 'And': - return new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslAnd()); - case 'Asin': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAsin()); - case 'Atan': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAtan()); - case 'AveragePool': - return new WebGLAveragePool(); - case 'BatchNormalization': - return new WebGLBatchNormalization(); - case 'Ceil': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil()); - case 'Clip': - return new WebGLClip(); - case 'Cos': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos()); - case 'Concat': - return new WebGLConcat(); - case 'Conv': - return new WebGLConv(); - case 'Div': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslDiv()); - case 'Dropout': - return new WebGLDropout(); - case 'Equal': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool'); - case 'Exp': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp()); - case 'Flatten': - return new WebGLFlatten(); - case 'Floor': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor()); - case 'Gather': - return new WebGLGather(); - case 'Gemm': - return new WebGLGemm(); - case 'GlobalAveragePool': - return new WebGLGlobalAveragePool(); - case 'GlobalMaxPool': - return new WebGLGlobalMaxPool(); - case 'Greater': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool'); - case 'Identity': - return new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslIdentity()); - case 'ImageScaler': - return new WebGLImageScaler(); - case 'LeakyRelu': - return new WebGLLeakyRelu(); - case 'Elu': - return new WebGLElu(); - case 'Less': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool'); - case 'Log': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog()); - case 'MatMul': - return new WebGLMatMul(); - case 'MaxPool': - return new WebGLMaxPool(); - case 'Mul': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslMul()); - case 'Neg': - return new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslNeg()); - case 'Not': - return new unaryOps.WebGLUnaryOp(['bool'], unaryOps.glslNot()); - case 'Or': - return new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslOr()); - case 'Pad': - return new WebGLPad(); - case 'Pow': - return new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPow()); - case 'PRelu': - return new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPRelu()); - case 'Relu': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslRelu()); - case 'Reshape': - return new WebGLReshape(); - case 'Sigmoid': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid()); - case 'Sin': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin()); - case 'ReduceSum': - return new WebGLReduceSum(); - case 'ReduceMean': - return new WebGLReduceMean(); - case 'ReduceMax': - return new WebGLReduceMax(); - case 'ReduceMin': - return new WebGLReduceMin(); - case 'ReduceProd': - return new WebGLReduceProd(); - case 'ReduceLogSum': - return new WebGLReduceLogSum(); - case 'ReduceSumSquare': - return new WebGLReduceSumSquare(); - case 'Softmax': - return new WebGLSoftmax(); - case 'Split': - // 'Split' operator has an optional attribute 'split' - // this attribute determines how the specified axis of input data - // is split. When the attribute is missing, we need the count of number of outputs - // so that we can determine the 'split' attribute from the runtime input to the Operator - return new WebGLSplit(node.outputs.length); - case 'Squeeze': - return new WebGLSqueeze(); - case 'Sqrt': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSqrt()); - case 'Sub': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslSub()); - case 'Sum': - return new WebGLSum(); - case 'Slice': - return new WebGLSlice(); - case 'Tan': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTan()); - case 'Tanh': - return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTanh()); - case 'Transpose': - return new WebGLTranspose(); - case 'Tile': - return new WebGLTile(); - case 'Unsqueeze': - return new WebGLUnsqueeze(); - case 'Xor': - return new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor()); - default: - throw new TypeError(`unrecognized operator '${node.opType}'`); - } - } } diff --git a/lib/model.ts b/lib/model.ts index 434b8bff..833f8200 100644 --- a/lib/model.ts +++ b/lib/model.ts @@ -4,13 +4,9 @@ import {onnx} from 'onnx-proto'; import {Graph} from './graph'; +import {OpSet} from './opset'; import {LongUtil} from './util'; -interface OpSet { - domain: string; - version: number; -} - export class Model { // empty model constructor() {} @@ -18,8 +14,8 @@ export class Model { load(buf: Buffer, graphInitializer?: Graph.Initializer): void { const modelProto = onnx.ModelProto.decode(buf); const irVersion = LongUtil.longToNumber(modelProto.irVersion); - if (irVersion !== 3) { - throw new Error('only support ONNX model with IR_VERSION=3'); + if (irVersion < 3) { + throw new Error('only support ONNX model with IR_VERSION>=3'); } this._opsets = modelProto.opsetImport.map(i => { diff --git a/lib/ops/pool.ts b/lib/ops/pool.ts index ee3921d9..7e42b771 100644 --- a/lib/ops/pool.ts +++ b/lib/ops/pool.ts @@ -24,6 +24,7 @@ class PoolBase { } protected autoPad: string; + protected ceilMode: number; protected countIncludePad: boolean; protected kernelShape: number[]; protected strides: number[]; @@ -39,6 +40,12 @@ export abstract class AveragePool extends PoolBase implements Operator { this.strides = attributes.getInts('strides', []); this.pads = attributes.getInts('pads', []); this.countIncludePad = (attributes.getInt('count_include_pad', 0) === 0 ? false : true); + this.ceilMode = attributes.getInt('ceil_mode', 0); + + // TODO: support attribute 'ceil_mode' + if (this.ceilMode !== 0) { + throw new Error(`using ceil() in shape computation is not yet supported for AveragePool`); + } } } @@ -58,7 +65,19 @@ export abstract class MaxPool extends PoolBase implements Operator { this.kernelShape = attributes.getInts('kernel_shape'); this.strides = attributes.getInts('strides', []); this.pads = attributes.getInts('pads', []); + this.ceilMode = attributes.getInt('ceil_mode', 0); + this.storageOrder = attributes.getInt('storage_order', 0); + + // TODO: support attribute 'ceil_mode' and 'storage_order' + if (this.storageOrder !== 0) { + throw new Error(`column major storage order is not yet supported for MaxPool`); + } + if (this.ceilMode !== 0) { + throw new Error(`using ceil() in shape computation is not yet supported for MaxPool`); + } } + + protected storageOrder: number; } export abstract class GlobalMaxPool extends PoolBase implements Operator { diff --git a/lib/opset.ts b/lib/opset.ts new file mode 100644 index 00000000..6952d690 --- /dev/null +++ b/lib/opset.ts @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Graph} from './graph'; +import {Operator} from './operators'; + +export interface OpSet { + domain: string; + version: number; +} + +export declare namespace OpSet { + interface OperatorConstructor { + (node: Graph.Node): Operator; + } + + /** + * Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml' + */ + type Domain = ''|'ai.onnx.ml'; + + /** + * A resolve rule consists of 4 items: opType, opSetDomain, versionSelector and operatorConstructor + */ + type ResolveRule = [string, Domain, string, OperatorConstructor]; +} + +export function resolveOperator( + node: Graph.Node, opsets: ReadonlyArray, rules: ReadonlyArray) { + for (const rule of rules) { + const opType = rule[0]; + const domain = rule[1]; + const versionSelector = rule[2]; + const opConstructor = rule[3]; + + if (node.opType === opType) { // operator type matches + for (const opset of opsets) { + // opset '' and 'ai.onnx' are considered the same. + if (opset.domain === domain || (opset.domain === 'ai.onnx' && domain === '')) { // opset domain found + if (matchSelector(opset.version, versionSelector)) { + return opConstructor(node); + } + } + } + throw new TypeError(`cannot resolve operator '${opType}' with opsets: ${ + opsets.map(set => `${set.domain || 'ai.onnx'} v${set.version}`).join(', ')}`); + } + } + + throw new TypeError(`unrecognized operator '${node.opType}'`); +} + +function matchSelector(version: number, selector: string): boolean { + if (selector.endsWith('+')) { + // minimum version match ('7+' expects version>=7) + const rangeStart = Number.parseInt(selector.substring(0, selector.length - 1), 10); + return !isNaN(rangeStart) && rangeStart <= version; + } else if (selector.split('-').length === 2) { + // range match ('6-8' expects 6<=version<=8) + const pair = selector.split('-'); + const rangeStart = Number.parseInt(pair[0], 10); + const rangeEnd = Number.parseInt(pair[1], 10); + return !isNaN(rangeStart) && !isNaN(rangeEnd) && rangeStart <= version && version <= rangeEnd; + } else { + // exact match ('7' expects version===7) + return Number.parseInt(selector, 10) === version; + } +} diff --git a/lib/session.ts b/lib/session.ts index a2e02513..1748287e 100644 --- a/lib/session.ts +++ b/lib/session.ts @@ -247,9 +247,7 @@ export class Session { this._ops = new Array(nodes.length); for (let i = 0; i < nodes.length; i++) { - const domain = this._model.opsets[0].domain; - const version = this._model.opsets[0].version; - this._ops[i] = this.sessionHandler.resolve(nodes[i], domain, version); + this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets); } } diff --git a/test/test-runner.ts b/test/test-runner.ts index a3433d27..4c5556e7 100644 --- a/test/test-runner.ts +++ b/test/test-runner.ts @@ -256,7 +256,7 @@ function initializeOperator( sessionHandler: SessionHandler, opType: string, attributeValues: ReadonlyArray): Operator { const attributes = new Attribute(undefined); attributeValues.forEach(value => attributes.set(value.name, value.type, value.data)); - return sessionHandler.resolve({name: '', opType, inputs: [], outputs: [], attributes}, '', 0); + return sessionHandler.resolve({name: '', opType, inputs: [], outputs: [], attributes}, [{domain: '', version: 7}]); } /** diff --git a/test/unittests/backends/webgl/test_conv_new.ts b/test/unittests/backends/webgl/test_conv_new.ts index 9234a291..872218ef 100644 --- a/test/unittests/backends/webgl/test_conv_new.ts +++ b/test/unittests/backends/webgl/test_conv_new.ts @@ -25,7 +25,8 @@ function webglConv( attributes.set('pads', 'ints', pads); } attributes.set('strides', 'ints', strides); - const op = sessionhandler!.resolve({opType: 'Conv', attributes, inputs: [], outputs: [], name: `Conv`}, '', 0); + const op = sessionhandler!.resolve( + {opType: 'Conv', attributes, inputs: [], outputs: [], name: `Conv`}, [{domain: '', version: 7}]); if (!op.checkInputs([inputTensor, kernelTensor])) { throw new Error('Invalid inputs'); } diff --git a/test/unittests/index.ts b/test/unittests/index.ts index c9d1bafb..6ad3eac5 100644 --- a/test/unittests/index.ts +++ b/test/unittests/index.ts @@ -12,3 +12,5 @@ require('./api/onnx'); require('./api/inference-session'); require('./api/tensor'); require('./api/types'); + +require('./opset'); diff --git a/test/unittests/opset.ts b/test/unittests/opset.ts new file mode 100644 index 00000000..52e2aeae --- /dev/null +++ b/test/unittests/opset.ts @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {expect} from 'chai'; + +import {Attribute} from '../../lib/attribute'; +import {CPU_OP_RESOLVE_RULES} from '../../lib/backends/cpu/op-resolve-rules'; +import {WASM_OP_RESOLVE_RULES} from '../../lib/backends/wasm/op-resolve-rules'; +import {WEBGL_OP_RESOLVE_RULES} from '../../lib/backends/webgl/op-resolve-rules'; +import {Graph} from '../../lib/graph'; +import {Operator} from '../../lib/operators'; +import {OpSet, resolveOperator} from '../../lib/opset'; + +describe('#UnitTest# - resolveOperator', () => { + const nodeAbs = createTestGraphNode('Abs_1', 'Abs'); + const opset7 = [{domain: '', version: 7}]; + it('ExpectFail - no rule available', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, []); + }).to.throw(TypeError); + }); + it('ExpectFail - no matching rule', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, [['And', '', '7', dummyOpConstructor], ['Sub', '', '7', dummyOpConstructor]]); + }).to.throw(TypeError); + }); + it('ExpectFail - version not match (exact match)', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '6', dummyOpConstructor]]); + }).to.throw(TypeError); + }); + it('ExpectFail - version not match (minimum version match)', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '8+', dummyOpConstructor]]); + }).to.throw(TypeError); + }); + it('ExpectFail - version not match (range match 1)', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '4-6', dummyOpConstructor]]); + }).to.throw(TypeError); + }); + it('ExpectFail - version not match (range match 2)', () => { + expect(() => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '8-10', dummyOpConstructor]]); + }).to.throw(TypeError); + }); + it('ExpectPass - version match (exact match)', () => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '7', dummyOpConstructor]]); + }); + it('ExpectPass - version match (minimum version match)', () => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '5+', dummyOpConstructor]]); + }); + it('ExpectPass - version match (range match 1)', () => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '5-7', dummyOpConstructor]]); + }); + it('ExpectPass - version match (range match 2)', () => { + resolveOperator(nodeAbs, opset7, [['Abs', '', '6-9', dummyOpConstructor]]); + }); +}); + +describe('#UnitTest# - resolve rules', () => { + const cpuCheckOnlyRules = + CPU_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpConstructor] as OpSet.ResolveRule); + const wasmCheckOnlyRules = + WASM_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpConstructor] as OpSet.ResolveRule); + const webglCheckOnlyRules = + WEBGL_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpConstructor] as OpSet.ResolveRule); + it('Consistency check - onnx.ai - cpu', () => { + checkConsistency(cpuCheckOnlyRules); + }); + it('Consistency check - onnx.ai - wasm', () => { + checkConsistency(wasmCheckOnlyRules); + }); + it('Consistency check - onnx.ai - webgl', () => { + checkConsistency(webglCheckOnlyRules); + }); +}); + +function createTestGraphNode(name: string, opType: string): Graph.Node { + return {name, opType, inputs: [], outputs: [], attributes: new Attribute(null)}; +} + +function dummyOpConstructor(): Operator { + // tslint:disable-next-line:no-any + return {} as any as Operator; +} + +function checkConsistency(rules: ReadonlyArray) { + const VERSION_MIN = 1, VERSION_MAX = 10; + const typeRules = new Map(); + rules.forEach(rule => { + let ruleSet = typeRules.get(rule[0]); + if (!ruleSet) { + ruleSet = []; + typeRules.set(rule[0], ruleSet); + } + ruleSet.push(rule); + }); + + typeRules.forEach((rules, type) => { + for (let i = VERSION_MIN; i < VERSION_MAX; i++) { + let match = false; + for (const r of rules) { + try { + resolveOperator(createTestGraphNode('', type), [{domain: '', version: i}], [r]); + } catch { + continue; + } + expect(match, `multiple rules overlapped: opType='${type}', domain='', version=${i}`).to.be.false; + match = true; + } + } + }); +}