From 28c1166cadf39622a3e6a8d3b16ddb025e41f9c4 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Fri, 25 Oct 2024 00:51:33 +0900 Subject: [PATCH] [Test] attention implementation #228 --- rl4co/models/nn/attention.py | 3 ++- tests/test_utils.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/rl4co/models/nn/attention.py b/rl4co/models/nn/attention.py index b65169f0..0dfa5973 100644 --- a/rl4co/models/nn/attention.py +++ b/rl4co/models/nn/attention.py @@ -19,7 +19,8 @@ def scaled_dot_product_attention_simple( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False ): - """Simple Scaled Dot-Product Attention in PyTorch without Flash Attention""" + """Simple (exact) Scaled Dot-Product Attention in RL4CO without customized kernels (i.e. no Flash Attention).""" + # Check for causal and attn_mask conflict if is_causal and attn_mask is not None: raise ValueError("Cannot set both is_causal and attn_mask") diff --git a/tests/test_utils.py b/tests/test_utils.py index a2494e80..c0f6041a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,9 @@ import torch from tensordict import TensorDict +from torch.nn.functional import scaled_dot_product_attention +from rl4co.models.nn.attention import scaled_dot_product_attention_simple from rl4co.utils.decoding import process_logits from rl4co.utils.ops import batchify, unbatchify @@ -35,3 +37,16 @@ def test_top_k_top_p_sampling(top_p, top_k): mask = torch.ones(8, 10).bool() logprobs = process_logits(logits, mask, top_p=top_p, top_k=top_k) assert len(logprobs) == logits.size(0) + + +def test_scaled_dot_product_attention(): + bs, ns, ds = 2, 3, 4 + q = torch.rand(bs, ns, ds) + k = torch.rand(bs, ns, ds) + v = torch.rand(bs, ns, ds) + attn_mask = torch.rand(bs, ns, ns) > 0.5 + attn_mask[:, 0, :] = True # at least one row element is True + attn_mask[:, :, 0] = True # at least one column element is True + attn_torch = scaled_dot_product_attention(q, k, v, attn_mask) + attn_rl4co = scaled_dot_product_attention_simple(q, k, v, attn_mask) + assert torch.allclose(attn_torch, attn_rl4co)