Skip to content

Commit

Permalink
remove segment-anything backbones
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Jul 6, 2023
1 parent 1daba74 commit 891feed
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 196 deletions.
81 changes: 0 additions & 81 deletions lightning_pose/models/backbones/vits.py

This file was deleted.

9 changes: 8 additions & 1 deletion lightning_pose/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,14 @@ def __init__(
self.backbone_arch = backbone

if "sam" in self.backbone_arch:
from lightning_pose.models.backbones.vits import build_backbone
raise NotImplementedError(
"segment-anything backbones are not supported in this branch."
"If you are running the code from a github installation, switch to the branch"
"features/vit."
"If you have pip installed lightning pose, there is no access to segment-anything"
"models due to dependency/installation issues. "
"For more information please contatct the package maintainers."
)
else:
from lightning_pose.models.backbones.torchvision import build_backbone

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from setuptools import find_packages, setup

VERSION = "0.0.2" # was previously None
VERSION = "0.0.2"

# add the README.md file to the long_description
with open("README.md", "r") as fh:
Expand Down Expand Up @@ -70,7 +70,6 @@ def get_cuda_version():
"typeguard==3.0.2",
"typing==3.7.4.3",
"botocore==1.27.59",
"segment_anything @ git+https://github.com/facebookresearch/segment-anything.git",
]


Expand Down
58 changes: 0 additions & 58 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
WIDTHS = [120, 246, 380] # similar but not square
RESNET_BACKBONES = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
EFFICIENTNET_BACKBONES = ["efficientnet_b0", "efficientnet_b1", "efficientnet_b2"]
VIT_BACKBONES = ["vit_b_sam"] # "vit_h_sam" very large (2.6GB), takes too long to download/load


def test_backbones_resnet():
Expand Down Expand Up @@ -48,18 +47,6 @@ def test_backbones_efficientnet():
torch.cuda.empty_cache() # remove tensors from gpu


def test_backbones_vit():
for ind, backbone in enumerate(VIT_BACKBONES):
model = BaseFeatureExtractor(backbone=backbone).to(_TORCH_DEVICE)
assert (
type(list(model.backbone.children())[0])
== segment_anything.modeling.image_encoder.PatchEmbed
)
# remove model from gpu; then cache can be cleared
del model
torch.cuda.empty_cache() # remove tensors from gpu


def test_representation_shapes_resnet():

# loop over different backbone versions and make sure that the resulting
Expand Down Expand Up @@ -162,48 +149,3 @@ def test_representation_shapes_efficientnet():
del model

torch.cuda.empty_cache() # remove tensors from gpu


def test_representation_shapes_vit():

# loop over different backbone versions and make sure that the resulting
# representation shapes make sense

# 128x128
rep_shape_list_small_image = [
torch.Size([BATCH_SIZE, 256, 8, 8]), # vit_b_sam
]
# 256x256
rep_shape_list_medium_image = [
torch.Size([BATCH_SIZE, 256, 16, 16]),
]
# 384x384
rep_shape_list_large_image = [
torch.Size([BATCH_SIZE, 256, 24, 24]),
]
shape_list_pre_pool = [
rep_shape_list_small_image,
rep_shape_list_medium_image,
rep_shape_list_large_image,
]

for idx_backbone, backbone in enumerate(VIT_BACKBONES):
for idx_image in range(len(HEIGHTS)):
if _TORCH_DEVICE == "cuda":
torch.cuda.empty_cache()
model = BaseFeatureExtractor(
backbone=backbone, image_size=HEIGHTS[idx_image]
).to(_TORCH_DEVICE)
fake_image_batch = torch.rand(
size=(BATCH_SIZE, 3, HEIGHTS[idx_image], HEIGHTS[idx_image]),
device=_TORCH_DEVICE,
)
# representation dim depends on both image size and backbone network
representations = model(fake_image_batch)
assert representations.shape == shape_list_pre_pool[idx_image][idx_backbone]
# remove model/data from gpu; then cache can be cleared
del fake_image_batch
del representations
del model

torch.cuda.empty_cache() # remove tensors from gpu
27 changes: 0 additions & 27 deletions tests/models/test_heatmap_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,30 +91,3 @@ def test_semisupervised_heatmap_pcasingleview_context(
trainer=trainer,
remove_logs_fn=remove_logs,
)


def test_semisupervised_heatmap_pcasingleview_context_vit(
cfg_context,
heatmap_data_module_combined_context,
video_dataloader,
trainer,
remove_logs,
):
"""Test the initialization and training of a semi-supervised heatmap context model ViT backone.
NOTE: the toy dataset is not a proper context dataset
"""

cfg_tmp = copy.deepcopy(cfg_context)
cfg_tmp.model.backbone = "vit_b_sam"
cfg_tmp.model.model_type = "heatmap"
cfg_tmp.model.losses_to_use = ["pca_singleview"]

run_model_test(
cfg=cfg_tmp,
data_module=heatmap_data_module_combined_context,
video_dataloader=video_dataloader,
trainer=trainer,
remove_logs_fn=remove_logs,
)
27 changes: 0 additions & 27 deletions tests/models/test_heatmap_tracker_mhcrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,3 @@ def test_semisupervised_heatmap_mhcrnn_pcasingleview(
trainer=trainer,
remove_logs_fn=remove_logs,
)


def test_semisupervised_heatmap_mhcrnn_pcasingleview_vit(
cfg_context,
heatmap_data_module_combined_context,
video_dataloader,
trainer,
remove_logs,
):
"""Test the initialization and training of a semi-supervised heatmap mhcrnn model ViT backbone.
NOTE: the toy dataset is not a proper context dataset
"""

cfg_tmp = copy.deepcopy(cfg_context)
cfg_tmp.model.backbone = "vit_b_sam"
cfg_tmp.model.model_type = "heatmap_mhcrnn"
cfg_tmp.model.losses_to_use = ["pca_singleview"]

run_model_test(
cfg=cfg_tmp,
data_module=heatmap_data_module_combined_context,
video_dataloader=video_dataloader,
trainer=trainer,
remove_logs_fn=remove_logs,
)

0 comments on commit 891feed

Please sign in to comment.