From 891feed1f8aae2ba1411b4d533dde188fc6fd8f0 Mon Sep 17 00:00:00 2001 From: themattinthehatt Date: Thu, 6 Jul 2023 12:03:26 -0400 Subject: [PATCH] remove segment-anything backbones --- lightning_pose/models/backbones/vits.py | 81 --------------------- lightning_pose/models/base.py | 9 ++- setup.py | 3 +- tests/models/test_base.py | 58 --------------- tests/models/test_heatmap_tracker.py | 27 ------- tests/models/test_heatmap_tracker_mhcrnn.py | 27 ------- 6 files changed, 9 insertions(+), 196 deletions(-) delete mode 100644 lightning_pose/models/backbones/vits.py diff --git a/lightning_pose/models/backbones/vits.py b/lightning_pose/models/backbones/vits.py deleted file mode 100644 index 006d36f1..00000000 --- a/lightning_pose/models/backbones/vits.py +++ /dev/null @@ -1,81 +0,0 @@ -from functools import partial - -import torch -from segment_anything.modeling import ImageEncoderViT -from typeguard import typechecked - - -@typechecked -def build_backbone(backbone_arch: str, image_size: int = 256, **kwargs): - """Load backbone weights for resnet models. - - Args: - backbone_arch: which backbone version/weights to use - image_size: height/width in pixels of images (must be square) - - Returns: - tuple - - backbone: pytorch model - - num_fc_input_features (int): number of input features to fully connected layer - - """ - - # load backbone weights - if "vit_h_sam" in backbone_arch: - ckpt_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" - state_dict = torch.hub.load_state_dict_from_url(ckpt_url) - encoder_embed_dim = 1280 - encoder_depth = 32 - encoder_num_heads = 16 - encoder_global_attn_indexes = (7, 15, 23, 31) - prompt_embed_dim = 256 - image_size = image_size - vit_patch_size = 16 - base = ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - ) - base.load_state_dict(state_dict, strict=False) - - elif "vit_b_sam" in backbone_arch: - ckpt_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" - state_dict = torch.hub.load_state_dict_from_url(ckpt_url) - encoder_embed_dim = 768 - encoder_depth = 12 - encoder_num_heads = 12 - encoder_global_attn_indexes = (2, 5, 8, 11) - prompt_embed_dim = 256 - image_size = image_size - vit_patch_size = 16 - base = ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - ) - base.load_state_dict(state_dict, strict=False) - - else: - raise NotImplementedError - - num_fc_input_features = base.neck[-2].in_channels - - return base, num_fc_input_features diff --git a/lightning_pose/models/base.py b/lightning_pose/models/base.py index cf12f89d..877e71e5 100644 --- a/lightning_pose/models/base.py +++ b/lightning_pose/models/base.py @@ -85,7 +85,14 @@ def __init__( self.backbone_arch = backbone if "sam" in self.backbone_arch: - from lightning_pose.models.backbones.vits import build_backbone + raise NotImplementedError( + "segment-anything backbones are not supported in this branch." + "If you are running the code from a github installation, switch to the branch" + "features/vit." + "If you have pip installed lightning pose, there is no access to segment-anything" + "models due to dependency/installation issues. " + "For more information please contatct the package maintainers." + ) else: from lightning_pose.models.backbones.torchvision import build_backbone diff --git a/setup.py b/setup.py index 739a899a..0f9cc120 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import find_packages, setup -VERSION = "0.0.2" # was previously None +VERSION = "0.0.2" # add the README.md file to the long_description with open("README.md", "r") as fh: @@ -70,7 +70,6 @@ def get_cuda_version(): "typeguard==3.0.2", "typing==3.7.4.3", "botocore==1.27.59", - "segment_anything @ git+https://github.com/facebookresearch/segment-anything.git", ] diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 63143fc5..382c8d05 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -13,7 +13,6 @@ WIDTHS = [120, 246, 380] # similar but not square RESNET_BACKBONES = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] EFFICIENTNET_BACKBONES = ["efficientnet_b0", "efficientnet_b1", "efficientnet_b2"] -VIT_BACKBONES = ["vit_b_sam"] # "vit_h_sam" very large (2.6GB), takes too long to download/load def test_backbones_resnet(): @@ -48,18 +47,6 @@ def test_backbones_efficientnet(): torch.cuda.empty_cache() # remove tensors from gpu -def test_backbones_vit(): - for ind, backbone in enumerate(VIT_BACKBONES): - model = BaseFeatureExtractor(backbone=backbone).to(_TORCH_DEVICE) - assert ( - type(list(model.backbone.children())[0]) - == segment_anything.modeling.image_encoder.PatchEmbed - ) - # remove model from gpu; then cache can be cleared - del model - torch.cuda.empty_cache() # remove tensors from gpu - - def test_representation_shapes_resnet(): # loop over different backbone versions and make sure that the resulting @@ -162,48 +149,3 @@ def test_representation_shapes_efficientnet(): del model torch.cuda.empty_cache() # remove tensors from gpu - - -def test_representation_shapes_vit(): - - # loop over different backbone versions and make sure that the resulting - # representation shapes make sense - - # 128x128 - rep_shape_list_small_image = [ - torch.Size([BATCH_SIZE, 256, 8, 8]), # vit_b_sam - ] - # 256x256 - rep_shape_list_medium_image = [ - torch.Size([BATCH_SIZE, 256, 16, 16]), - ] - # 384x384 - rep_shape_list_large_image = [ - torch.Size([BATCH_SIZE, 256, 24, 24]), - ] - shape_list_pre_pool = [ - rep_shape_list_small_image, - rep_shape_list_medium_image, - rep_shape_list_large_image, - ] - - for idx_backbone, backbone in enumerate(VIT_BACKBONES): - for idx_image in range(len(HEIGHTS)): - if _TORCH_DEVICE == "cuda": - torch.cuda.empty_cache() - model = BaseFeatureExtractor( - backbone=backbone, image_size=HEIGHTS[idx_image] - ).to(_TORCH_DEVICE) - fake_image_batch = torch.rand( - size=(BATCH_SIZE, 3, HEIGHTS[idx_image], HEIGHTS[idx_image]), - device=_TORCH_DEVICE, - ) - # representation dim depends on both image size and backbone network - representations = model(fake_image_batch) - assert representations.shape == shape_list_pre_pool[idx_image][idx_backbone] - # remove model/data from gpu; then cache can be cleared - del fake_image_batch - del representations - del model - - torch.cuda.empty_cache() # remove tensors from gpu diff --git a/tests/models/test_heatmap_tracker.py b/tests/models/test_heatmap_tracker.py index 6dbfa8f9..3af60c7d 100644 --- a/tests/models/test_heatmap_tracker.py +++ b/tests/models/test_heatmap_tracker.py @@ -91,30 +91,3 @@ def test_semisupervised_heatmap_pcasingleview_context( trainer=trainer, remove_logs_fn=remove_logs, ) - - -def test_semisupervised_heatmap_pcasingleview_context_vit( - cfg_context, - heatmap_data_module_combined_context, - video_dataloader, - trainer, - remove_logs, -): - """Test the initialization and training of a semi-supervised heatmap context model ViT backone. - - NOTE: the toy dataset is not a proper context dataset - - """ - - cfg_tmp = copy.deepcopy(cfg_context) - cfg_tmp.model.backbone = "vit_b_sam" - cfg_tmp.model.model_type = "heatmap" - cfg_tmp.model.losses_to_use = ["pca_singleview"] - - run_model_test( - cfg=cfg_tmp, - data_module=heatmap_data_module_combined_context, - video_dataloader=video_dataloader, - trainer=trainer, - remove_logs_fn=remove_logs, - ) diff --git a/tests/models/test_heatmap_tracker_mhcrnn.py b/tests/models/test_heatmap_tracker_mhcrnn.py index c6b9ef75..4fcec58d 100644 --- a/tests/models/test_heatmap_tracker_mhcrnn.py +++ b/tests/models/test_heatmap_tracker_mhcrnn.py @@ -47,30 +47,3 @@ def test_semisupervised_heatmap_mhcrnn_pcasingleview( trainer=trainer, remove_logs_fn=remove_logs, ) - - -def test_semisupervised_heatmap_mhcrnn_pcasingleview_vit( - cfg_context, - heatmap_data_module_combined_context, - video_dataloader, - trainer, - remove_logs, -): - """Test the initialization and training of a semi-supervised heatmap mhcrnn model ViT backbone. - - NOTE: the toy dataset is not a proper context dataset - - """ - - cfg_tmp = copy.deepcopy(cfg_context) - cfg_tmp.model.backbone = "vit_b_sam" - cfg_tmp.model.model_type = "heatmap_mhcrnn" - cfg_tmp.model.losses_to_use = ["pca_singleview"] - - run_model_test( - cfg=cfg_tmp, - data_module=heatmap_data_module_combined_context, - video_dataloader=video_dataloader, - trainer=trainer, - remove_logs_fn=remove_logs, - )