Skip to content

Commit

Permalink
add some tests for causal_conv1d_fn when initial_state is None
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff committed Dec 6, 2024
1 parent ba25e7b commit 3b83a96
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype):
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
Expand All @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,

weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
Expand All @@ -183,31 +191,28 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)


@pytest.mark.parametrize("itype", [torch.bfloat16])
Expand Down

0 comments on commit 3b83a96

Please sign in to comment.