Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error lowering index_put_: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 #3433

Open
ScottTodd opened this issue Jun 7, 2024 · 2 comments

Comments

@ScottTodd
Copy link
Member

Here are assorted experiments that I'm trying to rework into concrete test cases suitable for use here in torch-mlir (they use FxProgramsBuilder from iree-turbine to get MLIR from Python at the moment) : https://colab.research.google.com/gist/ScottTodd/f5e657c773e79be7a95aafb774cb3fd3/index_put-pytorch-torch-mlir-iree-turbine-iree.ipynb#scrollTo=UHFkgOtMz0k5

https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html

This puts three values (0.3, 1.4, and 2.5) into place at indices [0, 3], [1, 4], and [2, 5]:

import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3, 1.4, 2.5]))
print(a)

tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5000]])

that imports to this IR:

module @module {
  func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %2 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>
    %3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
    %none = torch.constant.none
    %4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_0 = torch.constant.none
    %5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_1 = torch.constant.none
    %6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[3],f32>, !torch.none -> !torch.vtensor<[3],f32>
    %7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %false = torch.constant.bool false
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[3],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
    torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
    return %8 : !torch.vtensor<[3,6],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
      torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000",
      torch_tensor_3_torch.float32: "0x040000009A99993E3333B33F00002040"
    }
  }
#-}

which compiles successfully through IREE and also through torch-mlir-opt --pass-pipeline=builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))

The index_put_ op also appears to support broadcasting the "values" from a single element to all indices:

import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3]))
print(a)

tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3000]])

that, however, imports to IR that fails to compile:

module @module {
  func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %2 = torch.vtensor.literal(dense<5.000000e-01> : tensor<1xf32>) : !torch.vtensor<[1],f32>
    %3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
    %none = torch.constant.none
    %4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_0 = torch.constant.none
    %5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_1 = torch.constant.none
    %6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[1],f32>
    %7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %false = torch.constant.bool false
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
    torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
    return %8 : !torch.vtensor<[3,6],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
      torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000"
    }
  }
#-}
/tmp/index_put_broadcast.mlir:15:10: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
         ^
/tmp/index_put_broadcast.mlir:15:10: note: see current operation: 
%38 = "tm_tensor.scatter"(%36, %37, %35) <{dimension_map = array<i64: 0, 1>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = false}> ({
^bb0(%arg1: f32, %arg2: f32):
  "tm_tensor.yield"(%arg1) : (f32) -> ()
}) : (tensor<1x1x1xf32>, tensor<3x2xi32>, tensor<3x6xf32>) -> tensor<3x6xf32>

There are other broadcasting semantics with "indices", some of which might be handled here in torch-mlir correctly already, but I'm not sure. I'd like to write a suite of e2e tests to verify all the edge cases, possibly drawing on https://github.com/pytorch/pytorch/blob/main/test/test_indexing.py

@ScottTodd
Copy link
Member Author

Wrote some tests cases (TBD how these can land in an existing or new test suite): https://gist.github.com/ScottTodd/1e95795e79d17964078217ca98a3a398

iree runtime + compiler at 20240410.859:

test_single_value                  | PASS
test_multiple_values               | PASS
test_broadcast_value_along_axis    | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis   | PASS

iree runtime + compiler at 20240606.916:

test_single_value                  | PASS
test_multiple_values               | PASS (then crash)
test_broadcast_value_along_axis    | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis   | PASS (then crash)

The new "pass then crash" cases are suspicious - need to debug the source of that... could look at IR dumps or bisect through nightly releases to find a culprit commit range.

The "broadcast_value_along" cases look like they are just unimplemented. That's not too surprising, since https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html contains basically no information about how the op should behave...

@ScottTodd
Copy link
Member Author

The "pass (then crash)" issues were unrelated to the op lowerings and can be worked around in IREE by using copy_buffer instead of wrap_buffer.

Still would like to see the broadcasting cases of torch.aten.index_put / tm_tensor.scatter implemented here in torch-mlir

@ScottTodd ScottTodd removed their assignment Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant