From 3cc7b86cfcee0a6769ff116d7ca489687e36d77f Mon Sep 17 00:00:00 2001 From: Jaroslav Fowkes Date: Mon, 9 May 2022 15:33:08 +0100 Subject: [PATCH 1/4] Option to use full polynomial in ML linesearch --- ptypy/engines/ML.py | 186 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 184 insertions(+), 2 deletions(-) diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index e7492b42f..30eb6c10d 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -100,6 +100,11 @@ class ML(PositionCorrectionEngine): lowlim = 0 help = Number of iterations before probe update starts + [all_line_coeffs] + default = False + type = bool + help = Whether to use all nine coefficients in the linesearch instead of three + """ SUPPORTED_MODELS = [Full, Vanilla, Bragg3dModel, BlockVanilla, BlockFull, GradFull, BlockGradFull] @@ -287,7 +292,10 @@ def engine_iterate(self, num=1): # In principle, the way things are now programmed this part # could be iterated over in a real Newton-Raphson style. t2 = time.time() - B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) + if self.p.all_line_coeffs: + B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h) + else: + B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) tc += time.time() - t2 if np.isinf(B).any() or np.isnan(B).any(): @@ -296,7 +304,17 @@ def engine_iterate(self, num=1): B[np.isinf(B)] = 0. B[np.isnan(B)] = 0. - self.tmin = dt(-.5 * B[1] / B[2]) + if self.p.all_line_coeffs: + diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative + roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double + real_roots = np.real(roots[np.isreal(roots)]) # not interested in complex roots + if real_roots.size == 1: # single real root + self.tmin = dt(real_roots[0]) + else: # find real root with smallest poly objective + evalp = lambda root: np.polyval(np.flip(B),root) + self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective + else: # same as above but quicker when poly quadratic + self.tmin = dt(-.5 * B[1] / B[2]) self.ob_h *= self.tmin self.pr_h *= self.tmin self.ob += self.ob_h @@ -427,6 +445,13 @@ def poly_line_coeffs(self, ob_h, pr_h): """ raise NotImplementedError + def poly_line_all_coeffs(self, ob_h, pr_h): + """ + Compute all the coefficients of the polynomial for line minimization + in direction h + """ + raise NotImplementedError + class GaussianModel(BaseModel): """ @@ -593,6 +618,85 @@ def poly_line_coeffs(self, ob_h, pr_h): return B + def poly_line_all_coeffs(self, ob_h, pr_h): + """ + Compute all the coefficients of the polynomial for line minimization + in direction h + """ + + B = np.zeros((9,), dtype=np.longdouble) + Brenorm = 1. / self.LL[0]**2 + + # Outer loop: through diffraction patterns + for dname, diff_view in self.di.views.items(): + if not diff_view.active: + continue + + # Weights and intensities for this view + w = self.weights[diff_view] + I = diff_view.data + + A0 = None + A1 = None + A2 = None + A3 = None + A4 = None + + for name, pod in diff_view.pods.items(): + if not pod.active: + continue + f = pod.fw(pod.probe * pod.object) + a = pod.fw(pod.probe * ob_h[pod.ob_view] + + pr_h[pod.pr_view] * pod.object) + b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view]) + + if A0 is None: + A0 = u.abs2(f).astype(np.longdouble) + A1 = 2 * np.real(f * a.conj()).astype(np.longdouble) + A2 = (2 * np.real(f * b.conj()).astype(np.longdouble) + + u.abs2(a).astype(np.longdouble)) + A3 = 2 * np.real(a * b.conj()).astype(np.longdouble) + A4 = u.abs2(b).astype(np.longdouble) + else: + A0 += u.abs2(f) + A1 += 2 * np.real(f * a.conj()) + A2 += 2 * np.real(f * b.conj()) + u.abs2(a) + A3 += 2 * np.real(a * b.conj()) + A4 += u.abs2(b) + + if self.p.floating_intensities: + A0 *= self.float_intens_coeff[dname] + A1 *= self.float_intens_coeff[dname] + A2 *= self.float_intens_coeff[dname] + A3 *= self.float_intens_coeff[dname] + A4 *= self.float_intens_coeff[dname] + + A0 = np.double(A0) - pod.upsample(I) + #A0 -= pod.upsample(I) + w = pod.upsample(w) + + B[0] += np.dot(w.flat, (A0**2).flat) * Brenorm + B[1] += np.dot(w.flat, (2*A0*A1).flat) * Brenorm + B[2] += np.dot(w.flat, (A1**2 + 2*A0*A2).flat) * Brenorm + B[3] += np.dot(w.flat, (2*A0*A3 + 2*A1*A2).flat) * Brenorm + B[4] += np.dot(w.flat, (A2**2 + 2*A1*A3 + 2*A0*A4).flat) * Brenorm + B[5] += np.dot(w.flat, (2*A1*A4 + 2*A2*A3).flat) * Brenorm + B[6] += np.dot(w.flat, (A3**2 + 2*A2*A4).flat) * Brenorm + B[7] += np.dot(w.flat, (2*A3*A4).flat) * Brenorm + B[8] += np.dot(w.flat, (A4**2).flat) * Brenorm + + parallel.allreduce(B) + + # Object regularizer + if self.regularizer: + for name, s in self.ob.storages.items(): + B[:3] += Brenorm * self.regularizer.poly_line_coeffs( + ob_h.storages[name].data, s.data) + + self.B = B + + return B + class PoissonModel(BaseModel): """ @@ -744,6 +848,84 @@ def poly_line_coeffs(self, ob_h, pr_h): return B + def poly_line_all_coeffs(self, ob_h, pr_h): + """ + Compute all the coefficients of the polynomial for line minimization + in direction h + """ + B = np.zeros((9,), dtype=np.longdouble) + Brenorm = 1/(self.tot_measpts * self.LL[0])**2 + + # Outer loop: through diffraction patterns + for dname, diff_view in self.di.views.items(): + if not diff_view.active: + continue + + # Weights and intensities for this view + I = diff_view.data + m = diff_view.pod.ma_view.data + + A0 = None + A1 = None + A2 = None + A3 = None + A4 = None + + for name, pod in diff_view.pods.items(): + if not pod.active: + continue + f = pod.fw(pod.probe * pod.object) + a = pod.fw(pod.probe * ob_h[pod.ob_view] + + pr_h[pod.pr_view] * pod.object) + b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view]) + + if A0 is None: + A0 = u.abs2(f).astype(np.longdouble) + A1 = 2 * np.real(f * a.conj()).astype(np.longdouble) + A2 = (2 * np.real(f * b.conj()).astype(np.longdouble) + + u.abs2(a).astype(np.longdouble)) + A3 = 2 * np.real(a * b.conj()).astype(np.longdouble) + A4 = u.abs2(b).astype(np.longdouble) + else: + A0 += u.abs2(f) + A1 += 2 * np.real(f * a.conj()) + A2 += 2 * np.real(f * b.conj()) + u.abs2(a) + A3 += 2 * np.real(a * b.conj()) + A4 += u.abs2(b) + + + if self.p.floating_intensities: + A0 *= self.float_intens_coeff[dname] + A1 *= self.float_intens_coeff[dname] + A2 *= self.float_intens_coeff[dname] + A3 *= self.float_intens_coeff[dname] + A4 *= self.float_intens_coeff[dname] + + A0 += 1e-6 + DI = 1. - I/A0 + + B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm + B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm + B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm + B[3] += (np.dot(m.flat, (A3*DI).flat) + np.dot(m.flat, (I*((A1*A2)/(A0**2.))).flat)) * Brenorm + B[4] += (np.dot(m.flat, (A4*DI).flat) + .5*np.dot(m.flat, (I*(A2/A0)**2.).flat) + np.dot(m.flat, (I*((A1*A3)/(A0**2.))).flat)) * Brenorm + B[5] += (np.dot(m.flat, (I*((A1*A4)/(A0**2.))).flat) + np.dot(m.flat, (I*((A2*A3)/(A0**2.))).flat)) * Brenorm + B[6] += (.5*np.dot(m.flat, (I*(A3/A0)**2.).flat) + np.dot(m.flat, (I*((A2*A4)/(A0**2.))).flat)) * Brenorm + B[7] += np.dot(m.flat, (I*((A3*A4)/(A0**2.))).flat) * Brenorm + B[8] += (.5*np.dot(m.flat, (I*(A4/A0)**2.).flat)) * Brenorm + + parallel.allreduce(B) + + # Object regularizer + if self.regularizer: + for name, s in self.ob.storages.items(): + B[:3] += Brenorm * self.regularizer.poly_line_coeffs( + ob_h.storages[name].data, s.data) + + self.B = B + + return B + class EuclidModel(BaseModel): """ From 8c4ca44019db12b465255f3c77df273c7eba3f54 Mon Sep 17 00:00:00 2001 From: Jaroslav Fowkes Date: Wed, 14 Feb 2024 13:44:43 +0000 Subject: [PATCH 2/4] Tidy and add full polynomial for Euclid model --- ptypy/engines/ML.py | 108 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 96 insertions(+), 12 deletions(-) diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index 30eb6c10d..1c2496b34 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -314,7 +314,7 @@ def engine_iterate(self, num=1): evalp = lambda root: np.polyval(np.flip(B),root) self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective else: # same as above but quicker when poly quadratic - self.tmin = dt(-.5 * B[1] / B[2]) + self.tmin = dt(-0.5 * B[1] / B[2]) self.ob_h *= self.tmin self.pr_h *= self.tmin self.ob += self.ob_h @@ -603,7 +603,7 @@ def poly_line_coeffs(self, ob_h, pr_h): w = pod.upsample(w) B[0] += np.dot(w.flat, (A0**2).flat) * Brenorm - B[1] += np.dot(w.flat, (2 * A0 * A1).flat) * Brenorm + B[1] += np.dot(w.flat, (2*A0*A1).flat) * Brenorm B[2] += np.dot(w.flat, (A1**2 + 2*A0*A2).flat) * Brenorm parallel.allreduce(B) @@ -691,7 +691,7 @@ def poly_line_all_coeffs(self, ob_h, pr_h): if self.regularizer: for name, s in self.ob.storages.items(): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( - ob_h.storages[name].data, s.data) + ob_h.storages[name].data, s.data) self.B = B @@ -834,7 +834,7 @@ def poly_line_coeffs(self, ob_h, pr_h): B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm - B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm + B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm parallel.allreduce(B) @@ -906,13 +906,13 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm - B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm - B[3] += (np.dot(m.flat, (A3*DI).flat) + np.dot(m.flat, (I*((A1*A2)/(A0**2.))).flat)) * Brenorm - B[4] += (np.dot(m.flat, (A4*DI).flat) + .5*np.dot(m.flat, (I*(A2/A0)**2.).flat) + np.dot(m.flat, (I*((A1*A3)/(A0**2.))).flat)) * Brenorm - B[5] += (np.dot(m.flat, (I*((A1*A4)/(A0**2.))).flat) + np.dot(m.flat, (I*((A2*A3)/(A0**2.))).flat)) * Brenorm - B[6] += (.5*np.dot(m.flat, (I*(A3/A0)**2.).flat) + np.dot(m.flat, (I*((A2*A4)/(A0**2.))).flat)) * Brenorm - B[7] += np.dot(m.flat, (I*((A3*A4)/(A0**2.))).flat) * Brenorm - B[8] += (.5*np.dot(m.flat, (I*(A4/A0)**2.).flat)) * Brenorm + B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm + B[3] += (np.dot(m.flat, (A3*DI).flat) + 0.5*np.dot(m.flat, (I*((2*A1*A2)/A0**2)).flat)) * Brenorm + B[4] += (np.dot(m.flat, (A4*DI).flat) + 0.5*np.dot(m.flat, (I*((A2**2 + 2*A1*A3)/A0**2)).flat)) * Brenorm + B[5] += 0.5*np.dot(m.flat, (I*((2*A1*A4 + 2*A2*A3)/A0**2)).flat) * Brenorm + B[6] += 0.5*np.dot(m.flat, (I*((A3**2 + 2*A2*A4)/A0**2)).flat) * Brenorm + B[7] += 0.5*np.dot(m.flat, (I*((2*A3*A4)/A0**2)).flat) * Brenorm + B[8] += 0.5*np.dot(m.flat, (I*(A4/A0)**2).flat) * Brenorm parallel.allreduce(B) @@ -1078,7 +1078,7 @@ def poly_line_coeffs(self, ob_h, pr_h): B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm - B[2] += (np.dot(w.flat, (A2*DA).flat) + .25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm + B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm parallel.allreduce(B) @@ -1092,6 +1092,90 @@ def poly_line_coeffs(self, ob_h, pr_h): return B + def poly_line_all_coeffs(self, ob_h, pr_h): + """ + Compute all the coefficients of the polynomial for line minimization + in direction h + """ + + B = np.zeros((9,), dtype=np.longdouble) + Brenorm = 1. / self.LL[0]**2 + + # Outer loop: through diffraction patterns + for dname, diff_view in self.di.views.items(): + if not diff_view.active: + continue + + # Weights and amplitudes for this view + w = self.weights[diff_view] + A = np.sqrt(diff_view.data) + + A0 = None + A1 = None + A2 = None + A3 = None + A4 = None + + for name, pod in diff_view.pods.items(): + if not pod.active: + continue + f = pod.fw(pod.probe * pod.object) + a = pod.fw(pod.probe * ob_h[pod.ob_view] + + pr_h[pod.pr_view] * pod.object) + b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view]) + + if A0 is None: + A0 = u.abs2(f).astype(np.longdouble) + A1 = 2 * np.real(f * a.conj()).astype(np.longdouble) + A2 = (2 * np.real(f * b.conj()).astype(np.longdouble) + + u.abs2(a).astype(np.longdouble)) + A3 = 2 * np.real(a * b.conj()).astype(np.longdouble) + A4 = u.abs2(b).astype(np.longdouble) + else: + A0 += u.abs2(f) + A1 += 2 * np.real(f * a.conj()) + A2 += 2 * np.real(f * b.conj()) + u.abs2(a) + A3 += 2 * np.real(a * b.conj()) + A4 += u.abs2(b) + + if self.p.floating_intensities: + A0 *= self.float_intens_coeff[dname] + A1 *= self.float_intens_coeff[dname] + A2 *= self.float_intens_coeff[dname] + A3 *= self.float_intens_coeff[dname] + A4 *= self.float_intens_coeff[dname] + + A0 += 1e-12 # cf Poisson model sqrt(1e-12) = 1e-6 + DA = 1. - A/np.sqrt(A0) + DA32 = A/A0**(3/2) + + B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm + B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm + B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * DA32).flat)) * Brenorm + B[3] += (np.dot(w.flat, (A3*DA).flat) + 0.25*np.dot(w.flat, (2*A1*A2 * DA32).flat) - 0.125*np.dot(w.flat, (A1**3/A0**2).flat)) * Brenorm + B[4] += (np.dot(w.flat, (A4*DA).flat) + 0.25*np.dot(w.flat, ((A2**2 + 2*A1*A3) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1**2*A2)/A0**2).flat) + + 0.015625*np.dot(w.flat, (A1**4/A0**3).flat)) * Brenorm + B[5] += (0.25*np.dot(w.flat, ((2*A2*A3 + 2*A1*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1*A2**2 + 3*A1**2*A3)/A0**2).flat) + + 0.015625*np.dot(w.flat, ((4*A1**3*A2)/A0**3).flat)) * Brenorm + B[6] += (0.25*np.dot(w.flat, ((A3**2 + 2*A2*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((A2**3 + 3*A1**2*A4 + 6*A1*A2*A3)/A0**2).flat) + + 0.015625*np.dot(w.flat, ((6*A1**2*A2**2 + 4*A1**3*A3)/A0**3).flat)) * Brenorm + B[7] += (0.25*np.dot(w.flat, (2*A3*A4 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2**2*A3 + 3*A1*A3**2 + 6*A1*A2*A4)/A0**2).flat) + + 0.015625*np.dot(w.flat, ((4*A1*A2**3 + 12*A1**2*A2*A3 + 4*A1**3*A4)/A0**3).flat)) * Brenorm + B[8] += (0.25*np.dot(w.flat, (A4**2 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2*A3**2 + 3*A2**2*A4 + 6*A1*A3*A4)/A0**2).flat) + + 0.015625*np.dot(w.flat, ((A2**4 + 12*A1*A2**2*A3 + 6*A1**2*A3**2 + 12*A1**2*A2*A4)/A0**3).flat)) * Brenorm + + parallel.allreduce(B) + + # Object regularizer + if self.regularizer: + for name, s in self.ob.storages.items(): + B[:3] += Brenorm * self.regularizer.poly_line_coeffs( + ob_h.storages[name].data, s.data) + + self.B = B + + return B + class Regul_del2(object): """\ From 66bfb1f4a63f35c33527b3374495279ae3c5c88d Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Tue, 5 Mar 2024 14:11:30 +0000 Subject: [PATCH 3/4] code restructure, less switching --- ptypy/engines/ML.py | 54 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index 1c2496b34..d9d99b9b7 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -294,17 +294,6 @@ def engine_iterate(self, num=1): t2 = time.time() if self.p.all_line_coeffs: B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h) - else: - B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) - tc += time.time() - t2 - - if np.isinf(B).any() or np.isnan(B).any(): - logger.warning( - 'Warning! inf or nan found! Trying to continue...') - B[np.isinf(B)] = 0. - B[np.isnan(B)] = 0. - - if self.p.all_line_coeffs: diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double real_roots = np.real(roots[np.isreal(roots)]) # not interested in complex roots @@ -313,8 +302,13 @@ def engine_iterate(self, num=1): else: # find real root with smallest poly objective evalp = lambda root: np.polyval(np.flip(B),root) self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective - else: # same as above but quicker when poly quadratic + else: + B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) + # same as above but quicker when poly quadratic self.tmin = dt(-0.5 * B[1] / B[2]) + + tc += time.time() - t2 + self.ob_h *= self.tmin self.pr_h *= self.tmin self.ob += self.ob_h @@ -614,6 +608,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -693,6 +693,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -844,6 +850,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -922,6 +934,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -1088,6 +1106,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -1172,6 +1196,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B From f279f00de6337c271b6a841ec0ce3743bc2dc208 Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Tue, 5 Mar 2024 15:18:37 +0000 Subject: [PATCH 4/4] change new parameter from boolean to str --- ptypy/engines/ML.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index d9d99b9b7..9d5bb7d1a 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -100,10 +100,11 @@ class ML(PositionCorrectionEngine): lowlim = 0 help = Number of iterations before probe update starts - [all_line_coeffs] - default = False - type = bool - help = Whether to use all nine coefficients in the linesearch instead of three + [poly_line_coeffs] + default = quadratic + type = str + help = How many coefficients to be used in the the linesearch + doc = choose between the 'quadratic' approximation (default) or 'all' """ @@ -292,7 +293,7 @@ def engine_iterate(self, num=1): # In principle, the way things are now programmed this part # could be iterated over in a real Newton-Raphson style. t2 = time.time() - if self.p.all_line_coeffs: + if self.p.poly_line_coeffs == "all": B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h) diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double @@ -302,10 +303,12 @@ def engine_iterate(self, num=1): else: # find real root with smallest poly objective evalp = lambda root: np.polyval(np.flip(B),root) self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective - else: + elif self.p.poly_line_coeffs == "quadratic": B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) # same as above but quicker when poly quadratic self.tmin = dt(-0.5 * B[1] / B[2]) + else: + raise NotImplementedError("poly_line_coeffs should be 'quadratic' or 'all'") tc += time.time() - t2