Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
lib: update operator resolve (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire authored Mar 22, 2019
1 parent 0432a35 commit 563dbd6
Show file tree
Hide file tree
Showing 16 changed files with 466 additions and 432 deletions.
8 changes: 4 additions & 4 deletions lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {Graph} from './graph';
import {Operator} from './operators';
import {OpSet} from './opset';
import {Session} from './session';

export interface InferenceHandler {
Expand Down Expand Up @@ -30,12 +31,11 @@ export interface SessionHandler {
dispose(): void;

/**
* Resolves the operator from the name; backend specific
* Resolves the operator from the name and opset version; backend specific
* @param node
* @param domain
* @param version
* @param opsets
*/
resolve(node: Graph.Node, domain: string, version: number): Operator;
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator;
/**
* This method let's the sessionHandler know that the graph initialization is complete
* @param graph the completely initialized graph
Expand Down
91 changes: 91 additions & 0 deletions lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {FLOAT_TYPES, NUMBER_TYPES} from '../../operators';
import {OpSet} from '../../opset';

import {CpuArgMax} from './ops/argMax';
import {CpuBatchNormalization} from './ops/batch-normalization';
import {CpuBinaryOp} from './ops/binary-op';
import {CpuConcat} from './ops/concat';
import {CpuConv} from './ops/conv';
import {CpuDropout} from './ops/dropout';
import {CpuFlatten} from './ops/flatten';
import {CpuGather} from './ops/gather';
import {CpuGemm} from './ops/gemm';
import {CpuImageScaler} from './ops/image-scaler';
import {CpuInstanceNormalization} from './ops/instance-normalization';
import {CpuLrn} from './ops/lrn';
import {CpuMatMul} from './ops/matmul';
import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool';
import * as cpuReduce from './ops/reduce';
import {CpuReshape} from './ops/reshape';
import {CpuSlice} from './ops/slice';
import {CpuSoftmax} from './ops/softmax';
import {CpuSqueeze} from './ops/squeeze';
import {CpuSum} from './ops/sum';
import {CpuTile} from './ops/tile';
import {CpuTranspose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {CpuUnsqueeze} from './ops/unsqueeze';

export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Abs', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.abs)],
['Acos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.acos)],
['Add', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2))],
['And', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2))],
['ArgMax', '', '1+', () => new CpuArgMax()],
['Asin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.asin)],
['Atan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.atan)],
['AveragePool', '', '7+', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new CpuBatchNormalization()],
['Ceil', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.ceil)],
['Clip', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip)],
['Concat', '', '4+', () => new CpuConcat()],
['Conv', '', '1+', () => new CpuConv()],
['Cos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.cos)],
['Div', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 / e2))],
['Dropout', '', '7+', () => new CpuDropout()],
['Elu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.elu)],
['Exp', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.exp)],
['Flatten', '', '1+', () => new CpuFlatten()],
['Floor', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.floor)],
['Gather', '', '1+', () => new CpuGather()],
['Gemm', '', '7+', () => new CpuGemm()],
['GlobalAveragePool', '', '1+', () => new CpuGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new CpuGlobalMaxPool()],
['ImageScaler', '', '1+', () => new CpuImageScaler()],
['InstanceNormalization', '', '6+', () => new CpuInstanceNormalization()],
['LeakyRelu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.leakyRelu)],
['Log', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.log)],
['LRN', '', '1+', () => new CpuLrn()],
['MatMul', '', '1+', () => new CpuMatMul()],
['MaxPool', '', '1+', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2))],
['Neg', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.neg)],
['Or', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 || e2))],
['PRelu', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 >= 0 ? e1 : e1 * e2))],
['ReduceLogSum', '', '1+', () => new cpuReduce.CpuReduceLogSum()],
['ReduceMax', '', '1+', () => new cpuReduce.CpuReduceMax()],
['ReduceMean', '', '1+', () => new cpuReduce.CpuReduceMean()],
['ReduceMin', '', '1+', () => new cpuReduce.CpuReduceMin()],
['ReduceProd', '', '1+', () => new cpuReduce.CpuReduceProd()],
['ReduceSum', '', '1+', () => new cpuReduce.CpuReduceSum()],
['ReduceSumSquare', '', '1+', () => new cpuReduce.CpuReduceSumSquare()],
['Relu', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.relu)],
['Reshape', '', '5+', () => new CpuReshape()],
['Sigmoid', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sigmoid)],
['Sin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sin)],
['Slice', '', '1+', () => new CpuSlice()],
['Softmax', '', '1+', () => new CpuSoftmax()],
['Sqrt', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sqrt)],
['Squeeze', '', '1+', () => new CpuSqueeze()],
['Sub', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 - e2))],
['Sum', '', '6+', () => new CpuSum()], // TODO: support multidirectional broadcast for Sum-8
['Tan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tan)],
['Tanh', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tanh)],
['Tile', '', '6+', () => new CpuTile()],
['Transpose', '', '1+', () => new CpuTranspose()],
['Unsqueeze', '', '1+', () => new CpuUnsqueeze()],
['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))],
];
167 changes: 0 additions & 167 deletions lib/backends/cpu/ops-resolve.ts

This file was deleted.

11 changes: 6 additions & 5 deletions lib/backends/cpu/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import {Backend, InferenceHandler, SessionHandler} from '../../backend';
import {Graph} from '../../graph';
import {Operator} from '../../operators';
import {OpSet, resolveOperator} from '../../opset';
import {Session} from '../../session';

import {CpuInferenceHandler} from './inference-handler';
import {resolve} from './ops-resolve';
import {CPU_OP_RESOLVE_RULES} from './op-resolve-rules';

export class CpuSessionHandler implements SessionHandler {
constructor(readonly backend: Backend, readonly context: Session.Context) {}
Expand All @@ -18,9 +19,9 @@ export class CpuSessionHandler implements SessionHandler {

dispose(): void {}

resolve(node: Graph.Node, domain: string, version: number): Operator {
// We have kept the ops resolve logic separately to be leveraged by other components (if needed)
// This is valid only if there is no statefulness associated with the op resolution logic (which is currently true)
return resolve(node, domain, version);
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES);
op.initialize(node.attributes);
return op;
}
}
38 changes: 38 additions & 0 deletions lib/backends/wasm/op-resolve-rules.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {OpSet} from '../../opset';

import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmClip} from './ops/clip';
import {WasmConv} from './ops/conv';
import {WasmGemm} from './ops/gemm';
import {WasmInstanceNormalization} from './ops/instance-normalization';
import {WasmMatMul} from './ops/matmul';
import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool';
import {WasmSoftmax} from './ops/softmax';
import {WasmSum} from './ops/sum';

export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Add', '', '7+', () => new WasmBinaryOp(['float32'], 'Add')],
['And', '', '7+', () => new WasmBinaryOp(['bool'], 'And')],
['AveragePool', '', '7+', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new WasmBatchNormalization()],
['Clip', '', '6+', () => new WasmClip()],
['Conv', '', '1+', () => new WasmConv()],
['Div', '', '7+', () => new WasmBinaryOp(['float32'], 'Div')],
['Gemm', '', '7+', () => new WasmGemm()],
['GlobalAveragePool', '', '1+', () => new WasmGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new WasmGlobalMaxPool()],
['InstanceNormalization', '', '6+', () => new WasmInstanceNormalization()],
['MatMul', '', '1+', () => new WasmMatMul()],
['MaxPool', '', '1+', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new WasmBinaryOp(['float32'], 'Mul')],
['Or', '', '7+', () => new WasmBinaryOp(['bool'], 'Or')],
['PRelu', '', '7+', () => new WasmBinaryOp(['float32'], 'PRelu')],
['Softmax', '', '1+', () => new WasmSoftmax()],
['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')],
['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8
['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')],
];
Loading

0 comments on commit 563dbd6

Please sign in to comment.