From 0f82d8dbebded972eaef9bcf9dbe9fae4bc6ca69 Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Tue, 15 Aug 2023 17:05:46 -0700 Subject: [PATCH] Fix batch indexes detection --- scvi/model/base/_vaemixin.py | 6 ++++-- tests/model/test_scvi.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py index c94e851816..9d32b7cf2f 100644 --- a/scvi/model/base/_vaemixin.py +++ b/scvi/model/base/_vaemixin.py @@ -238,7 +238,7 @@ def get_batch_representation( if not hasattr(self.module, "batch_embedding"): raise NotImplementedError("Model does not support batch embeddings.") - if self.module.batch_embedding is None: + elif self.module.batch_embedding is None: raise ValueError("Model was not trained with batch embeddings.") adata = self._validate_anndata(adata) @@ -251,7 +251,9 @@ def get_batch_representation( elif not all((key in cat_mapping) for key in batch_keys): raise ValueError("``batch_keys`` contains keys not present in ``adata``.") else: - batch_indexes = np.where(cat_mapping == batch_keys)[0] + batch_indexes = np.concatenate( + [np.where(cat_mapping == key)[0] for key in batch_keys] + ) batch_embeddings = self.module.batch_embedding.weight.detach().cpu().numpy() return batch_embeddings[batch_indexes] diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 0019cd7ded..f1a8043d68 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -14,6 +14,10 @@ def test_scvi_batch_embedding( ) model.train(max_epochs=1) + assert hasattr(model.module, "batch_embedding") + assert model.module.batch_embedding is not None + assert model.module.batch_embedding.weight.shape == (n_batches, batch_embedding_dim) + batch_representation = model.get_batch_representation() assert isinstance(batch_representation, np.ndarray) assert batch_representation.shape == (n_batches, batch_embedding_dim) @@ -21,3 +25,9 @@ def test_scvi_batch_embedding( batch_representation = model.get_batch_representation(batch_keys=["batch_0"]) assert isinstance(batch_representation, np.ndarray) assert batch_representation.shape == (1, batch_embedding_dim) + + batch_representation = model.get_batch_representation( + batch_keys=["batch_0", "batch_1"] + ) + assert isinstance(batch_representation, np.ndarray) + assert batch_representation.shape == (2, batch_embedding_dim)