Skip to content

Commit

Permalink
refactor: minor refactor of cuda device init
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed May 1, 2024
1 parent 4ab8780 commit 0f29e68
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions sdcat/cluster/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,10 @@ def compute_embedding(images: list, model_name: str):
info(f'Using patch size {patch_size} for model {model_name}')

# Load images and generate embeddings
device = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad():
# Set the cuda device
if torch.cuda.is_available():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for filename in images:
Expand All @@ -123,7 +122,7 @@ def compute_embedding(images: list, model_name: str):
# Convert the image to a tensor
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
if device:
if 'cuda' in device:
img_tensor = img_tensor.to(device)
features = model(img_tensor)

Expand Down

0 comments on commit 0f29e68

Please sign in to comment.