Skip to content

Commit

Permalink
Add option in MultiHeadAttentionConv to share key transform between a…
Browse files Browse the repository at this point in the history
…ttention heads.

PiperOrigin-RevId: 637708640
  • Loading branch information
Graph Learning Team authored and tensorflower-gardener committed Jul 8, 2024
1 parent d1167ca commit d4b1297
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 16 deletions.
56 changes: 40 additions & 16 deletions tensorflow_gnn/models/multi_head_attention/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase):
only queries are transformed since the two transformations on queries and
keys are equivalent to one. (The presence of transformations on values is
independent of this arg.)
share_key_transform: If true, the same transformation is applied to the
keys within all attention heads. Otherwise, separate transformations
are applied.
score_scaling: One of either `"rsqrt_dim"` (default), `"trainable_elup1"`,
or `"none"`. If set to `"rsqrt_dim"`, the attention scores are
divided by the square root of the dimension of keys (i.e.,
Expand Down Expand Up @@ -203,6 +206,7 @@ def __init__(
kernel_initializer: Any = None,
kernel_regularizer: Any = None,
transform_keys: bool = True,
share_key_transform: bool = False,
score_scaling: Literal["none", "rsqrt_dim",
"trainable_elup1"] = "rsqrt_dim",
transform_values_after_pooling: bool = False,
Expand Down Expand Up @@ -254,6 +258,7 @@ def __init__(
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._transform_keys = transform_keys
self._share_key_transform = share_key_transform
self._score_scaling = score_scaling
self._transform_values_after_pooling = transform_values_after_pooling

Expand All @@ -267,9 +272,14 @@ def __init__(
self._w_sender_node_to_key = None
self._w_sender_edge_to_key = None
else:
if self._share_key_transform:
key_width = per_head_channels
else:
key_width = per_head_channels * num_heads

if self.takes_sender_node_input:
self._w_sender_node_to_key = tf.keras.layers.Dense(
per_head_channels * num_heads,
key_width,
kernel_initializer=tfgnn.keras.clone_initializer(
self._kernel_initializer),
kernel_regularizer=kernel_regularizer,
Expand All @@ -279,7 +289,7 @@ def __init__(
self._w_sender_node_to_key = None
if self.takes_sender_edge_input:
self._w_sender_edge_to_key = tf.keras.layers.Dense(
per_head_channels * num_heads,
key_width,
kernel_initializer=tfgnn.keras.clone_initializer(
self._kernel_initializer),
kernel_regularizer=kernel_regularizer,
Expand Down Expand Up @@ -409,7 +419,7 @@ def convolve(self,
assert receiver_input is not None, "__init__() should have checked this."
queries = self._w_query(receiver_input)
queries = self._attention_activation(queries)
queries = broadcast_from_receiver(self._split_heads(queries))
queries = broadcast_from_receiver(self._split_query_heads(queries))

# Form the attention key for each head.
# If transform_keys is true, the pieces of keys inputs are transformed to
Expand All @@ -432,20 +442,20 @@ def convolve(self,
if sender_node_input is not None and sender_edge_input is None:
# In this special case, we can apply the attention_activation first
# and then broadcast its results.
keys = broadcast_from_sender_node(
self._split_heads(
self._attention_activation(
self._w_sender_node_to_key(sender_node_input))))
keys = self._attention_activation(
self._w_sender_node_to_key(sender_node_input))
keys = self._split_key_heads(keys)
keys = broadcast_from_sender_node(keys)
else:
# In the general case, the attention_activation (if any) comes last.
if sender_node_input is not None:
keys.append(
broadcast_from_sender_node(
self._split_heads(
self._w_sender_node_to_key(sender_node_input))))
node_keys = self._w_sender_node_to_key(sender_node_input)
node_keys = self._split_key_heads(node_keys)
keys.append(broadcast_from_sender_node(node_keys))
if sender_edge_input is not None:
keys.append(
self._split_heads(self._w_sender_edge_to_key(sender_edge_input)))
edge_keys = self._w_sender_edge_to_key(sender_edge_input)
edge_keys = self._split_key_heads(edge_keys)
keys.append(edge_keys)
keys = tf.add_n(keys)
keys = self._attention_activation(keys)

Expand Down Expand Up @@ -501,11 +511,12 @@ def convolve(self,
if sender_node_input is not None:
value_terms.append(
broadcast_from_sender_node(
self._split_heads(
self._split_value_heads(
self._w_sender_node_to_value(sender_node_input))))
if sender_edge_input is not None:
value_terms.append(
self._split_heads(self._w_sender_edge_to_value(sender_edge_input)))
self._split_value_heads(
self._w_sender_edge_to_value(sender_edge_input)))
values = tf.add_n(value_terms)
# Compute the weighed sum.
# [num_receivers, *extra_dims, num_heads, per_head_channels]
Expand Down Expand Up @@ -540,9 +551,22 @@ def convolve(self,
return pooled_values

# The following helpers map back and forth between tensors with...
# - a separate heads dimension: shape [..., num_heads, channels_per_head],
# - a shape that is or can be broadcast to [..., num_heads,
# channels_per_head],
# - all heads concatenated: shape [..., num_heads * channels_per_head].

def _split_key_heads(self, key):
if self._share_key_transform:
return tf.expand_dims(key, axis=-2)
else:
return self._split_heads(key)

def _split_query_heads(self, query):
return self._split_heads(query)

def _split_value_heads(self, value):
return self._split_heads(value)

def _split_heads(self, tensor):
assert tensor.shape[-1] is not None
assert tensor.shape[-1] % self._num_heads == 0, (
Expand Down
91 changes: 91 additions & 0 deletions tensorflow_gnn/models/multi_head_attention/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,97 @@ def testMultihead(self, transform_values_after_pooling):
self.assertAllEqual(got.shape, (3, 6))
self.assertAllClose(got, want, atol=.0001)

@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testShareTransformKeys(self, transform_values_after_pooling):
"""Extends testMultihead with shared key transformation."""
# The same test graph as in the testMultihead above.
gt_input = _get_test_bidi_cycle_graph(
tf.constant([
[1., 0., 0., 1.],
[0., 1., 0., 2.],
[0., 0., 1., 3.],
]))

conv = multi_head_attention.MultiHeadAttentionConv(
num_heads=2,
per_head_channels=3,
receiver_tag=tfgnn.TARGET,
activation="relu",
use_bias=False, # Don't create /bias variables.
score_scaling="none", # Disable score scaling.
transform_values_after_pooling=transform_values_after_pooling,
share_key_transform=True, # Share key transformation across heads.
)

_ = conv(gt_input, edge_set_name="edges") # Build weights.
weights = {v.name: v for v in conv.trainable_weights}
self.assertLen(weights, 3)

weights["multi_head_attention_conv/query/kernel:0"].assign(
# Attention head 0 uses the first three dimensions, which are used
# in the same way as for the testMultihead test above.
# Attention head 1 uses the last three dimensions, in which we
# now favor the clockwise incoming edges.
[
[0., 1., 0., 0., 0., 1.],
[0., 0., 1., 1., 0., 0.],
[1., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0.],
])

# No need for an inverse scaling factor as score scaling is disabled.
weights["multi_head_attention_conv/key_node/kernel:0"].assign(
# To favor the clockwise incoming edges, we use the three
# dimensions, and assign 100 and 0 to corresponding neighbors,
# which gives weights 1 to clockwise incoming edges and weights 0
# to counterclockwise incoming edges. Similar to testMultihead above.
[
[100., 0., 0.],
[0., 100., 0.],
[0., 0., 100.],
[0., 0., 0.],
])

if not transform_values_after_pooling:
# No matter where the -1s are, they got eliminated by ReLU.
# What we expect to see from the two heads is the different
# scaling factor for the last dimension: 1.1 vs 1.0.
weights["multi_head_attention_conv/value_node/kernel:0"].assign([
[0., -1., 0., 0., -1., 0.],
[-1., 0., 0., -1., 0., 0.],
[-1., -1., 0., -1., -1., 0.],
[0., 0., 1.1, 0., 0., 1.],
])
else:
# Same weights, but as Einsum kernel with axes "hvc".
weights["multi_head_attention_conv/value_pooled/kernel:0"].assign([[
[0., -1., 0.],
[-1., 0., 0.],
[-1., -1., 0.],
[0., 0., 1.1],
], [
[0., -1., 0.],
[-1., 0., 0.],
[-1., -1., 0.],
[0., 0., 1.],
]])

got = conv(gt_input, edge_set_name="edges")

# Attention head 0 generates the first four output dimensions, and attention
# head 1 the last two. Since we use the shared key transformation, and the
# same weights as for the second head in testMultihead above, both heads use
# weights 0 and 1. Attention head 1 has the same result as in testMultihead
# while the value transformation for attention head 0 scales by a factor of
# 1.1.
want = tf.constant([
[0., 0., 2.2, 0., 0., 3.0],
[0., 0., 3.3, 0., 0., 1.0],
[0., 0., 1.1, 0., 0., 2.0],
])
self.assertAllEqual(got.shape, (3, 6))
self.assertAllClose(got, want, atol=.0001)

@parameterized.named_parameters(
("", tftu.ModelReloading.SKIP, False),
("TransformAfter", tftu.ModelReloading.SKIP, True),
Expand Down

0 comments on commit d4b1297

Please sign in to comment.