Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design of Gather ops #26

Merged
merged 1 commit into from
Jan 25, 2024
Merged

Conversation

navahgar
Copy link
Collaborator

@navahgar navahgar commented Jan 11, 2024

This PR proposes a design for gather ops in TCP.

Copy link
Contributor

@sanjoy sanjoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can tcp.gather can always be written as a tcp.gather_elements + tcp.broadcast? Here's an example where I think it can, but I'm not sure if this is always possible:

Let's say we have indices of shape <5x6xi64>, elements of shape <7x8x9xf32> and output of shape <5x6x8x9xf32>, with gather index = 0.

Then we could:

  1. broadcast indices to shape <5x6x8x9xi64> (i.e. introduce the 8 and 9 dimensions)
  2. broadcast elements to <7x6x8x9xf32> (i.e. introduce the 6 dimension)
  3. do a gather elements on these with dimension 0

I think this will get the same result and we should be able to easily fuse the broadcasts.

@navahgar
Copy link
Collaborator Author

Can tcp.gather can always be written as a tcp.gather_elements + tcp.broadcast?

Thats very interesting. I had to write a couple of examples to convince myself that it works. My worry was that the tcp.gather op is supposed to extract an entire slice, but by broadcasting we are essentially replicating the indices, which I thought might not get the entire slice we are looking for. But my examples showed otherwise. So far, I haven't been able to find any case where this wouldn't work.

Do you see any benefits to doing it this way though, other than the obv one of reducing an op?

Assuming this always works, the tradeoff I see is between having an extra op and having a somewhat complicated lowering for some gather ops (only because it is not immediately obvious how this works).

@navahgar navahgar marked this pull request as ready for review January 11, 2024 07:15
@navahgar
Copy link
Collaborator Author

Folks, let me know your thoughts on this design of Gather ops in TCP.

@sanjoy
Copy link
Contributor

sanjoy commented Jan 11, 2024

Do you see any benefits to doing it this way though, other than the obv one of reducing an op?

That's the direct benefit I see.

The indirect benefit, IMO, is that since broadcast/gather fusions will be more common which will force the backend to be more resilient in supporting fusions, and this might have knock on beneficial effects.

@sjain-stanford sjain-stanford requested review from mabubakarpurdue and a team and removed request for AaronStGeorge, mabubakarpurdue, zezhang, sjarus and sjain-stanford January 11, 2024 20:24
@navahgar
Copy link
Collaborator Author

Can tcp.gather can always be written as a tcp.gather_elements + tcp.broadcast?

Thats very interesting. I had to write a couple of examples to convince myself that it works.

Here are couple of examples I tried (just for a record here):

Example 1

input <3x3xf32>
indices <2xi64>
gather_dim: 0
output <2x3xf32>

input : [[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]

indices : [2, 1]

Using tcp.gather:

output: [[6, 7, 8],
         [3, 4, 5]]

Using tcp.gather_elements + tcp.broadcast:

indices broacasted: [[2, 2, 2],
                     [1, 1, 1]]

out[i][j] = input[index[i][j]][j]  # if dim == 0

output: [[6, 7, 8],
         [3, 4, 5]]

Example 2

input <3x3xf32>
indices <2xi64>
gather_dim: 1
output <3x2xf32>

input : [[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]
indices: [2, 1]

Using tcp.gather:

output: [[2, 1],
         [5, 4],
         [8, 7]]

Using tcp.gather_elements + tcp.broadcast:

indices broadcasted: [[2, 1],
                      [2, 1],
                      [2, 1]]

out[i][j] = input[i][index[i][j]]  # if dim == 1

output: [[2, 1],
         [5, 4],
         [8, 7]]

@navahgar
Copy link
Collaborator Author

@sanjoy Updated the doc to use the broadcasting approach you proposed for gathering slices. PTAL.

Once we have an implementation of this, we can test it with a variety of gather cases to ensure this approach is correct for all of them (I haven't found any incorrect cases with this so far). In the worst case, we can revert to the original plan of having a separate op. Hence, I kept that op in the doc as an alternative that was considered.

@navahgar
Copy link
Collaborator Author

@sjarus As you had requested offline, I added examples of how gather ops from various frameworks get converted to TCP in this doc. PTAL.

@navahgar navahgar merged commit 6b5733f into cruise-automation:main Jan 25, 2024
1 check passed
@navahgar navahgar deleted the gather-doc branch January 25, 2024 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

3 participants