Skip to content

Commit

Permalink
improve PrecessingDescr docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 13, 2024
1 parent 33d3494 commit 818515b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
20 changes: 18 additions & 2 deletions bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,14 @@ class ProcessingDescrBase(NodeWithExplicitlySetFields):


class BinarizeKwargs(ProcessingKwargs):
"""key word arguments for `BinarizeDescr`"""

threshold: float
"""The fixed threshold"""


class BinarizeDescr(ProcessingDescrBase):
"""BinarizeDescr the tensor with a fixed threshold.
"""BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`.
Values above the threshold will be set to one, values below the threshold to zero.
"""

Expand All @@ -534,21 +536,29 @@ class BinarizeDescr(ProcessingDescrBase):


class ClipKwargs(ProcessingKwargs):
"""key word arguments for `ClipDescr`"""

min: float
"""minimum value for clipping"""
max: float
"""maximum value for clipping"""


class ClipDescr(ProcessingDescrBase):
"""Set tensor values below min to min and above max to max."""
"""Clip tensor values to a range.
Set tensor values below `ClipKwargs.min` to `ClipKwargs.min`
and above `ClipKwargs.max` to `ClipKwargs.max`.
"""

name: Literal["clip"] = "clip"

kwargs: ClipKwargs


class ScaleLinearKwargs(ProcessingKwargs):
"""key word arguments for `ScaleLinearDescr`"""

axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
"""The subset of axes to scale jointly.
For example xy to scale the two image axes for 2d data jointly."""
Expand Down Expand Up @@ -597,6 +607,8 @@ def kwargs(self) -> ProcessingKwargs:


class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
"""key word arguments for `ZeroMeanUnitVarianceDescr`"""

mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed"
"""Mode for computing mean and variance.
| mode | description |
Expand Down Expand Up @@ -642,6 +654,8 @@ class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):


class ScaleRangeKwargs(ProcessingKwargs):
"""key word arguments for `ScaleRangeDescr`"""

mode: Literal["per_dataset", "per_sample"]
"""Mode for computing percentiles.
| mode | description |
Expand Down Expand Up @@ -691,6 +705,8 @@ class ScaleRangeDescr(ProcessingDescrBase):


class ScaleMeanVarianceKwargs(ProcessingKwargs):
"""key word arguments for `ScaleMeanVarianceDescr`"""

mode: Literal["per_dataset", "per_sample"]
"""Mode for computing mean and variance.
| mode | description |
Expand Down
41 changes: 32 additions & 9 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,11 +789,15 @@ class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):


class BinarizeKwargs(ProcessingKwargs):
"""key word arguments for `BinarizeDescr`"""

threshold: float
"""The fixed threshold"""


class BinarizeAlongAxisKwargs(ProcessingKwargs):
"""key word arguments for `BinarizeDescr`"""

threshold: NotEmpty[List[float]]
"""The fixed threshold values along `axis`"""

Expand All @@ -803,7 +807,9 @@ class BinarizeAlongAxisKwargs(ProcessingKwargs):

class BinarizeDescr(ProcessingDescrBase):
"""Binarize the tensor with a fixed threshold.
Values above the threshold will be set to one, values below the threshold to zero.
Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
will be set to one, values below the threshold to zero.
"""

id: Literal["binarize"] = "binarize"
Expand All @@ -818,6 +824,8 @@ class ClipDescr(ProcessingDescrBase):


class EnsureDtypeKwargs(ProcessingKwargs):
"""key word arguments for `EnsureDtypeDescr`"""

dtype: Literal[
"float32",
"float64",
Expand All @@ -834,11 +842,15 @@ class EnsureDtypeKwargs(ProcessingKwargs):


class EnsureDtypeDescr(ProcessingDescrBase):
"""cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching)"""

id: Literal["ensure_dtype"] = "ensure_dtype"
kwargs: EnsureDtypeKwargs


class ScaleLinearKwargs(ProcessingKwargs):
"""key word arguments for `ScaleLinearDescr`"""

gain: float = 1.0
"""multiplicative factor"""

Expand All @@ -857,6 +869,8 @@ def _validate(self) -> Self:


class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
"""key word arguments for `ScaleLinearDescr`"""

axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
"""The axis of of gains/offsets values."""

Expand Down Expand Up @@ -912,8 +926,7 @@ def kwargs(self) -> ProcessingKwargs:


class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
"""Normalize with fixed, precomputed values for mean and variance.
See `zero_mean_unit_variance` for data dependent normalization."""
"""key word arguments for `FixedZeroMeanUnitVarianceDescr`"""

mean: float
"""The mean value to normalize with."""
Expand All @@ -923,8 +936,7 @@ class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):


class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
"""Normalize with fixed, precomputed values for mean and variance.
See `zero_mean_unit_variance` for data dependent normalization."""
"""key word arguments for `FixedZeroMeanUnitVarianceDescr`"""

mean: NotEmpty[List[float]]
"""The mean value(s) to normalize with."""
Expand All @@ -949,7 +961,13 @@ def _mean_and_std_match(self) -> Self:


class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
"""Subtract a given mean and divide by a given variance."""
"""Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
`FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
axes.
"""

id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
kwargs: Union[
Expand All @@ -958,6 +976,8 @@ class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):


class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
"""key word arguments for `ZeroMeanUnitVarianceDescr`"""

axes: Annotated[
Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
] = None
Expand All @@ -979,6 +999,8 @@ class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):


class ScaleRangeKwargs(ProcessingKwargs):
"""key word arguments for `ScaleRangeDescr`"""

axes: Annotated[
Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
] = None
Expand Down Expand Up @@ -1023,8 +1045,7 @@ class ScaleRangeDescr(ProcessingDescrBase):


class ScaleMeanVarianceKwargs(ProcessingKwargs):
"""Scale a tensor's data distribution to match another tensor's mean/std.
`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
"""key word arguments for `ScaleMeanVarianceKwargs`"""

reference_tensor: TensorId
"""Name of tensor to match."""
Expand All @@ -1044,7 +1065,9 @@ class ScaleMeanVarianceKwargs(ProcessingKwargs):


class ScaleMeanVarianceDescr(ProcessingDescrBase):
"""Scale the tensor s.t. its mean and variance match a reference tensor."""
"""Scale a tensor's data distribution to match another tensor's mean/std.
`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
"""

id: Literal["scale_mean_variance"] = "scale_mean_variance"
kwargs: ScaleMeanVarianceKwargs
Expand Down

0 comments on commit 818515b

Please sign in to comment.