diff --git a/kornia/models/segmentor/segmentation_models.py b/kornia/models/segmentor/segmentation_models.py index 40bc6c9b54..8e619bc97f 100644 --- a/kornia/models/segmentor/segmentation_models.py +++ b/kornia/models/segmentor/segmentation_models.py @@ -21,11 +21,12 @@ class SegmentationModels(Module): classes: Number of classes to predict. **kwargs: Additional arguments to pass to the model. Detailed arguments can be found at: https://github.com/qubvel-org/segmentation_models.pytorch/tree/main/segmentation_models_pytorch/decoders - + Note: Only encoder weights are available. Pretrained weights for the whole model are not available. """ + def __init__( self, model_name: str = "Unet", @@ -33,7 +34,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", in_channels: int = 3, classes: int = 1, - **kwargs + **kwargs, ) -> None: super().__init__() self.preproc_params = smp.encoders.get_preprocessing_params(encoder_name) # type: ignore @@ -42,7 +43,7 @@ def __init__( encoder_weights=encoder_weights, in_channels=in_channels, classes=classes, - **kwargs + **kwargs, ) def preprocessing(self, input: Tensor) -> Tensor: @@ -52,7 +53,7 @@ def preprocessing(self, input: Tensor) -> Tensor: input = kornia.color.rgb_to_bgr(input) else: raise ValueError(f"Unsupported input space: {self.preproc_params['input_space']}") - + if self.preproc_params["input_range"] is not None: if input.max() > 1 and self.preproc_params["input_range"][1] == 1: input = input / 255.0 @@ -61,7 +62,7 @@ def preprocessing(self, input: Tensor) -> Tensor: mean = tensor(self.preproc_params["mean"]).to(input.device) else: mean = tensor(self.preproc_params["mean"]).to(input.device) - + if self.preproc_params["std"] is None: std = tensor(self.preproc_params["std"]).to(input.device) else: