diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index f7ccad4a86..a15368c9ba 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -40,20 +40,6 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - post_inverse_scale = None - is_rowwise_scaling = a_scale.shape == (a_data.shape[0], 1) and b_scale.shape == ( - 1, - b_data.shape[1], - ) - - if is_rowwise_scaling and not use_fast_accum: - # The rowwise CUTLASS-based kernel is so slow without fast-accum that - # we'd rather use the tensorwise cuBLAS-based kernel and do the scaling - # manually afterwards (hoping Inductor will be able to fuse it). - post_inverse_scale = a_inverse_scale * b_inverse_scale - a_inverse_scale = a_inverse_scale.new_ones(()) - b_inverse_scale = a_inverse_scale.new_ones(()) - post_bias = None if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 @@ -71,8 +57,6 @@ def addmm_float8_unwrapped( use_fast_accum=use_fast_accum, ) - if post_inverse_scale is not None: - output *= post_inverse_scale if post_bias is not None: output += post_bias