diff --git a/src/nuclai/utils/datamodule.py b/src/nuclai/utils/datamodule.py index 3bc3ad5..13ee064 100644 --- a/src/nuclai/utils/datamodule.py +++ b/src/nuclai/utils/datamodule.py @@ -19,7 +19,7 @@ from monai import transforms from torch.utils.data import DataLoader -from nuclai.utils.utils import RandomSpatialPad, _get_mask +from nuclai.utils.utils import _get_mask class DataSetCls: @@ -221,8 +221,10 @@ def __init__( self.data = pd.read_csv(path_data) self.shape = shape self.trans = trans - self.padder = RandomSpatialPad(self.shape) - # self.padder = transforms.SpatialPad(spatial_size=self.shape, method="symmetric") + # self.padder = RandomSpatialPad(self.shape) + self.padder = transforms.SpatialPad( + spatial_size=self.shape, method="symmetric" + ) assert ( "image" in self.data.columns