Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tp.mean failure when dim is multi dimensional with skipped dimensions #297

Open
farazkh80 opened this issue Oct 22, 2024 · 0 comments
Open

Comments

@farazkh80
Copy link
Collaborator

tp.mean fails if we skip a dimension

Working examples

x = tp.reshape(tp.arange(12), (2,3,2))

then if we do

>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
    [[[3.0000, 4.0000],
      [5.0000, 6.0000],
      [7.0000, 8.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 2))
>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
    [[[5.0000, 6.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 1, 2))
>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))

Failed example 1

but if you skip a dim

>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
    data_list = self.tolist()
  File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
    data_memref = self.eval()
  File "/tripy/tripy/frontend/tensor.py", line 180, in eval
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> <stdin>:1 in <module>()

MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
    (t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t1926)): error: op: 
    %2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
      %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%7) : (tensor<f32>) -> ()
    }) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
          | 

    This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
          |
      145 |     return reshape_impl(input, shape, len(shape), output_len)
          |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant