From 6b98b174e62d081520e57183776d5b1b04cfb98b Mon Sep 17 00:00:00 2001 From: LukasMut Date: Wed, 16 Nov 2022 14:56:26 +0100 Subject: [PATCH] flatten acts fix --- setup.py | 3 ++- thingsvision/_version.py | 2 +- thingsvision/core/extraction/helpers.py | 13 +++++-------- thingsvision/core/extraction/mixin.py | 5 ++++- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 6c9ff5b6..bfb2ddba 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,8 @@ 'scikit-learn==1.1.*', 'scipy==1.8.1', 'h5py==3.7.0', - 'CLIP @ git+https://github.com/openai/CLIP.git' + 'CLIP', + # 'CLIP @ git+ssh://git@github.com/openai/CLIP@v1.0#egg=CLIP' ] setuptools.setup( diff --git a/thingsvision/_version.py b/thingsvision/_version.py index f1e49f68..90a1f38f 100644 --- a/thingsvision/_version.py +++ b/thingsvision/_version.py @@ -1 +1 @@ -__version__ = "2.2.5" +__version__ = "2.2.7" diff --git a/thingsvision/core/extraction/helpers.py b/thingsvision/core/extraction/helpers.py index 7806fa13..12690dbd 100644 --- a/thingsvision/core/extraction/helpers.py +++ b/thingsvision/core/extraction/helpers.py @@ -58,21 +58,18 @@ def show_model(self): print(n) print("visual") - def forward(self, batch: Tensor, module_name: str = "visual") -> Tensor: + @staticmethod + def forward(batch: Tensor) -> Tensor: img_features = model.encode_image(batch) - # if module_name == "visual": - # assert torch.unique( - # activations[module_name] == img_features - # ).item(), "\nFor CLIP, image features should represent activations in last encoder layer.\n" - return img_features - def flatten_acts(self, act: Tensor, img: Tensor, module_name: str) -> Tensor: + @staticmethod + def flatten_acts(act: Tensor, batch: Tensor, module_name: str) -> Tensor: if module_name.endswith("attn"): if isinstance(act, tuple): act = act[0] else: - if act.size(0) != img.shape[0] and len(act.shape) == 3: + if act.size(0) != batch.shape[0] and len(act.shape) == 3: act = act.permute(1, 0, 2) act = act.view(act.size(0), -1) return act diff --git a/thingsvision/core/extraction/mixin.py b/thingsvision/core/extraction/mixin.py index a880dfaf..23a2509a 100644 --- a/thingsvision/core/extraction/mixin.py +++ b/thingsvision/core/extraction/mixin.py @@ -55,7 +55,10 @@ def _extract_features( _ = self.forward(batch) act = activations[module_name] if flatten_acts: - act = self.flatten_acts(act) + if self.model_name.lower().startswith('clip'): + act = self.flatten_acts(act, batch, module_name) + else: + act = self.flatten_acts(act) act = self._to_numpy(act) return act