Skip to content

Commit

Permalink
Merge pull request #61 from aertslab/cas_small_fixes
Browse files Browse the repository at this point in the history
Cas small fixes
  • Loading branch information
casblaauw authored Nov 19, 2024
2 parents 8769271 + e9e7abb commit 0cb1419
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 47 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/multi_gpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"Training on multiple GPUs is currently on the roadmap of keras 3.0 for both Tensorflow and Pytorch backend. \n",
"Until this is implemented in Keras 3.0, we don't include multi GPU training inside the `Crested` trainer class, but you can still train on multiple GPU's using the standard Tensorflow's `tf.distribute.MirroredStrategy`. \n",
"You only need to wrap your model creation and training inside the `strategy.scope()` context manager. \n",
"You only need to wrap your model and optimizer creation and training inside the `strategy.scope()` context manager. \n",
"Data preparation is the same as in the single GPU case."
]
},
Expand All @@ -28,8 +28,6 @@
"\n",
"strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"config = crested.tl.default_configs(\"peak_regression\")\n",
"\n",
"datamodule = crested.tl.data.AnnDataModule(\n",
" my_adata,\n",
" genome_file=my_genome,\n",
Expand All @@ -41,6 +39,8 @@
"\n",
"with strategy.scope():\n",
" model_architecture = crested.tl.zoo.chrombpnet(seq_len=2114, num_classes=4)\n",
"\n",
" config = crested.tl.default_configs(\"peak_regression\")\n",
" \n",
" trainer = crested.tl.Crested(\n",
" data=datamodule,\n",
Expand Down
36 changes: 9 additions & 27 deletions src/crested/pp/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def _split_by_chromosome_auto(
regions: list[str], val_fraction: float = 0.1, test_fraction: float = 0.1
regions: list[str], val_fraction: float = 0.1, test_fraction: float = 0.1, random_state: int | None = None,
) -> pd.Series:
"""Split the dataset based on chromosome, automatically selecting chromosomes for val and test sets.
Expand All @@ -31,6 +31,8 @@ def _split_by_chromosome_auto(
"""
chrom_count = defaultdict(int)
for region in regions:
if ":" not in region:
raise ValueError(f"Region names should start with the chromosome name, bound by a colon (:). Offending region: {region}")
chrom = region.split(":")[0]
chrom_count[chrom] += 1

Expand All @@ -39,6 +41,7 @@ def _split_by_chromosome_auto(
target_test_size = int(test_fraction * total_regions)

chromosomes = list(chrom_count.keys())
np.random.seed(seed=random_state)
np.random.shuffle(chromosomes)

val_chroms = set()
Expand Down Expand Up @@ -187,7 +190,6 @@ def train_val_test_split(
test_size: float = 0.1,
val_chroms: list[str] = None,
test_chroms: list[str] = None,
chr_var_key: str = "chr",
shuffle: bool = True,
random_state: None | int = None,
) -> None:
Expand All @@ -207,8 +209,8 @@ def train_val_test_split(
adata
AnnData object to which the 'train/val/test' split column will be added.
strategy
strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the "target" df should
have a column "chr" with the chromosome names.
strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the anndata's var_names should
contain the chromosome name at the start, followed by a `:` (e.g. I:2000-2500 or chr3:10-20:+).
region: Split randomly on region indices.
Expand All @@ -226,8 +228,6 @@ def train_val_test_split(
List of chromosomes to include in the validation set. Required if strategy='chr'.
test_chroms
List of chromosomes to include in the test set. Required if strategy='chr'.
chr_var_key
Key in `.var` for chromosome.
shuffle
Whether or not to shuffle the data before splitting (when strategy='region').
random_state
Expand Down Expand Up @@ -260,13 +260,9 @@ def train_val_test_split(
# Input checks
if strategy not in ["region", "chr", "chr_auto"]:
raise ValueError("`strategy` should be either 'region','chr', or 'chr_auto'")
if strategy == "region" and not 0 <= val_size <= 1:
if strategy in ["region", "chr_auto"] and not 0 <= val_size <= 1:
raise ValueError("`val_size` should be a float between 0 and 1.")
if strategy == "region" and not 0 <= test_size <= 1:
raise ValueError("`test_size` should be a float between 0 and 1.")
if strategy == "chr_auto" and not 0 <= val_size <= 1:
raise ValueError("`val_size` should be a float between 0 and 1.")
if strategy == "chr_auto" and not 0 <= test_size <= 1:
if strategy in ["region", "chr_auto"] and not 0 <= test_size <= 1:
raise ValueError("`test_size` should be a float between 0 and 1.")
if (strategy == "region") and (val_chroms is not None or test_chroms is not None):
logger.warning(
Expand All @@ -278,20 +274,6 @@ def train_val_test_split(
raise ValueError(
"If `strategy` is 'chr', `val_chroms` and `test_chroms` should be provided."
)
if chr_var_key not in adata.var.columns:
raise ValueError(
f"Column '{chr_var_key}' not found in `.var`. "
"Make sure to add the chromosome information to the `.var` DataFrame."
)
unique_chr = adata.var[chr_var_key].unique()
if not set(val_chroms).issubset(unique_chr):
raise ValueError(
"Some chromosomes in `val_chroms` are not present in the dataset."
)
if not set(test_chroms).issubset(unique_chr):
raise ValueError(
"Some chromosomes in `test_chroms` are not present in the dataset."
)

# Split
regions = list(adata.var_names)
Expand All @@ -303,7 +285,7 @@ def train_val_test_split(
regions, val_chroms=val_chroms, test_chroms=test_chroms
)
elif strategy == "chr_auto":
split = _split_by_chromosome_auto(regions, val_size, test_size)
split = _split_by_chromosome_auto(regions, val_size, test_size, random_state)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down
63 changes: 48 additions & 15 deletions src/crested/tl/zoo/utils/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def conv_block(
padding: str = "valid",
l2: float = 1e-5,
batchnorm_momentum: float = 0.99,
name_prefix: str | None = None
) -> keras.KerasTensor:
"""
Convolution building block.
Expand Down Expand Up @@ -129,6 +130,8 @@ def conv_block(
L2 regularization weight (default is 1e-5).
batchnorm_momentum
Batch normalization momentum (default is 0.99).
name_prefix
Prefix for layer names.
Returns
-------
Expand All @@ -143,23 +146,42 @@ def conv_block(
padding=padding,
kernel_regularizer=keras.regularizers.L2(l2),
use_bias=conv_bias,
name=name_prefix + "_conv" if name_prefix else None,
)(inputs)
if normalization == "batch":
x = keras.layers.BatchNormalization(momentum=batchnorm_momentum)(x)
x = keras.layers.BatchNormalization(
momentum=batchnorm_momentum,
name=name_prefix + "_batchnorm" if name_prefix else None
)(x)
elif normalization == "layer":
x = keras.layers.LayerNormalization()(x)
x = keras.layers.Activation(activation)(x)
x = keras.layers.LayerNormalization(
name=name_prefix + "_layernorm" if name_prefix else None,
)(x)
x = keras.layers.Activation(
activation,
name=name_prefix + "_activation" if name_prefix else None,
)(x)
if res:
if filters != residual.shape[2]:
residual = keras.layers.Convolution1D(
filters=filters, kernel_size=1, strides=1
filters=filters,
kernel_size=1,
strides=1,
name=name_prefix + "_resconv" if name_prefix else None,
)(residual)
x = keras.layers.Add()([x, residual])

if pool_size > 1:
x = keras.layers.MaxPooling1D(pool_size=pool_size, padding=padding)(x)
x = keras.layers.MaxPooling1D(
pool_size=pool_size,
padding=padding,
name=name_prefix + "_pool" if name_prefix else None,
)(x)
if dropout > 0:
x = keras.layers.Dropout(dropout)(x)
x = keras.layers.Dropout(
dropout,
name=name_prefix + "_dropout" if name_prefix else None
)(x)

return x

Expand Down Expand Up @@ -249,6 +271,7 @@ def conv_block_bs(
bn_type: str = "standard",
kernel_initializer: str = "he_normal",
padding: str = "same",
name_prefix: str | None = None,
):
"""
Construct a convolution block (for Basenji).
Expand Down Expand Up @@ -287,6 +310,8 @@ def conv_block_bs(
Convolution kernel initializer.
padding
Padding type.
name_prefix
Prefix for layer names.
Returns
-------
Expand Down Expand Up @@ -318,6 +343,7 @@ def conv_block_bs(
dilation_rate=dilation_rate,
kernel_initializer=kernel_initializer,
kernel_regularizer=keras.regularizers.l2(l2_scale),
name=name_prefix + "_conv" if name_prefix else None,
)(current)

# batch norm
Expand All @@ -328,11 +354,18 @@ def conv_block_bs(
bn_layer = keras.layers.experimental.SyncBatchNormalization
else:
bn_layer = keras.layers.BatchNormalization
current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current)
current = bn_layer(
momentum=bn_momentum,
gamma_initializer=bn_gamma,
name=name_prefix + "_bnorm" if name_prefix else None,
)(current)

# dropout
if dropout > 0:
current = keras.layers.Dropout(rate=dropout)(current)
current = keras.layers.Dropout(
rate=dropout,
name=name_prefix + "_dropout" if name_prefix else None,
)(current)

# residual add
if residual:
Expand All @@ -345,14 +378,14 @@ def conv_block_bs(
# Pool
if pool_size > 1:
if w1:
current = keras.layers.MaxPool2D(pool_size=pool_size, padding=padding)(
current
)
pool_layer = keras.layers.MaxPool2D
else:
current = keras.layers.MaxPool1D(pool_size=pool_size, padding=padding)(
current
)

pool_layer = keras.layers.MaxPool1D
current = pool_layer(
pool_size=pool_size,
padding=padding,
name=name_prefix + "_pool" if name_prefix else None,
)(current)
return current


Expand Down
2 changes: 0 additions & 2 deletions tests/test_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_train_val_test_split_by_chromosome():
strategy="chr",
val_chroms=["chr1"],
test_chroms=["chr2"],
chr_var_key="chr",
)

split_counts = adata.var["split"].value_counts()
Expand Down Expand Up @@ -74,7 +73,6 @@ def test_train_val_test_split_by_chromosome_auto():
strategy="chr_auto",
val_size=0.2,
test_size=0.2,
chr_var_key="chr",
random_state=None,
)

Expand Down

0 comments on commit 0cb1419

Please sign in to comment.