Skip to content

Commit

Permalink
Update attention.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 10, 2023
1 parent 4b9188f commit 3193607
Showing 1 changed file with 69 additions and 29 deletions.
98 changes: 69 additions & 29 deletions serket/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,30 @@ def merge_heads(array: jax.Array) -> jax.Array:


def is_lazy_call(instance, *_, **__) -> bool:
return getattr(instance, "qkv_features", False) is None
return getattr(instance, "q_features", False) is None


def is_lazy_init(_, num_heads, qkv_features, *__, **___) -> bool:
return qkv_features is None
def is_lazy_init(_, num_heads, q_features, *__, **___) -> bool:
return q_features is None


def infer_qkv_features(_, q_array, *__, **___) -> int:
def infer_q_features(_, q_array, *__, **___) -> int:
return q_array.shape[-1]


attention_updates = dict(qkv_features=infer_qkv_features)
def infer_k_features(_, __, k_array, *___, **____) -> int:
return k_array.shape[-1]


def infer_v_features(_, __, ___, v_array, *____, **_____) -> int:
return v_array.shape[-1]


attention_updates = dict(
q_features=infer_q_features,
k_features=infer_k_features,
v_features=infer_v_features,
)


def calculate_attention(
Expand Down Expand Up @@ -105,7 +117,9 @@ class MultiHeadAttention(sk.TreeClass):
Args:
num_heads: Number of attention heads.
qkv_features: Number of features for the query.
q_features: Number of features for the query.
k_features: Number of features for the key.
v_features: Number of features for the value.
out_features: Number of features for the output.
q_weight_init: Initializer for the query weight. Defaults to glorot_uniform.
q_bias_init: Initializer for the query bias. Defaults to zeros. use
Expand All @@ -126,27 +140,40 @@ class MultiHeadAttention(sk.TreeClass):
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> batch = 3
>>> num_heads = 2
>>> qkv_features = 4
>>> q_features = 4
>>> k_features = 8
>>> v_features = 6
>>> q_length = 4
>>> kv_length = 2
>>> mask = jr.uniform(jr.PRNGKey(2), (batch, num_heads, q_length, kv_length))
>>> mask = (mask > 0.5).astype(jnp.float32)
>>> q = jr.uniform(jr.PRNGKey(0), (batch, q_length, qkv_features))
>>> k = jr.uniform(jr.PRNGKey(1), (batch, kv_length, qkv_features))
>>> v = jr.uniform(jr.PRNGKey(2), (batch, kv_length, qkv_features))
>>> layer = sk.nn.MultiHeadAttention(num_heads, qkv_features, drop_rate=0.0)
>>> q = jr.uniform(jr.PRNGKey(0), (batch, q_length, q_features))
>>> k = jr.uniform(jr.PRNGKey(1), (batch, kv_length, k_features))
>>> v = jr.uniform(jr.PRNGKey(2), (batch, kv_length, v_features))
>>> layer = sk.nn.MultiHeadAttention(
... num_heads,
... q_features,
... k_features,
... v_features,
... drop_rate=0.0,
... )
>>> print(layer(q, k, v, mask=mask, key=jr.PRNGKey(0)).shape)
(3, 4, 4)
Note:
If ``k_features``, ``v_features``, ``out_features`` are not specified,
they are set to ``q_features``.
Note:
:class:`.MultiHeadAttention` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``qkv_features`` argument
To use lazy initialization, pass ``None`` as the ``q_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
Expand All @@ -171,7 +198,9 @@ class MultiHeadAttention(sk.TreeClass):
def __init__(
self,
num_heads: int,
qkv_features: int,
q_features: int,
k_features: int | None = None,
v_features: int | None = None,
out_features: int | None = None,
q_weight_init: InitType = "glorot_uniform",
q_bias_init: InitType = "zeros",
Expand All @@ -185,36 +214,47 @@ def __init__(
drop_broadcast: bool = False,
key: jr.KeyArray = jr.PRNGKey(0),
):
if qkv_features % num_heads != 0:
raise ValueError(f"{qkv_features=} % {num_heads=} != 0.")
k_features = q_features if k_features is None else k_features
v_features = q_features if v_features is None else v_features
out_features = q_features if out_features is None else out_features

if q_features % num_heads != 0:
raise ValueError(f"{q_features=} % {num_heads=} != 0.")

if k_features % num_heads != 0:
raise ValueError(f"{k_features=} % {num_heads=} != 0.")

head_features = qkv_features // num_heads
out_features = qkv_features if out_features is None else out_features
if v_features % num_heads != 0:
raise ValueError(f"{v_features=} % {num_heads=} != 0.")

if out_features % num_heads != 0:
raise ValueError(f"{out_features=} % {num_heads=} != 0.")

head_features = q_features // num_heads
qkey, kkey, vkey, okey = jr.split(key, 4)

self.num_heads = num_heads
drop_axes = (-1, -2) if drop_broadcast else ...
self.dropout = sk.nn.GeneralDropout(drop_rate, drop_axes)

self.q_projection = sk.nn.Linear(
in_features=qkv_features,
in_features=q_features,
out_features=head_features * num_heads,
weight_init=q_weight_init,
bias_init=q_bias_init,
key=qkey,
)

self.k_projection = sk.nn.Linear(
in_features=qkv_features,
in_features=k_features,
out_features=head_features * num_heads,
weight_init=k_weight_init,
bias_init=k_bias_init,
key=kkey,
)

self.v_projection = sk.nn.Linear(
in_features=qkv_features,
in_features=v_features,
out_features=head_features * num_heads,
weight_init=v_weight_init,
bias_init=v_bias_init,
Expand All @@ -232,27 +272,27 @@ def __init__(
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=attention_updates)
def __call__(
self,
q_array: Annotated[jax.Array, "..., q_length, qkv_features"],
k_array: Annotated[jax.Array, "..., kv_length, qkv_features"],
v_array: Annotated[jax.Array, "..., kv_length, qkv_features"],
q_array: Annotated[jax.Array, "..., q_length, q_features"],
k_array: Annotated[jax.Array, "..., kv_length, k_features"],
v_array: Annotated[jax.Array, "..., kv_length, v_features"],
mask: Annotated[jax.Array, "..., num_heads, q_length, kv_length"] | None = None,
key: jr.KeyArray = jr.PRNGKey(0),
) -> Annotated[jax.Array, "..., q_length, out_features"]:
"""Applies multi-head attention to the given inputs.
Args:
q_array: Query array. [..., q_length, qkv_features]
k_array: Key array. [..., k_length, qkv_features]
v_array: Value array. [..., v_length, qkv_features]
q_array: Query array. [..., q_length, q_features]
k_array: Key array. [..., kv_length, k_features]
v_array: Value array. [..., kv_length, v_features]
mask: Mask array. [..., num_heads, q_length, kv_length]
key: Key for the random number generator.
"""

# [..., q_length, qkv_features] -> [..., q_length, head_features*num_heads]
# [..., q_length, q_features] -> [..., q_length, head_features*num_heads]
q_heads = self.q_projection(q_array)
# [..., k_length, qkv_features] -> [..., k_length, head_features*num_heads]
# [..., k_length, k_features] -> [..., k_length, head_features*num_heads]
k_heads = self.k_projection(k_array)
# [..., v_length, qkv_features] -> [..., v_length, head_features*num_heads]
# [..., v_length, v_features] -> [..., v_length, head_features*num_heads]
v_heads = self.v_projection(v_array)

attention = calculate_attention(
Expand Down

0 comments on commit 3193607

Please sign in to comment.