Skip to content

Commit

Permalink
fix empty kwargs element case
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Nov 3, 2024
1 parent 1e8550f commit a38423c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,15 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

Expand Down Expand Up @@ -423,15 +423,15 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

Expand Down Expand Up @@ -463,15 +463,15 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
Expand All @@ -101,7 +101,7 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
Expand All @@ -117,7 +117,7 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs["batch_size"],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
Expand Down

0 comments on commit a38423c

Please sign in to comment.