Skip to content

Commit

Permalink
Merge pull request #69 from LSSTDESC/ceci2
Browse files Browse the repository at this point in the history
Update for ceci version 2
  • Loading branch information
grantmerz authored Aug 2, 2024
2 parents 031adf8 + f49af59 commit 93d0ae7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/rail/estimation/algos/deepdisc.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ class DeepDiscInformer(CatInformer):
)
inputs = [('input', TableHandle), ('metadata', Hdf5Handle)]

def __init__(self, args, comm=None):
CatInformer.__init__(self, args, comm=comm)
def __init__(self, args, **kwargs):
super().__init__(args, **kwargs)

# check to make sure that batch_size is an even multiple of num_gpus
if self.config.batch_size % self.config.num_gpus != 0:
Expand Down Expand Up @@ -314,11 +314,11 @@ class DeepDiscPDFEstimator(CatEstimator):
("metadata", JsonHandle)]
outputs = [("output", QPHandle)]
def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
"""Constructor:
Do Estimator specific initialization"""
self.nnmodel = None
CatEstimator.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)
def estimate(self, input_data, input_metadata):
with tempfile.TemporaryDirectory() as temp_directory_name:
Expand Down Expand Up @@ -513,10 +513,10 @@ class DeepDiscPDFEstimatorWithChunking(CatEstimator):
("metadata", Hdf5Handle)]
outputs = [("output", QPHandle)]

def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
"""Constructor:
Do Estimator specific initialization"""
CatEstimator.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)

self.nnmodel = None
self.zgrid = np.linspace(0, 5, 200)
Expand Down

0 comments on commit 93d0ae7

Please sign in to comment.