Skip to content

Commit

Permalink
testing idea of skipping gather_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Feb 6, 2024
1 parent 1798bd1 commit 9c4176c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
26 changes: 17 additions & 9 deletions ptypy/accelerate/cuda_cupy/engines/projectional_cupy_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def engine_iterate(self, num=1):

for it in range(num):

error = {}
reduced_error = np.zeros((3,))
reduced_error_count = 0
local_error = {}

for inner in range(self.p.overlap_max_iterations):

Expand Down Expand Up @@ -403,20 +405,26 @@ def engine_iterate(self, num=1):
cp.asnumpy(s.gpu, stream=self.queue, out=s.data)
for name, s in self.pr.S.items():
cp.asnumpy(s.gpu, stream=self.queue, out=s.data)

self.queue.synchronize()

# costly but needed to sync back with
# for name, s in self.ex.S.items():
# s.data[:] = s.gpu.get()
# Gather errors from device
for dID, prep in self.diff_info.items():
err_fourier = prep.err_fourier_gpu.get()
err_phot = prep.err_phot_gpu.get()
err_exit = prep.err_exit_gpu.get()
errs = np.ascontiguousarray(np.vstack([err_fourier, err_phot, err_exit]).T)
error.update(zip(prep.view_IDs, errs))

self.error = error
if self.p.record_local_error:
local_error.update(zip(prep.view_IDs, errs))
else:
reduced_error += errs.sum(axis=0)
reduced_error_count += errs.shape[0]

if self.p.record_local_error:
error = local_error
else:
# Gather errors across all MPI ranks
error = parallel.allreduce(reduced_error)
count = parallel.allreduce(reduced_error_count)
error /= count
return error

# probe update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,7 @@ def engine_iterate(self, num=1):
for name, s in self.pr.S.items():
s.data[:] = s.gpu.get()

# costly but needed to sync back with
# for name, s in self.ex.S.items():
# s.data[:] = s.gpu.get()
# Gather errors
for dID, prep in self.diff_info.items():
err_fourier = prep.err_fourier_gpu.get()
err_phot = prep.err_phot_gpu.get()
Expand Down
16 changes: 11 additions & 5 deletions ptypy/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,17 @@ def iterate(self, num=None):
parallel.barrier()

def _fill_runtime(self):
local_error = u.parallel.gather_dict(self.error)
if local_error:
error = np.array(list(local_error.values())).mean(0)
local_error = None
if isinstance(self.error, np.ndarray) and (len(self.error)== 3):
error = self.error
elif isinstance(self.error, dict):
local_error = u.parallel.gather_dict(self.error)
if local_error:
error = np.array(list(local_error.values())).mean(0)
else:
error = np.zeros((3,))
else:
error = np.zeros((1,))
logger.error("Reconstruction error should be dictionary or ndarray of shape (3,)")
info = dict(
iteration=self.curiter,
iterations=self.alliter,
Expand All @@ -277,7 +283,7 @@ def _fill_runtime(self):
)

self.ptycho.runtime.iter_info.append(info)
if self.p.record_local_error:
if self.p.record_local_error and (local_error is not None):
self.ptycho.runtime.error_local = local_error

def finalize(self):
Expand Down

0 comments on commit 9c4176c

Please sign in to comment.