Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Add Cumsum on CPU, Wasm and Webgl #225

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ _This file is automatically generated from the def files via [this script](/tool
| [ConvTranspose](https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvTranspose) | | | |
| [Cos](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cos) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Cos-7) | | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Cos-7) |
| [Cosh](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cosh) | [9+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Cosh-9) | | |
| [CumSum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#CumSum) | | | |
| [CumSum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#CumSum) | [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#CumSum-11) | [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#CumSum-11) | [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#CumSum-11) |
| [DepthToSpace](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace) | | | |
| [DequantizeLinear](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DequantizeLinear) | | | |
| [Det](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Det) | | | |
Expand Down
2 changes: 2 additions & 0 deletions lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {CpuBinaryOp} from './ops/binary-op';
import {CpuCast} from './ops/cast';
import {CpuConcat} from './ops/concat';
import {CpuConv} from './ops/conv';
import {CpuCumSum} from './ops/cumsum';
import {CpuDropout} from './ops/dropout';
import {CpuExpand} from './ops/expand';
import {CpuFlatten} from './ops/flatten';
Expand Down Expand Up @@ -112,4 +113,5 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Upsample', '', '7-8', () => new CpuUpsample()],
['Upsample', '', '9', () => new CpuUpsampleV9()],
['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))],
['CumSum', '', '11+', () => new CpuCumSum()],
];
67 changes: 67 additions & 0 deletions lib/backends/cpu/ops/cumsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {CumSum} from '../../../ops/cumsum';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {CpuInferenceHandler} from '../inference-handler';

export class CpuCumSum extends CumSum {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const ax = inputs[1].integerData[0];
const output = cumsum(inputs[0], ax, this.exclusive, this.reverse);
return [output];
}
}

export function cumsum(x: Tensor, axis: number, exclusive: boolean, reverse: boolean) {
const y = new Tensor(x.dims, x.type);

if (axis < 0) {
axis = y.dims.length + axis;
}

const index: number[] = new Array(y.dims.length).fill(0);
let i = 0;

if (reverse) {
i = y.data.length - 1;
for (let j = 0; j < y.dims.length; j++) {
index[j] = y.dims[j] - 1;
}
}

while (i < y.data.length && i >= 0) {
const prevIndex = updateIndex(index, axis, index[axis] + (reverse ? 1 : -1));

const start = (index[axis] === 0 && !reverse) || (index[axis] === (y.dims[axis] - 1) && reverse);

if (start && !exclusive) {
y.set(index, x.get(index));
} else if (start && exclusive) {
y.set(index, 0);
} else if (!start && !exclusive) {
const prevValue = y.get(prevIndex) as number;
y.set(index, prevValue + (x.get(index) as number));
} else {
const prevValue = y.get(prevIndex) as number;
y.set(index, prevValue + (x.get(prevIndex) as number));
}

if (reverse) {
ShapeUtil.decrementIndex(index, x.dims);
i--;
} else {
ShapeUtil.incrementIndex(index, x.dims);
i++;
}
}

return y;
}

function updateIndex(index: number[], axis: number, value: number) {
const result = index.slice();
result[axis] = value;
return result;
}
2 changes: 2 additions & 0 deletions lib/backends/wasm/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmClip} from './ops/clip';
import {WasmConv} from './ops/conv';
import {WasmCumSum} from './ops/cumsum';
import {WasmGemm} from './ops/gemm';
import {WasmInstanceNormalization} from './ops/instance-normalization';
import {WasmMatMul} from './ops/matmul';
Expand Down Expand Up @@ -36,4 +37,5 @@ export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')],
['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8
['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')],
['CumSum', '', '11+', () => new WasmCumSum()],
];
35 changes: 35 additions & 0 deletions lib/backends/wasm/ops/cumsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {CumSum} from '../../../ops/cumsum';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmCumSum extends CumSum {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const ax = inputs[1].integerData[0];

const outputSize = ShapeUtil.size(inputs[0].dims);
const resultData = new Float32Array(outputSize);
WasmBinding.getInstance().ccall(
'_cumsum_f32', [inputs[0].floatData, 'float32ptr'], [inputs[0].dims, 'int32ptr'],
[inputs[0].dims.length, 'int32'], [ax, 'int32'], [this.exclusive, 'bool'], [this.reverse, 'bool'],
[resultData, 'float32ptr', 'out']);

const result = new Tensor(inputs[0].dims, inputs[0].type);
result.floatData.set(resultData);
return [result];
}

// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations
checkInputTypes(inputs: Tensor[]): boolean {
// currently Wasm backend only supports 'float32' input type
if (inputs[0].type !== 'float32') {
return false;
}

return true;
}
}
2 changes: 2 additions & 0 deletions lib/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as binaryOps from './ops/binary-op';
import {WebGLClip} from './ops/clip';
import {WebGLConcat} from './ops/concat';
import {WebGLConv} from './ops/conv';
import {WebGLCumSum} from './ops/cumsum';
import {WebGLDropout} from './ops/dropout';
import {WebGLElu} from './ops/elu';
import {WebGLFlatten} from './ops/flatten';
Expand Down Expand Up @@ -105,4 +106,5 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Upsample', '', '7-8', () => new WebGLUpsample()],
['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()],
['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())],
['CumSum', '', '11+', () => new WebGLCumSum()],
];
51 changes: 51 additions & 0 deletions lib/backends/webgl/ops/cumsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {CumSum} from '../../../ops/cumsum';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';

export class WebGLCumSum extends CumSum implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return inferenceHandler.run(this, inputs);
}
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const ax = inputs[1].integerData[0];
const rank = inputs[0].dims.length;
const dims = inputs[0].dims;

const startIx = this.reverse ? (dims[ax] - 1) : 0;
const comp = this.exclusive ? '' : '=';
const condition = this.reverse ? `k >${comp} endIx` : `k <${comp} endIx`;
const update = this.reverse ? 'k--' : 'k++';

const shaderSource = `
float process(int indices[${rank}]) {
float value = 0.0;
int endIx = indices[${ax}];
for (int k=${startIx}; ${condition}; ${update}) {
indices[${ax}] = k;
value += _A(indices);
}
return value;
}`;
const inputLayouts = [inferenceHandler.getOrCreateTextureLayout(inputs[0])];
return {
inputLayouts,
outputLayout: inferenceHandler.createTextureLayoutFromShape(inputs[0].dims),
samplers: ['A'],
shaderSource,
};
}

createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
return {
inputTextureDatas: inputTDs,
outputTextureData:
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
uniformData: {}
};
}
}
34 changes: 34 additions & 0 deletions lib/ops/cumsum.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

export abstract class CumSum implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {
this.exclusive = attributes.getInt('exclusive', 0) === 1;
this.reverse = attributes.getInt('reverse', 0) === 1;
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 2) {
return false;
}

return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
if (inputs[1].type !== 'int32' || inputs[1].dims.length !== 1) {
return false;
}

return true;
}

protected exclusive: boolean;
protected reverse: boolean;
}
29 changes: 29 additions & 0 deletions lib/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,35 @@ export class ShapeUtil {
}
}

// Decrement an index into a tensor (in lexicographic
// ordering), wrapping around the specified lower bound.
/**
* Decrement an index into a tensor (in lexicographic ordering), wrapping around the specified upper_bound.
* @param index Given index to decrement (Will be mutated)
* @param dims The dimensions of the tensor for which the given index corresponds to
* @param axisToDecrementOn The 1-indexed axis to decrement on. If undefined, axisToDecrementOn == rank
*/
static decrementIndex(index: number[], dims: ReadonlyArray<number>, axisToDecrementOn?: number) {
if (dims.length === 0 || index.length === 0) {
throw new Error(`Index decrementing unsupported for scalar Tensor`);
}
if (axisToDecrementOn === undefined) {
axisToDecrementOn = dims.length;
} else {
if (axisToDecrementOn <= 0 || axisToDecrementOn > dims.length) {
throw new Error(`Incorrect axis to decrement on`);
}
}

for (let k = axisToDecrementOn - 1; k >= 0; --k) {
index[k]--;
if (index[k] >= 0) {
break;
}
index[k] = dims[k] - 1;
}
}

/**
* Produces a new dimensions array based on the values in the 'originalDimensions' and 'shape' array
* Used in Reshape
Expand Down
3 changes: 2 additions & 1 deletion src/wasm-build-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"_clip_f32",
"_instance_normalization_f32",
"_sum_f32",
"_softmax_f32"
"_softmax_f32",
"_cumsum_f32"
]
}
66 changes: 66 additions & 0 deletions src/wasm-ops/cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "cumsum.h"
#include "common.h"
#include "utils/shape_utils.h"

// Wasm interop method
void cumsum_f32(void *data) {
uint32_t *dataIndex = static_cast<uint32_t *>(data);
uint32_t const argc = dataIndex[0];
const float *x = PARAM_FLOAT_PTR(data, dataIndex[1]);
const int32_t *dims = PARAM_INT32_PTR(data, dataIndex[2]);
const int32_t rank = PARAM_INT32(data, dataIndex[3]);
const int32_t axis = PARAM_INT32(data, dataIndex[4]);
const bool exclusive = PARAM_BOOL(data, dataIndex[5]);
const bool reverse = PARAM_BOOL(data, dataIndex[6]);

float *output = PARAM_FLOAT_PTR(data, dataIndex[7]);
cumsum_f32_imp(x, dims, rank, axis, exclusive, reverse, output);
}

// Core operator implementation
void cumsum_f32_imp(const float *X, const int32_t *dims, const int32_t rank,
int32_t axis, const bool exclusive, const bool reverse,
float *Y) {
if (axis < 0) {
axis = rank + axis;
}

// const index : number[] = new Array(y.dims.length).fill(0);
size_t i = 0;
std::vector<int32_t> dimsVector(dims, dims + rank);
std::vector<int32_t> strides = ShapeUtils::compute_strides(dimsVector);
size_t size = ShapeUtils::size_from_dims(dimsVector);

if (reverse) {
i = size - 1;
}

while (i < size && i >= 0) {

size_t indexAtAxis = ShapeUtils::offset_to_index(strides, i, axis);

size_t prevIndex = i + (reverse ? strides.at(axis) : -strides.at(axis));

bool start = (indexAtAxis == 0 && !reverse) ||
(indexAtAxis == dimsVector.at(axis) && reverse);

if (start && !exclusive) {
Y[i] = X[i];
} else if (start && exclusive) {
Y[i] = 0;
} else if (!start && !exclusive) {
Y[i] = Y[prevIndex] + X[i];
} else {
Y[i] = Y[prevIndex] + X[prevIndex];
}

if (reverse) {
i--;
} else {
i++;
}
}
}
13 changes: 13 additions & 0 deletions src/wasm-ops/cumsum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

#include <stdint.h>

extern "C" {
void cumsum_f32(void *);
void cumsum_f32_imp(const float *X, const int32_t *dims, const int32_t rank,
int32_t axis, const bool exclusive, const bool reverse,
float *Y);
}
10 changes: 10 additions & 0 deletions src/wasm-ops/utils/shape_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,13 @@ void ShapeUtils::offset_to_indices(const std::vector<int32_t> &strides,
}
indices[indices.size() - 1] = offset;
}

size_t ShapeUtils::offset_to_index(const std::vector<int32_t> &strides,
size_t offset, int32_t axis) {
size_t index;
for (size_t i = 0; i < axis; ++i) {
size_t index = floor(offset / strides[i]);
offset -= index * strides[i];
}
return index;
}
3 changes: 3 additions & 0 deletions src/wasm-ops/utils/shape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ std::vector<int32_t> offset_to_indices(const std::vector<int32_t> &strides,
// Fills in values in the indices vector. Assumes it is of the required size.
void offset_to_indices(const std::vector<int32_t> &strides, size_t offset,
std::vector<int32_t> &indices);
// Gives the index at a specific axis from a given offset
size_t offset_to_index(const std::vector<int32_t> &strides, size_t offset,
int32_t axis);
}; // namespace ShapeUtils
Loading