Skip to content

Commit 38bbb00

Browse files
author
John Halloran
committed
refactor: get residual matrix without a helper
1 parent 0bf62a8 commit 38bbb00

File tree

1 file changed

+68
-4
lines changed

1 file changed

+68
-4
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,51 @@ def outer_loop(self):
302302
)
303303

304304
def get_residual_matrix(self, components=None, weights=None, stretch=None):
305+
"""
306+
Return the residuals (difference) between the source matrix and its reconstruction
307+
from the given components, weights, and stretch factors.
308+
309+
Each component profile is stretched, interpolated to fractional positions,
310+
weighted per signal, and summed to form the reconstruction. The residuals
311+
are the source matrix minus this reconstruction.
312+
313+
Parameters
314+
----------
315+
components : (signal_len, n_components) array, optional
316+
weights : (n_components, n_signals) array, optional
317+
stretch : (n_components, n_signals) array, optional
318+
319+
Returns
320+
-------
321+
residuals : (signal_len, n_signals) array
322+
"""
323+
324+
if components is None:
325+
components = self.components
326+
if weights is None:
327+
weights = self.weights
328+
if stretch is None:
329+
stretch = self.stretch
330+
331+
residuals = -self.source_matrix.copy()
332+
sample_indices = np.arange(components.shape[0]) # (signal_len,)
333+
334+
for comp in range(components.shape[1]): # loop over components
335+
residuals += (
336+
np.interp(
337+
sample_indices[:, None]
338+
/ stretch[comp][None, :], # fractional positions (signal_len, n_signals)
339+
sample_indices, # (signal_len,)
340+
components[:, comp], # component profile (signal_len,)
341+
left=components[0, comp],
342+
right=components[-1, comp],
343+
)
344+
* weights[comp][None, :] # broadcast (n_signals,) over rows
345+
)
346+
347+
return residuals
348+
349+
def old_get_residual_matrix(self, components=None, weights=None, stretch=None):
305350
# Initialize residual matrix as negative of source_matrix
306351
if components is None:
307352
components = self.components
@@ -310,10 +355,29 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
310355
if stretch is None:
311356
stretch = self.stretch
312357
residuals = -self.source_matrix.copy()
313-
# Compute transformed components for all (k, m) pairs
314-
for k in range(weights.shape[0]): # K
315-
stretched_components, _, _ = apply_interpolation(stretch[k, :], components[:, k]) # Only use Ax
316-
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
358+
359+
# Discrete sample positions along the component axis
360+
sample_indices = np.arange(components.shape[0]) # (N,)
361+
362+
for comp in range(components.shape[1]): # loop over components
363+
component_profile = components[:, comp] # (N,)
364+
stretch_factors = stretch[comp, :] # (M,)
365+
366+
# Compute scaled/fractional positions along component_profile
367+
fractional_positions = sample_indices[:, None] / stretch_factors[None, :]
368+
369+
# Interpolate component_profile at fractional positions, clamp to ends
370+
interpolated_component = np.interp(
371+
fractional_positions,
372+
sample_indices,
373+
component_profile,
374+
left=component_profile[0],
375+
right=component_profile[-1],
376+
)
377+
378+
# Accumulate weighted contribution into residuals
379+
residuals += interpolated_component * weights[comp, None, :] # (M,) broadcast
380+
317381
return residuals
318382

319383
def get_objective_function(self, residuals=None, stretch=None):

0 commit comments

Comments
 (0)