Skip to content

Commit

Permalink
Add checkpoint to textsam.LangSAM() (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
forestbat authored Oct 13, 2023
1 parent 0029a6d commit 22d26eb
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions samgeo/text_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class LangSAM:
A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
"""

def __init__(self, model_type="vit_h"):
def __init__(self, model_type="vit_h", checkpoint=None):
"""Initialize the LangSAM instance.
Args:
Expand All @@ -119,7 +119,7 @@ def __init__(self, model_type="vit_h"):

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
self.build_sam(model_type)
self.build_sam(model_type, checkpoint)

self.source = None
self.image = None
Expand All @@ -129,17 +129,21 @@ def __init__(self, model_type="vit_h"):
self.logits = None
self.prediction = None

def build_sam(self, model_type):
def build_sam(self, model_type, checkpoint_url=None):
"""Build the SAM model.
Args:
model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
checkpoint_url:
"""
checkpoint_url = SAM_MODELS[model_type]
sam = sam_model_registry[model_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
if checkpoint_url is not None:
sam = sam_model_registry[model_type](checkpoint=checkpoint_url)
else:
checkpoint_url = SAM_MODELS[model_type]
sam = sam_model_registry[model_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
sam.to(device=self.device)
self.sam = SamPredictor(sam)

Expand Down

0 comments on commit 22d26eb

Please sign in to comment.