Skip to content

Commit

Permalink
Merge pull request #110 from ViCCo-Group/flatten_acts_fix
Browse files Browse the repository at this point in the history
fix issue #109
  • Loading branch information
LukasMut authored Nov 16, 2022
2 parents 839808b + 6b98b17 commit 4a5090c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'scikit-learn==1.1.*',
'scipy==1.8.1',
'h5py==3.7.0',
'CLIP @ git+https://github.com/openai/CLIP.git'
'CLIP',
# 'CLIP @ git+ssh://[email protected]/openai/[email protected]#egg=CLIP'
]

setuptools.setup(
Expand Down
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.5"
__version__ = "2.2.7"
13 changes: 5 additions & 8 deletions thingsvision/core/extraction/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,18 @@ def show_model(self):
print(n)
print("visual")

def forward(self, batch: Tensor, module_name: str = "visual") -> Tensor:
@staticmethod
def forward(batch: Tensor) -> Tensor:
img_features = model.encode_image(batch)
# if module_name == "visual":
# assert torch.unique(
# activations[module_name] == img_features
# ).item(), "\nFor CLIP, image features should represent activations in last encoder layer.\n"

return img_features

def flatten_acts(self, act: Tensor, img: Tensor, module_name: str) -> Tensor:
@staticmethod
def flatten_acts(act: Tensor, batch: Tensor, module_name: str) -> Tensor:
if module_name.endswith("attn"):
if isinstance(act, tuple):
act = act[0]
else:
if act.size(0) != img.shape[0] and len(act.shape) == 3:
if act.size(0) != batch.shape[0] and len(act.shape) == 3:
act = act.permute(1, 0, 2)
act = act.view(act.size(0), -1)
return act
Expand Down
5 changes: 4 additions & 1 deletion thingsvision/core/extraction/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def _extract_features(
_ = self.forward(batch)
act = activations[module_name]
if flatten_acts:
act = self.flatten_acts(act)
if self.model_name.lower().startswith('clip'):
act = self.flatten_acts(act, batch, module_name)
else:
act = self.flatten_acts(act)
act = self._to_numpy(act)
return act

Expand Down

0 comments on commit 4a5090c

Please sign in to comment.