Skip to content

Commit

Permalink
In gat_v2, clean up a pre-TF2.10 workaround for EinsumDense.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576557728
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Oct 25, 2023
1 parent 5d67fe7 commit 80fd6a9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorflow_gnn/models/gat_v2/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def __init__(self,
# use a single Dense layer that outputs `num_heads` units because we need
# to apply a different attention function a_k to its corresponding
# W_k-transformed features.
self._attention_logits_fn = tf.keras.layers.experimental.EinsumDense(
self._attention_logits_fn = tf.keras.layers.EinsumDense(
"...ik,ki->...i",
output_shape=(None, num_heads, 1), # TODO(b/205825425): (num_heads,)
output_shape=(num_heads,),
kernel_initializer=tfgnn.keras.clone_initializer(
self._kernel_initializer),
kernel_regularizer=kernel_regularizer,
Expand Down

0 comments on commit 80fd6a9

Please sign in to comment.