Skip to content

Commit

Permalink
Using scaled_mm for untuned gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 13, 2024
1 parent 7e09aea commit d61cbff
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions vllm/model_executor/layers/quantization/fp8_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,15 @@ def apply_fp8_16(

algo = self._config._tuned.get((m, n, k))
if algo is None:
import os

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/tmp/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError,
pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)
algo = 0
res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo))
_save_shape(m, n, k)
res, _ = torch._scaled_mm(x8,
weight.t(),
out_dtype=x.dtype,
scale_a=asf,
scale_b=wsf,
bias=bias)
else:
res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo))
return res

def apply_fp8_8(
Expand All @@ -257,24 +249,16 @@ def apply_fp8_8(

algo = self._config._tuned.get((m, n, k))
if algo is None:
import os

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/projects/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError,
pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)
algo = 0

res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))
_save_shape(m, n, k)
res, _ = torch._scaled_mm(x8,
weight.t(),
out_dtype=x8.dtype,
scale_a=asf,
scale_b=wsf,
scale_result=osf,
bias=bias)
else:
res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))
res16 = torch.empty_like(res, dtype=torch.float16)
vllm_ops.convert_fp8(res16, res, 1 / osf)
return res16
Expand Down Expand Up @@ -308,3 +292,18 @@ def _per_tensor_dequantize(tensor: torch.Tensor,
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight

def _save_shape(m, n, k):
if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/tmp/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError,
pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)

0 comments on commit d61cbff

Please sign in to comment.