Skip to content

Commit

Permalink
[SPEC] Update spec for broadcasting of batch dimensions, improve clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem committed Dec 18, 2023
1 parent 7b00e80 commit 19ac143
Showing 1 changed file with 125 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *,
* **3**: ``value`` - at least 3 dimensional tensor of type *T* and shape ``[N, ..., S, Ev]``. **Required.**

* **4**: ``attention_mask`` - two options:
** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[M, ..., L, S]``, or
** a scalar of type *T* with value ``0``. Scalar zero value is used to indicate that `attention_mask` is really not required to be applied (``attention_mask=None`` in the pseudo-code above) but ``scale`` is required to be set.
** at least 3 dimensional tensor of type *T* or ``boolean`` and shape ``[N, ..., L, S]``, or
** a scalar of type *T* with value ``0``. Scalar zero value signals that applying an attention mask is not necessary (similar to specifying attention_mask=None in the provided pseudo-code).

``attention_mask`` is ignored if ``causal`` is set to ``True``. **Optional.**

Expand All @@ -77,7 +77,7 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *,

**Dimensions**

* ``N, ...`` - one or more batch dimensions
* ``N, ...`` - one or more batch dimensions. Each batch dimension should be either constant across the input tensors (query, key, and value), indicating that they have the same batch size, or they should be broadcastable to the same value.

* ``S`` - source sequence length

Expand All @@ -87,52 +87,146 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *,

* ``Ev`` - embedding dimension of the value

* ``M, ...`` - one of more batch dimensions of the mask, should be broadcastable to ``N, ...``
At least one batch dimension ``N`` is required in ``query``, ``key`` and ``value`` inputs.
Other batch dimensions ``...`` are optional.

At least one batch dimension ``N`` is required and should match among ``query``, ``key`` and ``value`` inputs.
Other batch dimensions ``...`` are optional, if present should match among ``query``, ``key`` and ``value`` inputs as well.

**Examples**

**Example**
*Example 1: One batch dimension, dynamic dimensions support*

.. code-block:: xml
:force:
<layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
<!-- Example with simple dimensions, with N = 1, L = -1, S = -1, E = 80, Ev = 80-->
<port id="0" precision="FP32"> < !--query -->
<dim>1</dim> < !--N -->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--E -->
</port>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
<port id="1" precision="FP32"> < !--key -->
<dim>1</dim> < !--N -->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--E -->
</port>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
<port id="2" precision="FP32"> < !--value -->
<dim>1</dim> < !--N -->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--Ev -->
</port>
<port id="3" precision="FP32">
<dim>1</dim>
<dim>1</dim>
<dim>-1</dim>
<dim>-1</dim>
<port id="3" precision="FP32"> < !--attention_mask -->
<dim>1</dim> < !--N -->
<dim>-1</dim> < !--L -->
<dim>-1</dim> < !--S -->
</port>
</input>
<output>
<port id="4" precision="FP32">
<dim>1</dim>
<dim>32</dim>
<dim>-1</dim>
<dim>80</dim>
<dim>1</dim> < !--N -->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--Ev -->
</port>
</output>
</layer>
*Example 2: Matching multiple batch dimensions*

.. code-block:: xml
:force:
<layer id="286" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<!-- Multiple batch dimensions: N1 = 1, N2 = 2, N3 = 3-->
<port id="0" precision="FP32"> < !--query -->
<dim>1</dim> < !--N1 -->
<dim>2</dim> < !--N2 -->
<dim>3</dim> < !--N3 -->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--E -->
</port>
<port id="1" precision="FP32"> < !--key -->
<dim>1</dim> < !--N1 -->
<dim>2</dim> < !--N2 -->
<dim>3</dim> < !--N3 -->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--E -->
</port>
<port id="2" precision="FP32"> < !--value -->
<dim>1</dim> < !--N1 -->
<dim>2</dim> < !--N2 -->
<dim>3</dim> < !--N3 -->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--Ev -->
</port>
<port id="3" precision="FP32"> < !--attention_mask -->
<dim>1</dim> < !--N1 -->
<dim>2</dim> < !--N2 -->
<dim>3</dim> < !--N3 -->
<dim>-1</dim> < !--L -->
<dim>-1</dim> < !--S -->
</port>
</input>
<output>
<port id="4" precision="FP32">
<dim>1</dim> < !--N1 -->
<dim>2</dim> < !--N2 -->
<dim>3</dim> < !--N3 -->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--Ev -->
</port>
</output>
</layer>
*Example 3: With batch dimensions broadcasting*

.. code-block:: xml
:force:
<layer id="287" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<!-- Multiple batch dimensions, broadcastable to the following values: N1 = 4, N2 = 6, N3 = 10-->
<port id="0" precision="FP32"> < !--query -->
<dim>1</dim> < !--N1 (repeat 4 times) -->
<dim>6</dim> < !--N2 (repeat 1 time)-->
<dim>5</dim> < !--N3 (repeat 2 times)-->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--E -->
</port>
<port id="1" precision="FP32"> < !--key -->
<dim>2</dim> (repeat 2 times)< !--N1 -->
<dim>2</dim> (repeat 3 times)< !--N2 -->
<dim>2</dim> (repeat 5 times)< !--N3 -->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--E -->
</port>
<port id="2" precision="FP32"> < !--value -->
<dim>4</dim> < !--N1 (repeat 1 time)-->
<dim>3</dim> < !--N2 (repeat 2 times)-->
<dim>10</dim> < !--N3 (repeat 1 time)-->
<dim>-1</dim> < !--S -->
<dim>80</dim> < !--Ev -->
</port>
<port id="3" precision="FP32"> < !--attention_mask -->
<dim>1</dim> < !--N1 (repeat 4 times)-->
<dim>2</dim> < !--N2 (repeat 3 times)-->
<dim>1</dim> < !--N3 (repeat 10 times)-->
<dim>-1</dim> < !--L -->
<dim>-1</dim> < !--S -->
</port>
</input>
<output>
<!-- Output contains broadcasted dimensions N1 = 4, N2 = 6, N3 = 10-->
<port id="4" precision="FP32">
<dim>4</dim> < !--N1 -->
<dim>6</dim> < !--N2 -->
<dim>10</dim> < !--N3 -->
<dim>-1</dim> < !--L -->
<dim>80</dim> < !--Ev -->
</port>
</output>
</layer>

0 comments on commit 19ac143

Please sign in to comment.