diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 3fc1f731fc..51695cab56 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -243,7 +243,7 @@ def __init__( self.n_train, self.n_val = validate_data_split_with_external_indexing( self.adata_manager.adata.n_obs, self.external_indexing, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) else: @@ -251,7 +251,7 @@ def __init__( self.adata_manager.adata.n_obs, self.train_size, self.validation_size, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) @@ -423,7 +423,7 @@ def setup(self, stage: str | None = None): n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing( n_labeled_idx, [labeled_idx_train, labeled_idx_val, labeled_idx_test], - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) else: @@ -431,7 +431,7 @@ def setup(self, stage: str | None = None): n_labeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) @@ -463,7 +463,7 @@ def setup(self, stage: str | None = None): n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing( n_unlabeled_idx, [unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test], - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) else: @@ -471,7 +471,7 @@ def setup(self, stage: str | None = None): n_unlabeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) diff --git a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py index 312cefe131..2f5c819fa2 100644 --- a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py +++ b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py @@ -81,14 +81,14 @@ def __init__( self.n_background, self.train_size, self.validation_size, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) self.n_target_train, self.n_target_val = validate_data_split( self.n_target, self.train_size, self.validation_size, - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) else: @@ -101,7 +101,7 @@ def __init__( validate_data_split_with_external_indexing( self.n_background, [self.background_train_idx, self.background_val_idx, self.background_test_idx], - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) ) @@ -117,7 +117,7 @@ def __init__( self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing( self.n_target, [self.target_train_idx, self.target_val_idx, self.target_test_idx], - self.data_loader_kwargs["batch_size"], + self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, ) self.target_train_idx, self.target_val_idx, self.target_test_idx = (