diff --git a/docs/tutorials/multi_gpu.ipynb b/docs/tutorials/multi_gpu.ipynb index 2609e50..d84283e 100644 --- a/docs/tutorials/multi_gpu.ipynb +++ b/docs/tutorials/multi_gpu.ipynb @@ -13,7 +13,7 @@ "source": [ "Training on multiple GPUs is currently on the roadmap of keras 3.0 for both Tensorflow and Pytorch backend. \n", "Until this is implemented in Keras 3.0, we don't include multi GPU training inside the `Crested` trainer class, but you can still train on multiple GPU's using the standard Tensorflow's `tf.distribute.MirroredStrategy`. \n", - "You only need to wrap your model creation and training inside the `strategy.scope()` context manager. \n", + "You only need to wrap your model and optimizer creation and training inside the `strategy.scope()` context manager. \n", "Data preparation is the same as in the single GPU case." ] }, @@ -28,8 +28,6 @@ "\n", "strategy = tf.distribute.MirroredStrategy()\n", "\n", - "config = crested.tl.default_configs(\"peak_regression\")\n", - "\n", "datamodule = crested.tl.data.AnnDataModule(\n", " my_adata,\n", " genome_file=my_genome,\n", @@ -41,6 +39,8 @@ "\n", "with strategy.scope():\n", " model_architecture = crested.tl.zoo.chrombpnet(seq_len=2114, num_classes=4)\n", + "\n", + " config = crested.tl.default_configs(\"peak_regression\")\n", " \n", " trainer = crested.tl.Crested(\n", " data=datamodule,\n", diff --git a/src/crested/pp/_split.py b/src/crested/pp/_split.py index bc5ff58..149a2f4 100644 --- a/src/crested/pp/_split.py +++ b/src/crested/pp/_split.py @@ -15,7 +15,7 @@ def _split_by_chromosome_auto( - regions: list[str], val_fraction: float = 0.1, test_fraction: float = 0.1 + regions: list[str], val_fraction: float = 0.1, test_fraction: float = 0.1, random_state: int | None = None, ) -> pd.Series: """Split the dataset based on chromosome, automatically selecting chromosomes for val and test sets. @@ -31,6 +31,8 @@ def _split_by_chromosome_auto( """ chrom_count = defaultdict(int) for region in regions: + if ":" not in region: + raise ValueError(f"Region names should start with the chromosome name, bound by a colon (:). Offending region: {region}") chrom = region.split(":")[0] chrom_count[chrom] += 1 @@ -39,6 +41,7 @@ def _split_by_chromosome_auto( target_test_size = int(test_fraction * total_regions) chromosomes = list(chrom_count.keys()) + np.random.seed(seed=random_state) np.random.shuffle(chromosomes) val_chroms = set() @@ -187,7 +190,6 @@ def train_val_test_split( test_size: float = 0.1, val_chroms: list[str] = None, test_chroms: list[str] = None, - chr_var_key: str = "chr", shuffle: bool = True, random_state: None | int = None, ) -> None: @@ -207,8 +209,8 @@ def train_val_test_split( adata AnnData object to which the 'train/val/test' split column will be added. strategy - strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the "target" df should - have a column "chr" with the chromosome names. + strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the anndata's var_names should + contain the chromosome name at the start, followed by a `:` (e.g. I:2000-2500 or chr3:10-20:+). region: Split randomly on region indices. @@ -226,8 +228,6 @@ def train_val_test_split( List of chromosomes to include in the validation set. Required if strategy='chr'. test_chroms List of chromosomes to include in the test set. Required if strategy='chr'. - chr_var_key - Key in `.var` for chromosome. shuffle Whether or not to shuffle the data before splitting (when strategy='region'). random_state @@ -260,13 +260,9 @@ def train_val_test_split( # Input checks if strategy not in ["region", "chr", "chr_auto"]: raise ValueError("`strategy` should be either 'region','chr', or 'chr_auto'") - if strategy == "region" and not 0 <= val_size <= 1: + if strategy in ["region", "chr_auto"] and not 0 <= val_size <= 1: raise ValueError("`val_size` should be a float between 0 and 1.") - if strategy == "region" and not 0 <= test_size <= 1: - raise ValueError("`test_size` should be a float between 0 and 1.") - if strategy == "chr_auto" and not 0 <= val_size <= 1: - raise ValueError("`val_size` should be a float between 0 and 1.") - if strategy == "chr_auto" and not 0 <= test_size <= 1: + if strategy in ["region", "chr_auto"] and not 0 <= test_size <= 1: raise ValueError("`test_size` should be a float between 0 and 1.") if (strategy == "region") and (val_chroms is not None or test_chroms is not None): logger.warning( @@ -278,20 +274,6 @@ def train_val_test_split( raise ValueError( "If `strategy` is 'chr', `val_chroms` and `test_chroms` should be provided." ) - if chr_var_key not in adata.var.columns: - raise ValueError( - f"Column '{chr_var_key}' not found in `.var`. " - "Make sure to add the chromosome information to the `.var` DataFrame." - ) - unique_chr = adata.var[chr_var_key].unique() - if not set(val_chroms).issubset(unique_chr): - raise ValueError( - "Some chromosomes in `val_chroms` are not present in the dataset." - ) - if not set(test_chroms).issubset(unique_chr): - raise ValueError( - "Some chromosomes in `test_chroms` are not present in the dataset." - ) # Split regions = list(adata.var_names) @@ -303,7 +285,7 @@ def train_val_test_split( regions, val_chroms=val_chroms, test_chroms=test_chroms ) elif strategy == "chr_auto": - split = _split_by_chromosome_auto(regions, val_size, test_size) + split = _split_by_chromosome_auto(regions, val_size, test_size, random_state) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/src/crested/tl/zoo/utils/_layers.py b/src/crested/tl/zoo/utils/_layers.py index 9f790d4..e4456bc 100644 --- a/src/crested/tl/zoo/utils/_layers.py +++ b/src/crested/tl/zoo/utils/_layers.py @@ -99,6 +99,7 @@ def conv_block( padding: str = "valid", l2: float = 1e-5, batchnorm_momentum: float = 0.99, + name_prefix: str | None = None ) -> keras.KerasTensor: """ Convolution building block. @@ -129,6 +130,8 @@ def conv_block( L2 regularization weight (default is 1e-5). batchnorm_momentum Batch normalization momentum (default is 0.99). + name_prefix + Prefix for layer names. Returns ------- @@ -143,23 +146,42 @@ def conv_block( padding=padding, kernel_regularizer=keras.regularizers.L2(l2), use_bias=conv_bias, + name=name_prefix + "_conv" if name_prefix else None, )(inputs) if normalization == "batch": - x = keras.layers.BatchNormalization(momentum=batchnorm_momentum)(x) + x = keras.layers.BatchNormalization( + momentum=batchnorm_momentum, + name=name_prefix + "_batchnorm" if name_prefix else None + )(x) elif normalization == "layer": - x = keras.layers.LayerNormalization()(x) - x = keras.layers.Activation(activation)(x) + x = keras.layers.LayerNormalization( + name=name_prefix + "_layernorm" if name_prefix else None, + )(x) + x = keras.layers.Activation( + activation, + name=name_prefix + "_activation" if name_prefix else None, + )(x) if res: if filters != residual.shape[2]: residual = keras.layers.Convolution1D( - filters=filters, kernel_size=1, strides=1 + filters=filters, + kernel_size=1, + strides=1, + name=name_prefix + "_resconv" if name_prefix else None, )(residual) x = keras.layers.Add()([x, residual]) if pool_size > 1: - x = keras.layers.MaxPooling1D(pool_size=pool_size, padding=padding)(x) + x = keras.layers.MaxPooling1D( + pool_size=pool_size, + padding=padding, + name=name_prefix + "_pool" if name_prefix else None, + )(x) if dropout > 0: - x = keras.layers.Dropout(dropout)(x) + x = keras.layers.Dropout( + dropout, + name=name_prefix + "_dropout" if name_prefix else None + )(x) return x @@ -249,6 +271,7 @@ def conv_block_bs( bn_type: str = "standard", kernel_initializer: str = "he_normal", padding: str = "same", + name_prefix: str | None = None, ): """ Construct a convolution block (for Basenji). @@ -287,6 +310,8 @@ def conv_block_bs( Convolution kernel initializer. padding Padding type. + name_prefix + Prefix for layer names. Returns ------- @@ -318,6 +343,7 @@ def conv_block_bs( dilation_rate=dilation_rate, kernel_initializer=kernel_initializer, kernel_regularizer=keras.regularizers.l2(l2_scale), + name=name_prefix + "_conv" if name_prefix else None, )(current) # batch norm @@ -328,11 +354,18 @@ def conv_block_bs( bn_layer = keras.layers.experimental.SyncBatchNormalization else: bn_layer = keras.layers.BatchNormalization - current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current) + current = bn_layer( + momentum=bn_momentum, + gamma_initializer=bn_gamma, + name=name_prefix + "_bnorm" if name_prefix else None, + )(current) # dropout if dropout > 0: - current = keras.layers.Dropout(rate=dropout)(current) + current = keras.layers.Dropout( + rate=dropout, + name=name_prefix + "_dropout" if name_prefix else None, + )(current) # residual add if residual: @@ -345,14 +378,14 @@ def conv_block_bs( # Pool if pool_size > 1: if w1: - current = keras.layers.MaxPool2D(pool_size=pool_size, padding=padding)( - current - ) + pool_layer = keras.layers.MaxPool2D else: - current = keras.layers.MaxPool1D(pool_size=pool_size, padding=padding)( - current - ) - + pool_layer = keras.layers.MaxPool1D + current = pool_layer( + pool_size=pool_size, + padding=padding, + name=name_prefix + "_pool" if name_prefix else None, + )(current) return current diff --git a/tests/test_pp.py b/tests/test_pp.py index a5a18da..f1d4fe1 100644 --- a/tests/test_pp.py +++ b/tests/test_pp.py @@ -45,7 +45,6 @@ def test_train_val_test_split_by_chromosome(): strategy="chr", val_chroms=["chr1"], test_chroms=["chr2"], - chr_var_key="chr", ) split_counts = adata.var["split"].value_counts() @@ -74,7 +73,6 @@ def test_train_val_test_split_by_chromosome_auto(): strategy="chr_auto", val_size=0.2, test_size=0.2, - chr_var_key="chr", random_state=None, )