Skip to content

Commit

Permalink
[PyTorch] Fix ONNX exports (#437)
Browse files Browse the repository at this point in the history
* Fix ONNX exports

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Sep 23, 2023
1 parent 2da34d4 commit 903e1f4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 224 deletions.
173 changes: 14 additions & 159 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,156 +763,6 @@ def forward(self, inp):
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)

@skip_FP8
@pytest.mark.parametrize("softmax_fn", [
softmax_defs.ScaledUpperTriangMaskedSoftmax,
softmax_defs.ScaledMaskedSoftmax,
softmax_defs.ScaledSoftmax,
te.softmax.FusedScaleMaskSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
class Test_Softmax(nn.Module):
def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_fn
self.scale = 8 # arbitrary value
self.mask_inp = mask_inp
self.fused_scaled_softmax = None
self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)

def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)

if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale)
else:
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, self.scale)
else:
ret = self.softmax_fn.apply(inp, self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret

fake_bf16_io = precision == "fake-torch.bfloat16"
precision = torch.bfloat16 if fake_bf16_io else precision

# Set dimensions (these are arbitrary).
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
mask = None
input_names = ["input", "mask"]
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [batch_size, seq_len_q, seq_len_k]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
elif softmax_fn == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
kernel_str = "TorchSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)

input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)
dynamic_axes = {}
if mask is not None:
dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}}
do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes)
te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
atol = 5e-2 if fake_bf16_io else 1e-3
validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs)


# Test dynamically generated softmax mask.
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_softmax_mask_fn(seed_default_rng, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if fake_bf16_io else precision

class Test_Softmax(nn.Module):
def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool):
super().__init__()
self.scale = 1 # arbitrary value
self.fake_bf16_io = fake_bf16_io

if use_default_te_mask_fn:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0"
else:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}"

# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)

def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float)
return ret

# Set dimensions (these are arbitrary).
mask = None
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
assert seq_len_q == seq_len_k # This is a causal (TRILU) mask
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
input_tensor = torch.randn(
*inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision)
inp = (input_tensor, mask)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)

# Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask.
# This verifies that _get_onnx_export_causal_mask generates a correct mask.
model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")

# Compare the outputs of TE when using the default softmax mask
# to the ORT ONNX outputs produced when using the ONNX-compatible causal mask.
input_names = ["input", "mask"]
kernel_str = "FusedScaleMaskSoftmax"
fname = f"{kernel_str}{high_prec_str}.onnx"
do_export(model, inp, fname, input_names=input_names)
serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
atol = 1e-2 if fake_bf16_io else 1e-3
validate_result(
fname, inp, model_onnx_mask, atol=atol,
input_names=input_names, te_outputs=te_outputs_default_mask)


@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
Expand Down Expand Up @@ -1159,13 +1009,13 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"]
input_names = ["query", "key", "value", "attention_mask"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type)
inp = (query_layer, key_layer, value_layer, attention_mask)

mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
Expand All @@ -1175,6 +1025,7 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
inp,
Expand All @@ -1190,8 +1041,9 @@ def test_export_core_attention(

test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
Expand Down Expand Up @@ -1265,6 +1117,7 @@ def test_export_multihead_attention(

model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
Expand All @@ -1273,8 +1126,8 @@ def test_export_multihead_attention(
return_bias=True,
).to(device='cuda')

inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type)
input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"]
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
Expand Down Expand Up @@ -1342,13 +1195,13 @@ def test_export_transformer_layer(
num_attention_heads = 4

input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input", "attention_mask", "self_attn_mask_type"]
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask, attn_mask_type)
inp = (input_tensor, attention_mask)

fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
Expand All @@ -1360,6 +1213,7 @@ def test_export_transformer_layer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
Expand Down Expand Up @@ -1541,16 +1395,17 @@ def test_export_gpt_generation(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')

# "Context phase": use full input sequence length
input_names = ["input", "attention_mask", "self_attn_mask_type"]
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor, None, attn_mask_type)
inp = (input_tensor,)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
Expand Down
Loading

0 comments on commit 903e1f4

Please sign in to comment.