@@ -233,7 +233,7 @@ def optimize_loop(self):
233233 self .num_updates += 1
234234 self .residuals = self .get_residual_matrix ()
235235 self .objective_function = self .get_objective_function ()
236- print (f"Objective function after updateX : { self .objective_function :.5e} " )
236+ print (f"Objective function after update_comps : { self .objective_function :.5e} " )
237237 self ._objective_history .append (self .objective_function )
238238 if self .objective_difference is None :
239239 self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
@@ -243,7 +243,7 @@ def optimize_loop(self):
243243 self .num_updates += 1
244244 self .residuals = self .get_residual_matrix ()
245245 self .objective_function = self .get_objective_function ()
246- print (f"Objective function after updateY2 : { self .objective_function :.5e} " )
246+ print (f"Objective function after update_weights : { self .objective_function :.5e} " )
247247 self ._objective_history .append (self .objective_function )
248248
249249 # Now we update stretch
@@ -266,14 +266,16 @@ def apply_interpolation(self, a, x, return_derivatives=False):
266266 a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
267267
268268 # Compute fractional indices, broadcasting over `a`
269- ii = np .arange (x_len )[:, None ] / a # Shape (N, M)
269+ fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
270270
271- II = np .floor (ii ).astype (int ) # Integer part (still (N, M))
272- valid_mask = II < (x_len - 1 ) # Ensure indices are within bounds
271+ integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
272+ valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
273273
274274 # Apply valid_mask to keep correct indices
275- idx_int = np .where (valid_mask , II , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
276- idx_frac = np .where (valid_mask , ii , II ) # Keep aligned (previously "i")
275+ idx_int = np .where (
276+ valid_mask , integer_indices , x_len - 2
277+ ) # Prevent out-of-bounds indexing (previously "I")
278+ idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
277279
278280 # Ensure x is a 1D array
279281 x = np .asarray (x ).ravel ()
@@ -351,7 +353,7 @@ def apply_interpolation_matrix(self, comps=None, weights=None, stretch=None, ret
351353 stretch_tiled = np .tile (stretch_flat , (self ._signal_len , 1 ))
352354
353355 # Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch)
354- ii = (
356+ fractional_indices = (
355357 np .tile (np .arange (self ._signal_len )[:, None ], (1 , self ._num_conditions * self ._n_components ))
356358 * stretch_tiled
357359 )
@@ -368,44 +370,45 @@ def apply_interpolation_matrix(self, comps=None, weights=None, stretch=None, ret
368370 ).reshape (self ._signal_len , self ._n_components * self ._num_conditions )
369371
370372 # Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)])
371- X1 = np .vstack ([comps , comps [- 1 , :]]) # Duplicate last row (like MATLAB)
373+ comps_bounded = np .vstack ([comps , comps [- 1 , :]]) # Duplicate last row (like MATLAB)
372374
373375 # Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1))
374- II = np .floor (ii ).astype (int )
376+ floor_indices = np .floor (fractional_indices ).astype (int )
375377
376- II1 = np .minimum (II + 1 , self ._signal_len )
377- II2 = np .minimum (II1 + 1 , self ._signal_len )
378+ floor_ind_1 = np .minimum (floor_indices + 1 , self ._signal_len )
379+ floor_ind_2 = np .minimum (floor_ind_1 + 1 , self ._signal_len )
378380
379381 # Compute fractional part (MATLAB: iI = ii - II)
380- iI = ii - II
382+ fractional_floor_indices = fractional_indices - floor_indices
381383
382384 # Compute offset indices (MATLAB: II1_ = II1 + bias; II2_ = II2 + bias)
383- II1_ = II1 + bias
384- II2_ = II2 + bias
385+ offset_floor_ind_1 = floor_ind_1 + bias
386+ offset_floor_ind_2 = floor_ind_2 + bias
385387
386388 # Extract values (MATLAB: XI1 = reshape(X1(II1_), N, K*M); XI2 = reshape(X1(II2_), N, K*M))
387389 # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
388- XI1 = X1 .flatten (order = "F" )[(II1_ - 1 ).ravel ()].reshape (
390+ # order = F uses FORTRAN, column major order
391+ comps_val_1 = comps_bounded .flatten (order = "F" )[(offset_floor_ind_1 - 1 ).ravel ()].reshape (
389392 self ._signal_len , self ._n_components * self ._num_conditions
390- ) # order = F uses FORTRAN, column major order
391- XI2 = X1 .flatten (order = "F" )[(II2_ - 1 ).ravel ()].reshape (
393+ )
394+ comps_val_2 = comps_bounded .flatten (order = "F" )[(offset_floor_ind_2 - 1 ).ravel ()].reshape (
392395 self ._signal_len , self ._n_components * self ._num_conditions
393396 )
394397
395398 # Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_comps=Ax2.*YY)
396- Ax2 = XI1 * (1 - iI ) + XI2 * iI
397- stretched_comps = Ax2 * weights_tiled # Apply weighting
399+ stretch_comps2 = comps_val_1 * (1 - fractional_floor_indices ) + comps_val_2 * fractional_floor_indices
400+ stretched_comps = stretch_comps2 * weights_tiled # Apply weighting
398401
399402 if return_derivatives :
400403 # Compute first derivative (MATLAB: Tx2=XI1.*(-di)+XI2.*di; d_str_cmps=Tx2.*YY)
401- di = - ii * stretch_tiled
402- d_x2 = XI1 * (- di ) + XI2 * di
403- d_str_cmps = d_x2 * weights_tiled
404+ di = - fractional_indices * stretch_tiled
405+ d_comps2 = comps_val_1 * (- di ) + comps_val_2 * di
406+ d_str_cmps = d_comps2 * weights_tiled
404407
405408 # Compute second derivative (MATLAB: Hx2=XI1.*(-ddi)+XI2.*ddi; dd_str_comps=Hx2.*YY)
406409 ddi = - di * stretch_tiled * 2
407- dd_x2 = XI1 * (- ddi ) + XI2 * ddi
408- dd_str_cmps = dd_x2 * weights_tiled
410+ dd_comps2 = comps_val_1 * (- ddi ) + comps_val_2 * ddi
411+ dd_str_cmps = dd_comps2 * weights_tiled
409412 else :
410413 shape = stretched_comps .shape
411414 d_str_cmps = np .empty (shape )
@@ -430,13 +433,17 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
430433 K = weights .shape [0 ]
431434
432435 # Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1))
433- AA = np .tile (stretch .reshape (1 , M * K , order = "F" ) ** - 1 , (N , 1 ))
436+ stretch_tiled = np .tile (
437+ stretch .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ) ** - 1 , (self ._signal_len , 1 )
438+ )
434439
435440 # Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA)
436- ii = np .arange (N )[:, None ] * AA # Shape (N, M*K), replacing `index`
441+ ii = np .arange (self . _signal_len )[:, None ] * stretch_tiled # Shape (N, M*K), replacing `index`
437442
438443 # Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1))
439- YY = np .tile (weights .reshape (1 , M * K , order = "F" ), (N , 1 ))
444+ weights_tiled = np .tile (
445+ weights .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ), (self ._signal_len , 1 )
446+ )
440447
441448 # Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1))
442449 II = np .floor (ii ).astype (int )
@@ -448,7 +455,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
448455 II2_ = II2
449456
450457 # Compute fractional part (MATLAB: iI = ii - II)
451- iI = ii - II
458+ fractional_indices = ii - II
452459
453460 # Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
454461 repm = np .tile (np .arange (K ), (N , M ))
@@ -457,12 +464,14 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
457464 kron = np .kron (residuals , np .ones ((1 , K )))
458465
459466 # (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY)
460- kron_iI = kron * iI
461- iIYY = (iI - 1 ) * YY
467+ kron_iI = kron * fractional_indices
468+ iIYY = (fractional_indices - 1 ) * weights_tiled
462469
463470 # Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K))
464471 x2 = coo_matrix (((- kron * iIYY ).flatten (), (II1_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )).tocsc ()
465- x3 = coo_matrix (((kron_iI * YY ).flatten (), (II2_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )).tocsc ()
472+ x3 = coo_matrix (
473+ ((kron_iI * weights_tiled ).flatten (), (II2_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )
474+ ).tocsc ()
466475
467476 # Combine the last row into previous, then remove the last row
468477 x2 [N - 1 , :] += x2 [N , :]
@@ -527,46 +536,53 @@ def hess(y):
527536
528537 def update_comps (self ):
529538 """
530- Updates `comps` using gradient-based optimization with adaptive step size L .
539+ Updates `comps` using gradient-based optimization with adaptive step size step_size .
531540 """
532541 # Compute `stretched_comps` using the interpolation function
533542 stretched_comps , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs (derivatives)
534543 # Compute RA and RR
535- intermediate_RA = stretched_comps .flatten (order = "F" ).reshape (
544+ intermediate_reshaped = stretched_comps .flatten (order = "F" ).reshape (
536545 (self ._signal_len * self ._num_conditions , self ._n_components ), order = "F"
537546 )
538- RA = intermediate_RA .sum (axis = 1 ).reshape ((self ._signal_len , self ._num_conditions ), order = "F" )
539- RR = RA - self .source_matrix
547+ reshaped_stretched_components = intermediate_reshaped .sum (axis = 1 ).reshape (
548+ (self ._signal_len , self ._num_conditions ), order = "F"
549+ )
550+ component_residuals = reshaped_stretched_components - self .source_matrix
540551 # Compute gradient `GraX`
541552 self .grad_comps = self .apply_transformation_matrix (
542- residuals = RR
543- ).toarray () # toarray equivalent of full, make non-sparse
553+ residuals = component_residuals
554+ ).toarray () # toarray equivalent of MATLAB " full", makes non-sparse
544555
545- # Compute initial step size `L0 `
546- L0 = np .linalg .eigvalsh (self .weights .T @ self .weights ).max () * np .max (
556+ # Compute initial step size `initial_step_size `
557+ initial_step_size = np .linalg .eigvalsh (self .weights .T @ self .weights ).max () * np .max (
547558 [self .stretch .max (), 1 / self .stretch .min ()]
548559 )
549- # Compute adaptive step size `L `
560+ # Compute adaptive step size `step_size `
550561 if self ._prev_comps is None :
551- L = L0
562+ step_size = initial_step_size
552563 else :
553564 num = np .sum (
554565 (self .grad_comps - self ._prev_grad_comps ) * (self .comps - self ._prev_comps )
555566 ) # Elem-wise multiply
556567 denom = np .linalg .norm (self .comps - self ._prev_comps , "fro" ) ** 2 # Frobenius norm squared
557- L = num / denom if denom > 0 else L0
558- if L <= 0 :
559- L = L0
568+ step_size = num / denom if denom > 0 else initial_step_size
569+ if step_size <= 0 :
570+ step_size = initial_step_size
560571
561572 # Store our old component matrix before updating because it is used in step selection
562573 self ._prev_comps = self .comps .copy ()
563574
564575 while True : # iterate updating components
565- comps_step = self ._prev_comps - self .grad_comps / L
576+ comps_step = self ._prev_comps - self .grad_comps / step_size
566577 # Solve x^3 + p*x + q = 0 for the largest real root
567- self .comps = np .square (cubic_largest_real_root (- comps_step , self .eta / (2 * L )))
578+ self .comps = np .square (cubic_largest_real_root (- comps_step , self .eta / (2 * step_size )))
568579 # Mask values that should be set to zero
569- mask = self .comps ** 2 * L / 2 - L * self .comps * comps_step + self .eta * np .sqrt (self .comps ) < 0
580+ mask = (
581+ self .comps ** 2 * step_size / 2
582+ - step_size * self .comps * comps_step
583+ + self .eta * np .sqrt (self .comps )
584+ < 0
585+ )
570586 self .comps = mask * self .comps
571587
572588 objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
@@ -576,9 +592,9 @@ def update_comps(self):
576592 # Check if objective function improves
577593 if objective_improvement > 0 :
578594 break
579- # If not, increase L (step size)
580- L *= 2
581- if np .isinf (L ):
595+ # If not, increase step_size (step size)
596+ step_size *= 2
597+ if np .isinf (step_size ):
582598 break
583599
584600 def update_weights (self ):
@@ -587,7 +603,7 @@ def update_weights(self):
587603 """
588604
589605 for m in range (self ._num_conditions ):
590- T = np .zeros ((self ._signal_len , self ._n_components )) # Initialize T as an (N, K) zero matrix
606+ T = np .zeros ((self ._signal_len , self ._n_components ))
591607
592608 # Populate T using apply_interpolation
593609 for k in range (self ._n_components ):
0 commit comments