Skip to content

Commit

Permalink
Add cublas FP8 tensorwise GEMM in fbgemm quantize bench (pytorch#3693)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3693

X-link: facebookresearch/FBGEMM#769

As title

Reviewed By: jianyuh

Differential Revision: D69641673

fbshipit-source-id: c75191cc7b435464aac8ce60014e109789e36352
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Feb 14, 2025
1 parent 58b7680 commit a4be13a
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def cuda(self) -> bool:
@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
FP8 matmul with tensorwise scaling.
FP8 cublas matmul with rowwise scaling.
"""

def quantize(self, x, w):
Expand Down Expand Up @@ -503,6 +503,39 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasTensorwiseGemm(QuantizeOpBase):
"""
FP8 cublas matmul with tensorwise scaling.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.fbgemm.f8f8bf16_cublas(xq, wq, x_scale * w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale * w_scale)

@property
def name(self) -> str:
return "cublas_tensorwise"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8RowwiseGemm(QuantizeOpBase):
"""
Expand Down

0 comments on commit a4be13a

Please sign in to comment.