Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical accuracy test in 03-matrix-multiplication.py is failing; atol and rtol values #5283

Open
ravil-mobile opened this issue Nov 28, 2024 · 1 comment
Labels

Comments

@ravil-mobile
Copy link
Contributor

ravil-mobile commented Nov 28, 2024

Describe the bug

Hi all,

I am investigating the failure of the numerical accuracy test failure of 03-matrix-multiplication.py on the AMD MI300 GPUs. This example uses float16 and compares numerical results obtained with Torch and Triton.

rtol = 1e-2 if is_hip_mi200() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")

As one can see, we increase rtol for MI200 GPUs because we know that the MFMA units flush denorms to zero in the case of float16. As far as I know, this issue has been resolved on MI300.

The specific atol and rtol values (1e-2 and 0.0, respectively) were introduced in this PR by @kernhanda. However, I couldn't find any reasoning why those values were chosen. It seems to me that those values are too high for float16. Below is a script which compares numerical results obtained by Torch and Numpy. It fails on both MI300 and H100 machines with default command line options (which replicates the ones in the example) - i.e., while comparing Torch results with the Numpy ones obtained with float16 and float64.

import torch
import numpy as np
import argparse
import sys


parser = argparse.ArgumentParser()
parser.add_argument("-t", "--type", choices=['fp16', 'fp32'],
                    default="fp16",
                    help="data type under the test")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance")
parser.add_argument("--rtol", type=float, default=0.0, help="relative tolerance")
parser.add_argument("-v", "--verbose", action='store_true', help='verbose output')
args = parser.parse_args()

if args.type == 'fp64':
    print('using float64...')
    torch_fp = torch.float64
if args.type == 'fp32':
    print('using float32...')
    torch_fp = torch.float32
elif args.type == 'fp16':
    print('using float16...')
    torch_fp = torch.float16

print(f'using {args.atol=}; {args.rtol=}')

torch.manual_seed(0)
M = N = 512
K = 512
print(f'{M=}; {N=}; {K=}')
a = torch.randn((M, K), device='cuda', dtype=torch_fp)
b = torch.randn((K, N), device='cuda', dtype=torch_fp)

torch_output = torch.matmul(a, b)

torch_output_host = torch_output.cpu()
a_host = a.cpu()
b_host = b.cpu()

numpy_output = np.matmul(a_host.numpy(), b_host.numpy())

# See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
def difference(a_vec, b_vec, atol, rtol):
    assert a_vec.shape == b_vec.shape
    m, n = a_vec.shape
    err_counter = 0
    max_diff = -sys.float_info.min
    min_diff = sys.float_info.max 
    for i in range(m):
        for j in range(n):
            a = a_vec[i,j].item()
            b = b_vec[i,j].item()
            limit = atol + rtol * abs(b)
            delta = abs(a - b)
            max_diff = delta if delta > max_diff else max_diff
            min_diff = delta if min_diff > delta else min_diff
            if delta > limit:
                if args.verbose:
                  print(f'[{i}, {j}]: {a} | {b} | {delta}')
                err_counter += 1
    print(f'total errors found: {err_counter}')
    return max_diff, min_diff

print('-' * 80)
print(f'Comparing with: {torch_fp}')
if torch.allclose(torch_output_host, torch.from_numpy(numpy_output), atol=args.atol, rtol=args.rtol):
    print("\U00002705: Torch and Numpy match")
else:
    print("\U0000274C: Torch and Numpy differ")


max_diff, min_diff = difference(torch_output_host, numpy_output, atol=args.atol, rtol=args.rtol)
print(f'{max_diff=:10.8f}; {min_diff=:10.8f}')

print('-' * 80)
print(f'Comparing with: {torch.float64}')
a_host = a_host.double()
b_host = b_host.double()

numpy_output = np.matmul(a_host.numpy(), b_host.numpy())
max_diff, min_diff = difference(torch_output_host, torch.from_numpy(numpy_output), atol=args.atol, rtol=args.rtol)
print(f'{max_diff=:10.8f}; {min_diff=:10.8f}')

I suggest we need to increase tolerance values (i.e., the current values are to low for float16 for a given GEMM configration (M=512; N=512; K=512)).
I need to know whether OpenAI developers are going to be ok with this change.

Environment details

Triton: doesn't matter b/c the reproducible example uses Torch and Numpy

Docker:

  • ROCm: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
  • CUDA: pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel

GPU:

  • MI300X
  • H100-PCIe
@ravil-mobile
Copy link
Contributor Author

ravil-mobile commented Nov 29, 2024

Hi all, I observed 2 things while working with Nvidia H100 GPUs around the following lines:

a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_mi200() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")

  1. I can set any small value of atol (e.g., 1e-16, 1e-32 or 1e-64) and the accuracy test passes
  2. The accuracy test fails when I am switching to the CPU version of torch.matmul. This can be reproduced by replacing the following line

torch_output = torch.matmul(a, b)

with

torch_output = torch.matmul(a.cpu(), b.cpu())
torch_output = torch_output.to(device='cuda')

@antiagainst @zhanglx13

My env:
Triton: v3.2.0; commit cc89dac
Torch: 2.4.1+cu118

@ravil-mobile ravil-mobile changed the title Numerical accuracy test in 03-matrix-multiplication.py is fauing; atol and rtol values Numerical accuracy test in 03-matrix-multiplication.py is failing; atol and rtol values Nov 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant