Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
still failing as `_scaled_mm` requires the secomd matrix to be column major: ``` E NotImplementedError: Failing to map `torch._scaled_mm` to `thunder.torch` op of [Symbol name=_scaled_mm] with args of [<TensorProxy(name="t166", dtype=thunder.dtypes.float8_e4m3fn, shape=(16, 32))>, <TensorProxy(name="t169", dtype=thunder.dtypes.float8_e4m3fn, shape=(32, 64))>, <TensorProxy(name="t170", dtype=thunder.dtypes.float32, shape=())>, <TensorProxy(name="t171", dtype=thunder.dtypes.float32, shape=())>, None, None, torch.float32, True] E BoundSymbol in question is E ```python E t165 = manual_float8_matmul_with_args_in_float8_127377658692416_2(input_fp8, t164) # t165: "cuda:0 f32[16, 64]" E # t102 = ltorch.reshape(input_fp8, -1, 32) # t102: "cuda:0 f32[16, 32]" E # t102 = prims.reshape(input_fp8, (16, 32)) # t102: "cuda:0 f32[16, 32]" E # t103 = ltorch.spmm(t102, t164) # t103: "cuda:0 f32[16, 64]" E # t165 = prims.shallow_copy(t103) # t165: "cuda:0 f32[16, 64]" E ``` E Corresponding torch.fx Graph is E ```python E class <lambda>(torch.nn.Module): E def forward(self, arg0, arg1, arg2, arg3, arg4, arg5): E arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]"; E E arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec) E # No stacktrace found for following nodes E view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]); arg0_1 = None E t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1); arg3_1 = None E clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format); t = None E t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone); clone = None E reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1); arg1_1 = None E reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1); arg4_1 = None E _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True); view = t_1 = reciprocal = reciprocal_1 = None E return pytree.tree_unflatten([_scaled_mm, None], self._out_spec) E E ``` E Original error is Exception encountered when doing automatic registration for _scaled_mm, please use manual registration: RuntimeError('mat2 must be col_major') ``` Signed-off-by: Masaki Kozuki <[email protected]>
- Loading branch information