-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Gather Ops in Tcp | ||
|
||
## Gather elements along a given dim | ||
|
||
`tcp.gather_elements` op gathers elements from a given tensor based on indices that index along a given dim. | ||
|
||
Syntax: | ||
|
||
operation ::= `tcp.gather_elements` $input `,` $indices attr-dict `:` | ||
type($input) `,` type($indices) `->` type($out) | ||
|
||
Attributes: | ||
|
||
dim : index | ||
|
||
Inputs: | ||
|
||
input : tensor of any supported type, rank r | ||
indices : tensor of int64, rank r | ||
|
||
Output: | ||
|
||
out : tensor of any supported type, rank r, same shape as indices | ||
|
||
Semantics: | ||
|
||
For rank 2 input and indices: | ||
out[i][j] = input[index[i][j]][j] # if dim == 0 | ||
out[i][j] = input[i][index[i][j]] # if dim == 1 | ||
|
||
For rank 3 input and indices: | ||
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 | ||
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 | ||
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 | ||
|
||
This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/generated/torch.gather.html) and `onnx.GatherElements` [[2]](https://onnx.ai/onnx/operators/onnx__GatherElements.html#l-onnx-doc-gatherelements). | ||
|
||
|
||
## Gather slices along a given dim | ||
|
||
This requires gathering slices from a given tensor based on indices that index along a given dim. | ||
|
||
Our design is to use `tcp.gather_elements` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`. | ||
* Broadcast `input` from `[a, b, c]` to `[a, y, b, c]` by introducing `y` dim. | ||
* Broadcast `indices` from `[x, y]` to `[x, y, b, c]` by introducing `b` and `c` dims. | ||
* Perform `tcp.gather_elements` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`. | ||
|
||
|
||
This approach can be used to represent ops like `torch.index_select` [[3]](https://pytorch.org/docs/stable/generated/torch.index_select.html), `tf.gather` [[4]](https://www.tensorflow.org/api_docs/python/tf/gather), and `onnx.Gather` [[5]](https://onnx.ai/onnx/operators/onnx__Gather.html#l-onnx-doc-gather). | ||
|
||
### Alternative considered | ||
|
||
We considered a separate `tcp.gather` op for this particular case with the following design. | ||
|
||
Syntax: | ||
|
||
operation ::= `tcp.gather` $input `,` $indices attr-dict `:` | ||
type($input) `,` type($indices) `->` type($out) | ||
|
||
Attributes: | ||
|
||
dim : index | ||
|
||
Inputs: | ||
|
||
input : tensor of any supported type, rank r | ||
indices : tensor of int64, rank q | ||
|
||
Output: | ||
|
||
out : tensor of any supported type, rank r + q - 1 | ||
|
||
Semantics: | ||
|
||
For input of rank 2 and indices of rank 2: | ||
out[i][j][k] = input[indices[i][j]][k] # if dim == 0 | ||
out[i][j][k] = input[i][indices[j][k]] # if dim == 1 | ||
|
||
For input of rank 3 and indices of rank 2: | ||
out[i][j][k][m] = input[indices[i][j]][k][m] # if dim == 0 | ||
out[i][j][k][m] = input[i][indices[j][k]][m] # if dim == 1 | ||
out[i][j][k][m] = input[i][j][indices[k][m]] # if dim == 2 | ||
|
||
The above approach of reusing `tcp.gather_elements` is preferred to avoid adding a new op here. | ||
|
||
## Gather slices along N dims | ||
|
||
`tcp.gather_nd` op gathers slices from a given tensor based on indices that index along the first `n` dims. | ||
|
||
Syntax: | ||
|
||
operation ::= `tcp.gather` $input `,` $indices attr-dict `:` | ||
type($input) `,` type($indices) `->` type($out) | ||
|
||
Inputs: | ||
|
||
input : tensor of any supported type, rank r | ||
indices : tensor of int64, rank q | ||
|
||
Output: | ||
|
||
out : tensor of any supported type, rank r + q - indices_shape[-1] - 1 | ||
|
||
Semantics: | ||
|
||
For input of rank 2 and indices of shape (N, 2): | ||
a, b = indices[i] | ||
out[i] = input[a][b] | ||
|
||
For input of rank 3 and indices of shape (N, 2): | ||
a, b = indices[i] | ||
out[i][j] = input[a][b][j] | ||
|
||
For input of rank 4 and indices of shape (N, 2): | ||
a, b = indices[i] | ||
out[i][j][k] = input[a][b][j][k] | ||
|
||
For input of rank 4 and indices of shape (N, 3): | ||
a, b, c = indices[i] | ||
out[i][j] = input[a][b][c][j] | ||
|
||
This op can be used to represent ops like `tf.gather_nd` [[6]](https://www.tensorflow.org/api_docs/python/tf/gather_nd) and `onnx.GatherND` [[7]](https://onnx.ai/onnx/operators/onnx__GatherND.html#l-onnx-doc-gathernd), except for cases when they support batching. |