From 19ac1433a4bebf50b30f12ce275a2082eb7c92fc Mon Sep 17 00:00:00 2001 From: PiotrKrzem Date: Mon, 18 Dec 2023 02:35:03 +0100 Subject: [PATCH] [SPEC] Update spec for broadcasting of batch dimensions, improve clarity --- .../sequence/ScaledDotProductAttention.rst | 156 ++++++++++++++---- 1 file changed, 125 insertions(+), 31 deletions(-) diff --git a/docs/articles_en/documentation/openvino_ir/operation_sets/operations_specifications/sequence/ScaledDotProductAttention.rst b/docs/articles_en/documentation/openvino_ir/operation_sets/operations_specifications/sequence/ScaledDotProductAttention.rst index 6d0e8633dc3767..c685cb68f19745 100644 --- a/docs/articles_en/documentation/openvino_ir/operation_sets/operations_specifications/sequence/ScaledDotProductAttention.rst +++ b/docs/articles_en/documentation/openvino_ir/operation_sets/operations_specifications/sequence/ScaledDotProductAttention.rst @@ -58,8 +58,8 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, * **3**: ``value`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, Ev]``. **Required.** * **4**: ``attention_mask`` - two options: - ** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[M, ..., L, S]``, or - ** a scalar of type *T* with value ``0``. Scalar zero value is used to indicate that `attention_mask` is really not required to be applied (``attention_mask=None`` in the pseudo-code above) but ``scale`` is required to be set. + ** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[N, ..., L, S]``, or + ** a scalar of type *T* with value ``0``. Scalar zero value signals that applying an attention mask is not necessary (similar to specifying attention_mask=None in the provided pseudo-code). ``attention_mask`` is ignored if ``causal`` is set to ``True``. **Optional.** @@ -77,7 +77,7 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, **Dimensions** -* ``N, ...`` - one or more batch dimensions +* ``N, ...`` - one or more batch dimensions. Each batch dimension should be either constant across the input tensors (query, key, and value), indicating that they have the same batch size, or they should be broadcastable to the same value. * ``S`` - source sequence length @@ -87,13 +87,13 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, * ``Ev`` - embedding dimension of the value -* ``M, ...`` - one of more batch dimensions of the mask, should be broadcastable to ``N, ...`` +At least one batch dimension ``N`` is required in ``query``, ``key`` and ``value`` inputs. +Other batch dimensions ``...`` are optional. -At least one batch dimension ``N`` is required and should match among ``query``, ``key`` and ``value`` inputs. -Other batch dimensions ``...`` are optional, if present should match among ``query``, ``key`` and ``value`` inputs as well. +**Examples** -**Example** +*Example 1: One batch dimension, dynamic dimensions support* .. code-block:: xml :force: @@ -101,38 +101,132 @@ Other batch dimensions ``...`` are optional, if present should match among ``que - - 1 - 32 - -1 - 80 + + < !--query --> + 1 < !--N --> + -1 < !--L --> + 80 < !--E --> - - 1 - 32 - -1 - 80 + < !--key --> + 1 < !--N --> + -1 < !--S --> + 80 < !--E --> - - 1 - 32 - -1 - 80 + < !--value --> + 1 < !--N --> + -1 < !--S --> + 80 < !--Ev --> - - 1 - 1 - -1 - -1 + < !--attention_mask --> + 1 < !--N --> + -1 < !--L --> + -1 < !--S --> - 1 - 32 - -1 - 80 + 1 < !--N --> + -1 < !--L --> + 80 < !--Ev --> +*Example 2: Matching multiple batch dimensions* + +.. code-block:: xml + :force: + + + + + + < !--query --> + 1 < !--N1 --> + 2 < !--N2 --> + 3 < !--N3 --> + -1 < !--L --> + 80 < !--E --> + + < !--key --> + 1 < !--N1 --> + 2 < !--N2 --> + 3 < !--N3 --> + -1 < !--S --> + 80 < !--E --> + + < !--value --> + 1 < !--N1 --> + 2 < !--N2 --> + 3 < !--N3 --> + -1 < !--S --> + 80 < !--Ev --> + + < !--attention_mask --> + 1 < !--N1 --> + 2 < !--N2 --> + 3 < !--N3 --> + -1 < !--L --> + -1 < !--S --> + + + + + 1 < !--N1 --> + 2 < !--N2 --> + 3 < !--N3 --> + -1 < !--L --> + 80 < !--Ev --> + + + + +*Example 3: With batch dimensions broadcasting* + +.. code-block:: xml + :force: + + + + + + < !--query --> + 1 < !--N1 (repeat 4 times) --> + 6 < !--N2 (repeat 1 time)--> + 5 < !--N3 (repeat 2 times)--> + -1 < !--L --> + 80 < !--E --> + + < !--key --> + 2 (repeat 2 times)< !--N1 --> + 2 (repeat 3 times)< !--N2 --> + 2 (repeat 5 times)< !--N3 --> + -1 < !--S --> + 80 < !--E --> + + < !--value --> + 4 < !--N1 (repeat 1 time)--> + 3 < !--N2 (repeat 2 times)--> + 10 < !--N3 (repeat 1 time)--> + -1 < !--S --> + 80 < !--Ev --> + + < !--attention_mask --> + 1 < !--N1 (repeat 4 times)--> + 2 < !--N2 (repeat 3 times)--> + 1 < !--N3 (repeat 10 times)--> + -1 < !--L --> + -1 < !--S --> + + + + + + 4 < !--N1 --> + 6 < !--N2 --> + 10 < !--N3 --> + -1 < !--L --> + 80 < !--Ev --> + + + \ No newline at end of file