diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index a32cd50..d856662 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -820,7 +820,7 @@ def test_on_dataset( def embed_on_dataset( self, dataset: Callable, - devices: Union[str, int, List[int]] = "cpu", + device: Union[str, int] = "cpu", num_workers: int = 1, batch_size: int = 256, ): @@ -829,7 +829,7 @@ def embed_on_dataset( Args: dataset: Dataset object that yields one-hot encoded sequences - devices: Device IDs to use + device: Device ID to use num_workers: Number of workers for data loader batch_size: Batch size for data loader @@ -844,13 +844,20 @@ def embed_on_dataset( ) # Get device - orig_device = self.device - device = self.parse_devices(devices)[0] if isinstance(device, list): device = device[0] warnings.warn( - f"embed_on_dataset currently only uses a single GPU: {device}" + f"embed_on_dataset currently only uses a single GPU: Using {device}" ) + if isinstance(device, str): + try: + device = int(device) + except Exception: + pass + device = torch.device(device) + + # Move model to device + orig_device = self.device self.to(device) # Get embeddings diff --git a/tests/test_lightning.py b/tests/test_lightning.py index fd77f5c..d9a609d 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -252,7 +252,7 @@ def test_lightning_model_transform(): def test_lightning_model_embed_on_dataset(): - preds = single_task_reg_model.embed_on_dataset(dataset=udataset, devices="cpu") + preds = single_task_reg_model.embed_on_dataset(dataset=udataset, device="cpu") assert preds.shape == (2, 1, 3)