diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index 5401df72fb67a..ddd3b304280e7 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -219,12 +219,12 @@ def apply_fp8_16( algo = self._config._tuned.get((m, n, k)) if algo is None: _save_shape(m, n, k) - res, _ = torch._scaled_mm(x8, - weight.t(), - out_dtype=x.dtype, - scale_a=asf, - scale_b=wsf, - bias=bias) + res, _ = torch._scaled_mm(x8, + weight.t(), + out_dtype=x.dtype, + scale_a=asf, + scale_b=wsf, + bias=bias) else: res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res @@ -291,17 +291,16 @@ def _per_tensor_dequantize(tensor: torch.Tensor, dq_weight = fake_qweight * inv_scale return dq_weight + def _save_shape(m, n, k): if os.getenv("TUNE_FP8") == "1": try: df = pd.read_csv("/tmp/fp8_shapes.csv") - except (IOError, pd.errors.EmptyDataError, - pd.errors.ParserError): + except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) - df = pd.concat( - [df, pd.DataFrame({ - "M": [m], - "N": [n], - "K": [k] - })]).drop_duplicates() + df = pd.concat([df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() df.to_csv("/tmp/fp8_shapes.csv", index=False)