18
18
from torch .utils .data import Dataset
19
19
20
20
from model_training .augmentations import image_augmentation , point_augmentation , transform_meshes
21
- from constants import POINT_DIR_COPD , IMG_DIR_COPD , ALIGN_CORNERS
21
+ from constants import ALIGN_CORNERS , POINT_DIR_TS , IMG_DIR_TS_PREPROC
22
22
from utils .general_utils import load_points , kpts_to_grid , kpts_to_world , load_meshes , o3d_to_pt3d_meshes
23
23
from utils .sitk_image_ops import resample_equal_spacing , sitk_image_to_tensor , multiple_objects_morphology , \
24
24
get_resample_factors , load_image_metadata
@@ -373,7 +373,7 @@ def normalize_img(img, min_val=IMG_MIN, max_val=IMG_MAX):
373
373
374
374
class PointDataset (CustomDataset ):
375
375
def __init__ (self , sample_points , kp_mode ,
376
- folder = POINT_DIR , image_folder = IMG_DIR ,
376
+ folder = POINT_DIR_TS , image_folder = IMG_DIR_TS_PREPROC ,
377
377
use_coords = True , patch_feat = None , exclude_rhf = False , lobes = False , binary = False , do_augmentation = True ,
378
378
copd = False ):
379
379
@@ -622,13 +622,12 @@ def get_obj_mesh(self, item):
622
622
623
623
624
624
class PointToMeshDS (PointDataset ):
625
- def __init__ (self , sample_points , kp_mode , folder = POINT_DIR , image_folder = IMG_DIR , use_coords = True ,
625
+ def __init__ (self , sample_points , kp_mode , folder = POINT_DIR_TS , image_folder = IMG_DIR_TS_PREPROC , use_coords = True ,
626
626
patch_feat = None , exclude_rhf = False , lobes = False , binary = False , do_augmentation = False , copd = False ):
627
627
super (PointToMeshDS , self ).__init__ (sample_points = sample_points , kp_mode = kp_mode , folder = folder ,
628
628
image_folder = image_folder ,
629
629
use_coords = use_coords , patch_feat = patch_feat , exclude_rhf = exclude_rhf ,
630
- lobes = lobes , binary = binary , do_augmentation = do_augmentation , copd = copd ,
631
- all_to_device = all_to_device )
630
+ lobes = lobes , binary = binary , do_augmentation = do_augmentation , copd = copd )
632
631
self .meshes = []
633
632
self .img_sizes = []
634
633
for case , sequence in self .ids :
@@ -655,7 +654,7 @@ def unnormalize_mesh(self, mesh: Meshes, index):
655
654
656
655
657
656
class PointToMeshAndLabelDataset (PointToMeshDS ):
658
- def __init__ (self , sample_points , kp_mode , folder = POINT_DIR , image_folder = IMG_DIR , use_coords = True ,
657
+ def __init__ (self , sample_points , kp_mode , folder = POINT_DIR_TS , image_folder = IMG_DIR_TS_PREPROC , use_coords = True ,
659
658
patch_feat = None , exclude_rhf = False , lobes = False , binary = False , do_augmentation = True , copd = False ):
660
659
super ().__init__ (sample_points = sample_points , kp_mode = kp_mode , folder = folder ,
661
660
image_folder = image_folder ,
0 commit comments