Skip to content

Commit

Permalink
[Misc/Testing] Use torch.testing.assert_close (#7324)
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-chuang authored Aug 16, 2024
1 parent e165528 commit 50b8d08
Show file tree
Hide file tree
Showing 25 changed files with 197 additions and 188 deletions.
18 changes: 9 additions & 9 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
else:
recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,

if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
Expand All @@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
get_pp_group().send(test_tensor)

if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
torch.testing.assert_close(test_tensor, recv_tensor)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand Down
8 changes: 4 additions & 4 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)
torch.testing.assert_close(out1, inp1)
torch.testing.assert_close(out2, inp2)


@ray.remote(num_gpus=1, max_calls=1)
Expand All @@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))

inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))


@pytest.mark.parametrize("tp_size", [2])
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
return ref_out, ref_scale.view((1, ))
10 changes: 5 additions & 5 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_act_and_mul(
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)


@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
Expand All @@ -73,7 +73,7 @@ def test_activation(
layer = activation()
out = layer(x)
ref_out = layer.forward_native(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
4 changes: 2 additions & 2 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


def ref_multi_query_kv_attention(
Expand Down Expand Up @@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
4 changes: 2 additions & 2 deletions tests/kernels/test_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


def ref_multi_query_kv_attention(
Expand Down Expand Up @@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
54 changes: 27 additions & 27 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def test_copy_blocks(

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand Down Expand Up @@ -184,17 +184,17 @@ def test_reshape_and_cache(
cloned_value_cache[block_idx, :, :, block_offset] = value[i]

if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand Down Expand Up @@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
cloned_value_cache[block_idx, block_offset, :, :] = value[i]

if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
Expand Down Expand Up @@ -372,10 +372,10 @@ def test_swap_blocks(
block_mapping_tensor)

for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())


@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand Down Expand Up @@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8)

assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
23 changes: 13 additions & 10 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)


def cutlass_int8_gemm_helper(m: int,
Expand Down Expand Up @@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding

a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)

baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)

Expand All @@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)


@pytest.mark.parametrize("m", [32, 64, 128])
Expand Down Expand Up @@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding

a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)

if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
Expand Down Expand Up @@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol = 1e-3
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)


# Test working with a subset of A and B
Expand All @@ -363,7 +366,7 @@ def test_cutlass_subset():
scale_b,
out_dtype=torch.bfloat16)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test to make sure cuda graphs work
Expand Down Expand Up @@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):

baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
4 changes: 2 additions & 2 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
scale=scale,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
4 changes: 2 additions & 2 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
14 changes: 7 additions & 7 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(ref_scales, ops_scales)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -57,9 +57,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)

assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(ref_scale, ops_scale)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


# Regression test for a case with large activations where an int32 index cannot
Expand All @@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
ref_out = ref_out.to(dtype=dtype)
ops_out = ops_out.to(dtype=dtype)

assert torch.allclose(ref_out, ops_out)
torch.testing.assert_close(ref_out, ops_out)
12 changes: 7 additions & 5 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
ops_out, ops_scales = scaled_int8_quant(x)

assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors
torch.testing.assert_close(ops_scales, ref_scales)
torch.testing.assert_close(
ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits.max).to(torch.int8)
out2, _ = scaled_int8_quant(x, scale)

assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
torch.testing.assert_close(
out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
Loading

0 comments on commit 50b8d08

Please sign in to comment.