Skip to content

Commit

Permalink
More bugfixes in ML cupy/pycuda engine
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Feb 2, 2025
1 parent b54c8bb commit dc2a127
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions ptypy/accelerate/cuda_cupy/engines/ML_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,11 @@ def new_grad(self):
# local references
ob = self.engine.ob.S[oID].data
obg = ob_grad.S[oID].data
obf = ob_fln.S[oID].data
pr = self.engine.pr.S[pID].data
prg = pr_grad.S[pID].data
prf = pr_fln.S[pID].data
if self.engine.p.wavefield_precond:
obf = ob_fln.S[oID].data
prf = pr_fln.S[pID].data

# Schedule w & I to device
ev_w, w, data_w = self.engine.w_data.to_gpu(
Expand Down
5 changes: 3 additions & 2 deletions ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,11 @@ def new_grad(self):
# local references
ob = self.engine.ob.S[oID].data
obg = ob_grad.S[oID].data
obf = ob_fln.S[oID].data
pr = self.engine.pr.S[pID].data
prg = pr_grad.S[pID].data
prf = pr_fln.S[pID].data
if self.engine.p.wavefield_precond:
obf = ob_fln.S[oID].data
prf = pr_fln.S[pID].data

# Schedule w & I to device
ev_w, w, data_w = self.engine.w_data.to_gpu(prep.weights, dID, qu_htod)
Expand Down

0 comments on commit dc2a127

Please sign in to comment.