Skip to content

feat: use hessian in optimization #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions src/diffpy/snmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(
init_stretch=None,
rho=0,
eta=0,
max_iter=500,
tol=5e-7,
max_iter=300,
tol=1e-6,
n_components=None,
random_state=None,
):
Expand Down Expand Up @@ -231,24 +231,26 @@ def __init__(
print("Finished optimization.")

def optimize_loop(self):
# Update components first
self._prev_grad_components = self.grad_components.copy()
self.update_components()
self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_components: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
if self.objective_difference is None:
self.objective_difference = self._objective_history[-1] - self.objective_function

# Now we update weights
self.update_weights()
self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_weights: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
for i in range(4):
# Update components first
self._prev_grad_components = self.grad_components.copy()
self.update_components()
self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_components: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)
if self.objective_difference is None:
self.objective_difference = self._objective_history[-1] - self.objective_function

# Now we update weights
self.update_weights()
self.num_updates += 1
self.residuals = self.get_residual_matrix()
self.objective_function = self.get_objective_function()
print(f"Objective function after update_weights: {self.objective_function:.5e}")
self._objective_history.append(self.objective_function)

# Now we update stretch
self.update_stretch()
Expand Down Expand Up @@ -488,7 +490,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None

return stretch_transformed

def solve_quadratic_program(self, t, m, alg="trust-constr"):
def solve_quadratic_program(self, t, m, alg="L-BFGS-B"):
"""
Solves the quadratic program for updating y in stretched NMF using scipy.optimize:

Expand Down Expand Up @@ -588,7 +590,7 @@ def update_components(self):
+ self.eta * np.sqrt(self.components)
< 0
)
self.components = mask * self.components
self.components[~mask] = 0

objective_improvement = self._objective_history[-1] - self.get_objective_function(
residuals=self.get_residual_matrix()
Expand Down Expand Up @@ -656,7 +658,18 @@ def regularize_function(self, stretch=None):
der_reshaped.T + self.rho * stretch @ self._spline_smooth_operator.T @ self._spline_smooth_operator
)

return reg_func, func_grad
# Hessian: diagonal of second derivatives
hess_diag_vals = np.sum(
dd_stretch_components * np.tile(stretch_difference, (1, self.n_components)), axis=0
).ravel(order="F")

# Add the spline regularization Hessian (rho * PPPP)
smooth_hess = self.rho * np.kron(self._spline_smooth_penalty.toarray(), np.eye(self.n_components))

full_hess_diag = hess_diag_vals + np.diag(smooth_hess)
hessian = diags(full_hess_diag, format="csc")

return reg_func, func_grad, hessian

def update_stretch(self):
"""
Expand All @@ -669,9 +682,9 @@ def update_stretch(self):
# Define the optimization function
def objective(stretch_vec):
stretch_matrix = stretch_vec.reshape(self.stretch.shape) # Reshape back to matrix form
func, grad = self.regularize_function(stretch_matrix)
func, grad, hess = self.regularize_function(stretch_matrix)
grad = grad.flatten()
return func, grad
return func, grad, hess

# Optimization constraints: lower bound 0.1, no upper bound
bounds = [(0.1, None)] * stretch_init_vec.size # Equivalent to 0.1 * ones(K, M)
Expand All @@ -682,6 +695,7 @@ def objective(stretch_vec):
x0=stretch_init_vec, # Initial guess
method="trust-constr", # Equivalent to 'trust-region-reflective'
jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient
hess=lambda stretch_vec: objective(stretch_vec)[2],
bounds=bounds, # Lower bounds on stretch
# TODO: A Hessian can be incorporated for better convergence.
)
Expand Down