Skip to content

Commit

Permalink
Design of Gather ops
Browse files Browse the repository at this point in the history
  • Loading branch information
navahgar committed Jan 18, 2024
1 parent fe5bbe1 commit a5bff3d
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions docs/gather.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Gather Ops in Tcp

## Gather elements along a given dim

`tcp.gather_elements` op gathers elements from a given tensor based on indices that index along a given dim.

Syntax:

operation ::= `tcp.gather_elements` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Attributes:

dim : index

Inputs:

input : tensor of any supported type, rank r
indices : tensor of int64, rank r

Output:

out : tensor of any supported type, rank r, same shape as indices

Semantics:

For rank 2 input and indices:
out[i][j] = input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1

For rank 3 input and indices:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/generated/torch.gather.html) and `onnx.GatherElements` [[2]](https://onnx.ai/onnx/operators/onnx__GatherElements.html#l-onnx-doc-gatherelements).


## Gather slices along a given dim

This requires gathering slices from a given tensor based on indices that index along a given dim.

Our design is to use `tcp.gather_elements` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`.
* Broadcast `input` from `[a, b, c]` to `[a, y, b, c]` by introducing `y` dim.
* Broadcast `indices` from `[x, y]` to `[x, y, b, c]` by introducing `b` and `c` dims.
* Perform `tcp.gather_elements` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`.


This approach can be used to represent ops like `torch.index_select` [[3]](https://pytorch.org/docs/stable/generated/torch.index_select.html), `tf.gather` [[4]](https://www.tensorflow.org/api_docs/python/tf/gather), and `onnx.Gather` [[5]](https://onnx.ai/onnx/operators/onnx__Gather.html#l-onnx-doc-gather).

### Alternative considered

We considered a separate `tcp.gather` op for this particular case with the following design.

Syntax:

operation ::= `tcp.gather` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Attributes:

dim : index

Inputs:

input : tensor of any supported type, rank r
indices : tensor of int64, rank q

Output:

out : tensor of any supported type, rank r + q - 1

Semantics:

For input of rank 2 and indices of rank 2:
out[i][j][k] = input[indices[i][j]][k] # if dim == 0
out[i][j][k] = input[i][indices[j][k]] # if dim == 1

For input of rank 3 and indices of rank 2:
out[i][j][k][m] = input[indices[i][j]][k][m] # if dim == 0
out[i][j][k][m] = input[i][indices[j][k]][m] # if dim == 1
out[i][j][k][m] = input[i][j][indices[k][m]] # if dim == 2

The above approach of reusing `tcp.gather_elements` is preferred to avoid adding a new op here.

## Gather slices along N dims

`tcp.gather_nd` op gathers slices from a given tensor based on indices that index along the first `n` dims.

Syntax:

operation ::= `tcp.gather` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Inputs:

input : tensor of any supported type, rank r
indices : tensor of int64, rank q

Output:

out : tensor of any supported type, rank r + q - indices_shape[-1] - 1

Semantics:

For input of rank 2 and indices of shape (N, 2):
a, b = indices[i]
out[i] = input[a][b]

For input of rank 3 and indices of shape (N, 2):
a, b = indices[i]
out[i][j] = input[a][b][j]

For input of rank 4 and indices of shape (N, 2):
a, b = indices[i]
out[i][j][k] = input[a][b][j][k]

For input of rank 4 and indices of shape (N, 3):
a, b, c = indices[i]
out[i][j] = input[a][b][c][j]

This op can be used to represent ops like `tf.gather_nd` [[6]](https://www.tensorflow.org/api_docs/python/tf/gather_nd) and `onnx.GatherND` [[7]](https://onnx.ai/onnx/operators/onnx__GatherND.html#l-onnx-doc-gathernd), except for cases when they support batching.

0 comments on commit a5bff3d

Please sign in to comment.