diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index 63bfe9a0..a0105c83 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -69,8 +69,8 @@ def __init__( init_stretch=None, rho=0, eta=0, - max_iter=500, - tol=5e-7, + max_iter=300, + tol=1e-6, n_components=None, random_state=None, ): @@ -231,24 +231,26 @@ def __init__( print("Finished optimization.") def optimize_loop(self): - # Update components first - self._prev_grad_components = self.grad_components.copy() - self.update_components() - self.num_updates += 1 - self.residuals = self.get_residual_matrix() - self.objective_function = self.get_objective_function() - print(f"Objective function after update_components: {self.objective_function:.5e}") - self._objective_history.append(self.objective_function) - if self.objective_difference is None: - self.objective_difference = self._objective_history[-1] - self.objective_function - # Now we update weights - self.update_weights() - self.num_updates += 1 - self.residuals = self.get_residual_matrix() - self.objective_function = self.get_objective_function() - print(f"Objective function after update_weights: {self.objective_function:.5e}") - self._objective_history.append(self.objective_function) + for i in range(4): + # Update components first + self._prev_grad_components = self.grad_components.copy() + self.update_components() + self.num_updates += 1 + self.residuals = self.get_residual_matrix() + self.objective_function = self.get_objective_function() + print(f"Objective function after update_components: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) + if self.objective_difference is None: + self.objective_difference = self._objective_history[-1] - self.objective_function + + # Now we update weights + self.update_weights() + self.num_updates += 1 + self.residuals = self.get_residual_matrix() + self.objective_function = self.get_objective_function() + print(f"Objective function after update_weights: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) # Now we update stretch self.update_stretch() @@ -488,7 +490,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None return stretch_transformed - def solve_quadratic_program(self, t, m, alg="trust-constr"): + def solve_quadratic_program(self, t, m, alg="L-BFGS-B"): """ Solves the quadratic program for updating y in stretched NMF using scipy.optimize: @@ -588,7 +590,7 @@ def update_components(self): + self.eta * np.sqrt(self.components) < 0 ) - self.components = mask * self.components + self.components[~mask] = 0 objective_improvement = self._objective_history[-1] - self.get_objective_function( residuals=self.get_residual_matrix() @@ -656,7 +658,18 @@ def regularize_function(self, stretch=None): der_reshaped.T + self.rho * stretch @ self._spline_smooth_operator.T @ self._spline_smooth_operator ) - return reg_func, func_grad + # Hessian: diagonal of second derivatives + hess_diag_vals = np.sum( + dd_stretch_components * np.tile(stretch_difference, (1, self.n_components)), axis=0 + ).ravel(order="F") + + # Add the spline regularization Hessian (rho * PPPP) + smooth_hess = self.rho * np.kron(self._spline_smooth_penalty.toarray(), np.eye(self.n_components)) + + full_hess_diag = hess_diag_vals + np.diag(smooth_hess) + hessian = diags(full_hess_diag, format="csc") + + return reg_func, func_grad, hessian def update_stretch(self): """ @@ -669,9 +682,9 @@ def update_stretch(self): # Define the optimization function def objective(stretch_vec): stretch_matrix = stretch_vec.reshape(self.stretch.shape) # Reshape back to matrix form - func, grad = self.regularize_function(stretch_matrix) + func, grad, hess = self.regularize_function(stretch_matrix) grad = grad.flatten() - return func, grad + return func, grad, hess # Optimization constraints: lower bound 0.1, no upper bound bounds = [(0.1, None)] * stretch_init_vec.size # Equivalent to 0.1 * ones(K, M) @@ -682,6 +695,7 @@ def objective(stretch_vec): x0=stretch_init_vec, # Initial guess method="trust-constr", # Equivalent to 'trust-region-reflective' jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient + hess=lambda stretch_vec: objective(stretch_vec)[2], bounds=bounds, # Lower bounds on stretch # TODO: A Hessian can be incorporated for better convergence. )