Skip to content

Commit ca0afad

Browse files
committed
fix imports and file constants
1 parent f23e387 commit ca0afad

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

data_processing/datasets.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils.data import Dataset
1919

2020
from model_training.augmentations import image_augmentation, point_augmentation, transform_meshes
21-
from constants import POINT_DIR_COPD, IMG_DIR_COPD, ALIGN_CORNERS
21+
from constants import ALIGN_CORNERS, POINT_DIR_TS, IMG_DIR_TS_PREPROC
2222
from utils.general_utils import load_points, kpts_to_grid, kpts_to_world, load_meshes, o3d_to_pt3d_meshes
2323
from utils.sitk_image_ops import resample_equal_spacing, sitk_image_to_tensor, multiple_objects_morphology, \
2424
get_resample_factors, load_image_metadata
@@ -373,7 +373,7 @@ def normalize_img(img, min_val=IMG_MIN, max_val=IMG_MAX):
373373

374374
class PointDataset(CustomDataset):
375375
def __init__(self, sample_points, kp_mode,
376-
folder=POINT_DIR, image_folder=IMG_DIR,
376+
folder=POINT_DIR_TS, image_folder=IMG_DIR_TS_PREPROC,
377377
use_coords=True, patch_feat=None, exclude_rhf=False, lobes=False, binary=False, do_augmentation=True,
378378
copd=False):
379379

@@ -622,13 +622,12 @@ def get_obj_mesh(self, item):
622622

623623

624624
class PointToMeshDS(PointDataset):
625-
def __init__(self, sample_points, kp_mode, folder=POINT_DIR, image_folder=IMG_DIR, use_coords=True,
625+
def __init__(self, sample_points, kp_mode, folder=POINT_DIR_TS, image_folder=IMG_DIR_TS_PREPROC, use_coords=True,
626626
patch_feat=None, exclude_rhf=False, lobes=False, binary=False, do_augmentation=False, copd=False):
627627
super(PointToMeshDS, self).__init__(sample_points=sample_points, kp_mode=kp_mode, folder=folder,
628628
image_folder=image_folder,
629629
use_coords=use_coords, patch_feat=patch_feat, exclude_rhf=exclude_rhf,
630-
lobes=lobes, binary=binary, do_augmentation=do_augmentation, copd=copd,
631-
all_to_device=all_to_device)
630+
lobes=lobes, binary=binary, do_augmentation=do_augmentation, copd=copd)
632631
self.meshes = []
633632
self.img_sizes = []
634633
for case, sequence in self.ids:
@@ -655,7 +654,7 @@ def unnormalize_mesh(self, mesh: Meshes, index):
655654

656655

657656
class PointToMeshAndLabelDataset(PointToMeshDS):
658-
def __init__(self, sample_points, kp_mode, folder=POINT_DIR, image_folder=IMG_DIR, use_coords=True,
657+
def __init__(self, sample_points, kp_mode, folder=POINT_DIR_TS, image_folder=IMG_DIR_TS_PREPROC, use_coords=True,
659658
patch_feat=None, exclude_rhf=False, lobes=False, binary=False, do_augmentation=True, copd=False):
660659
super().__init__(sample_points=sample_points, kp_mode=kp_mode, folder=folder,
661660
image_folder=image_folder,

0 commit comments

Comments
 (0)