Skip to content

Commit

Permalink
Merge pull request #20 from Genentech/embed_device
Browse files Browse the repository at this point in the history
fixed device parsing in embed_on_dataset
  • Loading branch information
avantikalal authored Jul 17, 2024
2 parents 2ef0c4b + 294a496 commit 0a0f060
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/grelu/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def test_on_dataset(
def embed_on_dataset(
self,
dataset: Callable,
devices: Union[str, int, List[int]] = "cpu",
device: Union[str, int] = "cpu",
num_workers: int = 1,
batch_size: int = 256,
):
Expand All @@ -829,7 +829,7 @@ def embed_on_dataset(
Args:
dataset: Dataset object that yields one-hot encoded sequences
devices: Device IDs to use
device: Device ID to use
num_workers: Number of workers for data loader
batch_size: Batch size for data loader
Expand All @@ -844,13 +844,20 @@ def embed_on_dataset(
)

# Get device
orig_device = self.device
device = self.parse_devices(devices)[0]
if isinstance(device, list):
device = device[0]
warnings.warn(
f"embed_on_dataset currently only uses a single GPU: {device}"
f"embed_on_dataset currently only uses a single GPU: Using {device}"
)
if isinstance(device, str):
try:
device = int(device)
except Exception:
pass
device = torch.device(device)

# Move model to device
orig_device = self.device
self.to(device)

# Get embeddings
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_lightning_model_transform():


def test_lightning_model_embed_on_dataset():
preds = single_task_reg_model.embed_on_dataset(dataset=udataset, devices="cpu")
preds = single_task_reg_model.embed_on_dataset(dataset=udataset, device="cpu")
assert preds.shape == (2, 1, 3)


Expand Down

0 comments on commit 0a0f060

Please sign in to comment.