From 903e1f4f86778ff2ab71588fd4ef90353f9d76c1 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Sep 2023 23:42:44 -0700 Subject: [PATCH] [PyTorch] Fix ONNX exports (#437) * Fix ONNX exports Signed-off-by: Kirthi Shankar Sivamani * docs Signed-off-by: Kirthi Shankar Sivamani * review Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_onnx_export.py | 173 ++-------------------- transformer_engine/pytorch/attention.py | 65 +++----- transformer_engine/pytorch/transformer.py | 34 ++--- 3 files changed, 48 insertions(+), 224 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 14640febde..533e0cff6a 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -763,156 +763,6 @@ def forward(self, inp): validate_result( fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) -@skip_FP8 -@pytest.mark.parametrize("softmax_fn", [ - softmax_defs.ScaledUpperTriangMaskedSoftmax, - softmax_defs.ScaledMaskedSoftmax, - softmax_defs.ScaledSoftmax, - te.softmax.FusedScaleMaskSoftmax, -]) -# Softmax kernel only supports FP16 or BF16! -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) -def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision): - class Test_Softmax(nn.Module): - def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False): - super().__init__() - self.softmax_fn = softmax_fn - self.scale = 8 # arbitrary value - self.mask_inp = mask_inp - self.fused_scaled_softmax = None - self.fake_bf16_io = fake_bf16_io - if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax: - self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - mask_func=te.utils.attention_mask_func, - softmax_in_fp32=True, - ) - - def forward(self, inp, mask): - if self.fake_bf16_io: - inp = inp.type(torch.bfloat16) - - if self.fused_scaled_softmax: - ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale) - else: - if self.mask_inp: - ret = self.softmax_fn.apply(inp, mask, self.scale) - else: - ret = self.softmax_fn.apply(inp, self.scale) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - fake_bf16_io = precision == "fake-torch.bfloat16" - precision = torch.bfloat16 if fake_bf16_io else precision - - # Set dimensions (these are arbitrary). - batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 - mask = None - input_names = ["input", "mask"] - inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] - if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax: - inp_shape = [batch_size, seq_len_q, seq_len_k] - kernel_str = "ScaledUpperTriangMaskedSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - elif softmax_fn == softmax_defs.ScaledMaskedSoftmax: - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision) - mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - kernel_str = "ScaledMaskedSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True) - elif softmax_fn == softmax_defs.ScaledSoftmax: - kernel_str = "ScaledSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - elif softmax_fn == te.softmax.FusedScaleMaskSoftmax: - kernel_str = "TorchSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - - input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"{kernel_str}{high_prec_str}.onnx" - inp = (input_tensor, mask) - dynamic_axes = {} - if mask is not None: - dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}} - do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes) - te_outputs = te_infer(model, inp, is_fp8=False) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if fake_bf16_io or precision != torch.bfloat16: - atol = 5e-2 if fake_bf16_io else 1e-3 - validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs) - - -# Test dynamically generated softmax mask. -# Softmax kernel only supports FP16 or BF16! -@skip_FP8 -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) -def test_softmax_mask_fn(seed_default_rng, precision): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if fake_bf16_io else precision - - class Test_Softmax(nn.Module): - def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool): - super().__init__() - self.scale = 1 # arbitrary value - self.fake_bf16_io = fake_bf16_io - - if use_default_te_mask_fn: - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0" - else: - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}" - - # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax - # even when is_in_onnx_export_mode()==False. - os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" - self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - mask_func=te.utils.attention_mask_func, - softmax_in_fp32=True, - ) - - def forward(self, inp, mask): - if self.fake_bf16_io: - inp = inp.type(torch.bfloat16) - ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale) - if self.fake_bf16_io: - ret = ret.type(torch.float) - return ret - - # Set dimensions (these are arbitrary). - mask = None - batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 - assert seq_len_q == seq_len_k # This is a causal (TRILU) mask - inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] - input_tensor = torch.randn( - *inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision) - inp = (input_tensor, mask) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - - # Compare the outputs of TE when using the default softmax mask - # to the TE outputs produced when using the ONNX-compatible causal mask. - # This verifies that _get_onnx_export_causal_mask generates a correct mask. - model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io) - te_outputs_default_mask = te_infer(model, inp, is_fp8=True) - with te.onnx_export(True): - # ONNX export mode forces use of the ONNX-compatible causal mask. - model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io) - te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True) - compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask, - atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking") - - # Compare the outputs of TE when using the default softmax mask - # to the ORT ONNX outputs produced when using the ONNX-compatible causal mask. - input_names = ["input", "mask"] - kernel_str = "FusedScaleMaskSoftmax" - fname = f"{kernel_str}{high_prec_str}.onnx" - do_export(model, inp, fname, input_names=input_names) - serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names) - if fake_bf16_io or precision != torch.bfloat16: - atol = 1e-2 if fake_bf16_io else 1e-3 - validate_result( - fname, inp, model_onnx_mask, atol=atol, - input_names=input_names, te_outputs=te_outputs_default_mask) - @pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("use_fp8", [False, True]) @@ -1159,13 +1009,13 @@ def test_export_core_attention( query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"] + input_names = ["query", "key", "value", "attention_mask"] attention_mask = None if use_mask: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type) + inp = (query_layer, key_layer, value_layer, attention_mask) mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) @@ -1175,6 +1025,7 @@ def test_export_core_attention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, + attn_mask_type=attn_mask_type, ).to(device='cuda') do_export(model, inp, @@ -1190,8 +1041,9 @@ def test_export_core_attention( test_configs_multihead_attention = [ #"use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax + (False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax + (False, "padding"), # calls ScaledSoftmax ] test_configs_attention_type = [ #"input_layernorm, attention_type, fuse_qkv_params" @@ -1265,6 +1117,7 @@ def test_export_multihead_attention( model = te.MultiheadAttention( *attention_args, + attn_mask_type=attn_mask_type, params_dtype=precision, return_layernorm_output=return_layernorm_output, input_layernorm=input_layernorm, @@ -1273,8 +1126,8 @@ def test_export_multihead_attention( return_bias=True, ).to(device='cuda') - inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type) - input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"] + inp_context = (hidden_states_context, attention_mask, encoder_output) + input_names = ["hidden_states", "attention_mask", "encoder_output"] output_names=["attention_output", "attention_bias"] do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, @@ -1342,13 +1195,13 @@ def test_export_transformer_layer( num_attention_heads = 4 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - input_names = ["input", "attention_mask", "self_attn_mask_type"] + input_names = ["input", "attention_mask"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask, attn_mask_type) + inp = (input_tensor, attention_mask) fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" @@ -1360,6 +1213,7 @@ def test_export_transformer_layer( hidden_size, ffn_hidden_size, num_attention_heads, + self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, @@ -1541,16 +1395,17 @@ def test_export_gpt_generation( hidden_size, ffn_hidden_size, num_attention_heads, + self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma).to(device='cuda') # "Context phase": use full input sequence length - input_names = ["input", "attention_mask", "self_attn_mask_type"] + input_names = ["input"] output_names = ["output"] input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - inp = (input_tensor, None, attn_mask_type) + inp = (input_tensor,) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"input": {0: "seq", 1:"bs"}, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 461725f59d..bba67903bd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -610,6 +610,7 @@ def backward(ctx, tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2) return tensors[0], tensors[1], None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms BMM1 -> softmax + dropout -> BMM2 @@ -1324,11 +1325,6 @@ class DotProductAttention(torch.nn.Module): and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. - .. warning:: - - Argument :attr:`attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. - Parameters ---------- num_attention_heads : int @@ -1348,6 +1344,12 @@ class DotProductAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `DotProductAttention` when multiple such modules are concatenated, for instance in consecutive transformer blocks. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. Parallelism parameters ---------------------- @@ -1374,7 +1376,7 @@ def __init__( kv_channels: int, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, - attn_mask_type: Optional[str] = None, + attn_mask_type: str = "causal", sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -1387,13 +1389,6 @@ def __init__( ) -> None: super().__init__() - if attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - self.attn_mask_type = attn_mask_type self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group @@ -1487,7 +1482,7 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1543,7 +1538,7 @@ def forward( Value tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out softmax input when not using flash-attn. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` type of attention mask passed into softmax operation. checkpoint_core_attention : bool, default = `False` If true, forward activations for attention are recomputed @@ -1558,13 +1553,7 @@ def forward( Whether to use the fast path to set output tensors to 0 or not. """ - if self.attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if attn_mask_type is None: attn_mask_type = self.attn_mask_type assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition @@ -1697,11 +1686,6 @@ class MultiheadAttention(torch.nn.Module): Argument :attr:`attention_mask` will be ignored in the `forward` call when :attr:`attn_mask_type` is set to `"causal"`. - .. warning:: - - Argument :attr:`attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. - Parameters ---------- hidden_size : int @@ -1727,6 +1711,12 @@ class MultiheadAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `TransformerLayer` when multiple such modules are concatenated to form a transformer block. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. num_gqa_groups : int, default = `None` number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -1817,7 +1807,7 @@ def __init__( init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, - attn_mask_type: Optional[str] = None, + attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -1843,13 +1833,6 @@ def __init__( ) -> None: super().__init__() - if attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - self.attn_mask_type = attn_mask_type self.layer_number = layer_number self.input_layernorm = input_layernorm @@ -2034,7 +2017,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, @@ -2057,7 +2040,7 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using @@ -2092,13 +2075,7 @@ def forward( """ # hidden_states: [sq, b, h] - if self.attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if attn_mask_type is None: attn_mask_type = self.attn_mask_type if attn_mask_type == "padding" and attention_mask is not None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index a2ecfbda45..2b436916ca 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -73,10 +73,9 @@ class TransformerLayer(torch.nn.Module): Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` are deprecated and will be fully removed in future releases. - .. warning:: - - Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. + .. note:: + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`self_attn_mask_type` is set to `"causal"`. Parameters ---------- @@ -127,6 +126,12 @@ class TransformerLayer(torch.nn.Module): kv_channels: int, default = `None` number of key-value channels. defaults to :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. + self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`self_attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -212,7 +217,7 @@ def __init__( output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, kv_channels: Optional[int] = None, - self_attn_mask_type: Optional[str] = None, + self_attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, tp_size: int = 1, params_dtype: Optional[torch.dtype] = None, @@ -239,13 +244,6 @@ def __init__( ) -> None: super().__init__() - if self_attn_mask_type is not None: - warnings.warn( - "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - warnings.warn( "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" "are deprecated and will be fully removed in future releases.", @@ -445,7 +443,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - self_attn_mask_type: str = "causal", + self_attn_mask_type: Optional[str] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, @@ -470,7 +468,7 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` + self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using @@ -507,13 +505,7 @@ def forward( Whether to set output tensors to 0 or not before use. """ - if self.self_attn_mask_type is not None: - warnings.warn( - "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if self_attn_mask_type is None: self_attn_mask_type = self.self_attn_mask_type assert (