Skip to content

Commit

Permalink
Fix batch indexes detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Kim committed Aug 16, 2023
1 parent 66bfd7f commit 0f82d8d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 4 additions & 2 deletions scvi/model/base/_vaemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def get_batch_representation(

if not hasattr(self.module, "batch_embedding"):
raise NotImplementedError("Model does not support batch embeddings.")
if self.module.batch_embedding is None:
elif self.module.batch_embedding is None:
raise ValueError("Model was not trained with batch embeddings.")

adata = self._validate_anndata(adata)
Expand All @@ -251,7 +251,9 @@ def get_batch_representation(
elif not all((key in cat_mapping) for key in batch_keys):
raise ValueError("``batch_keys`` contains keys not present in ``adata``.")
else:
batch_indexes = np.where(cat_mapping == batch_keys)[0]
batch_indexes = np.concatenate(
[np.where(cat_mapping == key)[0] for key in batch_keys]
)

batch_embeddings = self.module.batch_embedding.weight.detach().cpu().numpy()
return batch_embeddings[batch_indexes]
10 changes: 10 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@ def test_scvi_batch_embedding(
)
model.train(max_epochs=1)

assert hasattr(model.module, "batch_embedding")
assert model.module.batch_embedding is not None
assert model.module.batch_embedding.weight.shape == (n_batches, batch_embedding_dim)

batch_representation = model.get_batch_representation()
assert isinstance(batch_representation, np.ndarray)
assert batch_representation.shape == (n_batches, batch_embedding_dim)

batch_representation = model.get_batch_representation(batch_keys=["batch_0"])
assert isinstance(batch_representation, np.ndarray)
assert batch_representation.shape == (1, batch_embedding_dim)

batch_representation = model.get_batch_representation(
batch_keys=["batch_0", "batch_1"]
)
assert isinstance(batch_representation, np.ndarray)
assert batch_representation.shape == (2, batch_embedding_dim)

0 comments on commit 0f82d8d

Please sign in to comment.