From 1fc9577c44f60a82d31b15015251af11e63c0215 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Wed, 18 Dec 2024 04:00:27 -0800 Subject: [PATCH] [BugFix] Fix output of `SipHash(as_tensor=False)` (#2664) --- test/test_storage_map.py | 9 +++++++++ torchrl/data/map/hash.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index db2d0bc2c49..90b16db00d3 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -46,6 +46,15 @@ def test_sip_hash(self): hash_b = torch.tensor(hash_module(b)) assert (hash_a == hash_b).all() + def test_sip_hash_nontensor(self): + a = torch.rand((3, 2)) + b = a.clone() + hash_module = SipHash(as_tensor=False) + hash_a = hash_module(a) + hash_b = hash_module(b) + assert len(hash_a) == 3 + assert hash_a == hash_b + @pytest.mark.parametrize("n_components", [None, 14]) @pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000]) def test_randomprojection_hash(self, n_components, scale): diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 59526628dbe..a3ae9ec1ae9 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -111,7 +111,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: hash_value = x_i.tobytes() hash_values.append(hash_value) if not self.as_tensor: - return hash_value + return hash_values result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64) return result