@@ -608,24 +608,29 @@ def update_components(self):
608608
609609 def update_weights (self ):
610610 """
611- Updates weights using matrix operations, solving a quadratic program to do so.
611+ Updates weights by building the stretched component matrix `stretched_comps` with np.interp
612+ and solving a quadratic program for each signal.
612613 """
613614
614- signal_length = self .signal_length
615- n_signals = self .n_signals
616-
617- for m in range (n_signals ):
618- t = np .zeros ((signal_length , self .n_components ))
619-
620- # Populate t using apply_interpolation
621- for k in range (self .n_components ):
622- t [:, k ] = apply_interpolation (self .stretch [k , m ], self .components [:, k ]).squeeze ()
623-
624- # Solve quadratic problem for y
625- y = self .solve_quadratic_program (t = t , m = m )
615+ sample_indices = np .arange (self .signal_length )
616+ for signal in range (self .n_signals ):
617+ # Stretch factors for this signal across components:
618+ this_stretch = self .stretch [:, signal ]
619+ # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
620+ stretched_comps = np .empty ((self .signal_length , self .n_components ), dtype = self .components .dtype )
621+ for comp in range (self .n_components ):
622+ pos = sample_indices / this_stretch [comp ]
623+ stretched_comps [:, comp ] = np .interp (
624+ pos ,
625+ sample_indices ,
626+ self .components [:, comp ],
627+ left = self .components [0 , comp ],
628+ right = self .components [- 1 , comp ],
629+ )
626630
627- # Update Y
628- self .weights [:, m ] = y
631+ # Solve quadratic problem for a given signal and update its weight
632+ new_weight = self .solve_quadratic_program (t = stretched_comps , m = signal )
633+ self .weights [:, signal ] = new_weight
629634
630635 def regularize_function (self , stretch = None ):
631636 if stretch is None :
@@ -712,37 +717,3 @@ def cubic_largest_real_root(p, q):
712717 y = np .max (real_roots , axis = 0 ) * (delta < 0 ) # Keep only real roots when delta < 0
713718
714719 return y
715-
716-
717- def apply_interpolation (a , x ):
718- """
719- Applies an interpolation-based transformation to `x` based on scaling `a`.
720- """
721- x_len = len (x )
722-
723- # Ensure `a` is an array and reshape for broadcasting
724- a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
725-
726- # Compute fractional indices, broadcasting over `a`
727- fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
728-
729- integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
730- valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
731-
732- # Apply valid_mask to keep correct indices
733- idx_int = np .where (valid_mask , integer_indices , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
734- idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
735-
736- # Ensure x is a 1D array
737- x = np .asarray (x ).ravel ()
738-
739- # Compute interpolated_x (linear interpolation)
740- interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
741- idx_frac - idx_int
742- )
743-
744- # Fill the tail with the last valid value
745- intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
746- interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
747-
748- return interpolated_x
0 commit comments