Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut committed Apr 1, 2024
1 parent 044c5c5 commit e134552
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
34 changes: 19 additions & 15 deletions tests/extractor/extraction/test_torch_vs_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,25 @@ def test_custom_torch_vs_tf_extraction(self):
pt_model.backend = pt_backend

layer_name = "relu"
tf_features = tf_model.extract_features(
batches=tf_dl,
module_name=layer_name,
flatten_acts=False,
)
pt_features = pt_model.extract_features(
batches=pt_dl,
module_name=layer_name,
flatten_acts=False,
output_type="tensor",
)
expected_features_pt = torch.tensor([[2, 2], [0, 0]])
expected_features_tf = np.array([[2, 2], [0, 0]])
np.testing.assert_allclose(pt_features, expected_features_pt)
np.testing.assert_allclose(tf_features, expected_features_tf)
expected_features_pt = torch.tensor([[2., 2.], [0., 0.]])
expected_features_tf = np.array([[2., 2.], [0, 0.]])

for i, batch in enumerate(tf_dl):
tf_features = tf_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
)
np.testing.assert_allclose(tf_features, expected_features_tf[i][None,:])

for i, batch in enumerate(pt_dl):
pt_features = pt_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
output_type="tensor",
)
np.testing.assert_allclose(pt_features, expected_features_pt[i][None,:])

layer_name = "relu2"
expected_features = np.array([[4., 4.], [0., 0.]])
Expand Down
6 changes: 6 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@
"pretrained": True,
"source": "keras",
},
"VGG19_keras": {
"model_name": "VGG19",
"modules": ["block1_conv1", "flatten"],
"pretrained": False,
"source": "keras",
},
# Vissl models
"simclr-rn50": {
"model_name": "simclr-rn50",
Expand Down
3 changes: 1 addition & 2 deletions thingsvision/core/extraction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def extract_batch(
Returns the feature matrix (e.g., $X \in \mathbb{R}^{B \times d}$ if penultimate or logits layer or flatten_acts = True).
"""
raise NotImplementedError

Check warning on line 107 in thingsvision/core/extraction/base.py

View check run for this annotation

Codecov / codecov/patch

thingsvision/core/extraction/base.py#L107

Added line #L107 was not covered by tests



@abc.abstractmethod
def _extract_batch(
self,
Expand Down
2 changes: 1 addition & 1 deletion thingsvision/core/extraction/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
if not self.model:
self.load_model()
self.prepare_inference()

def _extract_batch(
self,
batch: Array,
Expand Down
2 changes: 1 addition & 1 deletion thingsvision/core/extraction/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def extract_batch(
act = self._to_numpy(act)
self._unregister_hook()
return act

@torch.no_grad()
def _extract_batch(
self,
Expand Down

0 comments on commit e134552

Please sign in to comment.