1- use  ndarray:: { Array2 ,  ArrayView1 ,  ArrayView2 ,  Zip } ; 
1+ use  ndarray:: { Array2 ,  ArrayView1 ,  ArrayView2 ,  ArrayViewMut1 } ; 
22use  numpy:: { PyArray2 ,  PyReadonlyArray1 ,  PyReadonlyArray2 } ; 
33use  pyo3:: prelude:: * ; 
44use  rayon:: prelude:: * ; 
5- use  std:: sync:: atomic:: { AtomicUsize ,  Ordering } ; 
5+ use  std:: sync:: atomic:: { AtomicBool ,  Ordering } ; 
66use  std:: sync:: Arc ; 
77
8- mod  internal { 
9-     pub ( super )  fn  sad_converged ( a :  & [ f64 ] ,  b :  & [ f64 ] ,  tol :  f64 )  -> bool  { 
10-         a. iter ( ) . zip ( b) . all ( |( & x,  & y) | ( x - y) . abs ( )  < tol) 
8+ 
9+ #[ allow( non_snake_case) ]  
10+ pub  fn  demean_impl ( X :  & mut  Array2 < f64 > ,  D :  ArrayView2 < usize > ,  weights :  ArrayView1 < f64 > ,  tol :  f64 ,  iterations :  usize )  -> bool  { 
11+     let  nsamples = X . nrows ( ) ; 
12+     let  nfactors = D . ncols ( ) ; 
13+     let  success = Arc :: new ( AtomicBool :: new ( true ) ) ; 
14+     let  group_weights = FactorGroupWeights :: new ( & D ,  & weights) ; 
15+     
16+     X . axis_iter_mut ( ndarray:: Axis ( 1 ) ) 
17+         . into_par_iter ( ) 
18+         . for_each ( |mut  column| { 
19+             let  mut  demeaner = ColumnDemeaner :: new ( nsamples,  group_weights. width ) ; 
20+             
21+             for  _ in  0 ..iterations { 
22+                 for  i in  0 ..nfactors { 
23+                     demeaner. demean_column ( 
24+                         & mut  column,  
25+                         & weights,  
26+                         & D . column ( i) ,  
27+                         group_weights. factor_weight_slice ( i) 
28+                     ) ; 
29+                 } 
30+ 
31+                 demeaner. check_convergence ( & column. view ( ) ,  tol) ; 
32+                 if  demeaner. converged  { 
33+                     break ; 
34+                 } 
35+             } 
36+ 
37+             if  !demeaner. converged  { 
38+                 // We can use a relaxed ordering since we only ever go from true to false 
39+                 // and it doesn't matter how many times we do this. 
40+                 success. store ( false ,  Ordering :: Relaxed ) ; 
41+             } 
42+         } ) ; 
43+ 
44+     success. load ( Ordering :: Relaxed ) 
45+ } 
46+ 
47+ // The column demeaner is in charge of subtracting group means until convergence. 
48+ struct  ColumnDemeaner  { 
49+     converged :  bool , 
50+     checkpoint :  Vec < f64 > , 
51+     group_sums :  Vec < f64 > , 
52+ } 
53+ 
54+ impl  ColumnDemeaner  { 
55+     fn  new ( n :  usize ,  k :  usize )  -> Self  { 
56+         Self  { 
57+             converged :  false , 
58+             checkpoint :  vec ! [ 0.0 ;  n] , 
59+             group_sums :  vec ! [ 0.0 ;  k] , 
60+         } 
1161    } 
1262
13-     pub ( super )  fn  subtract_weighted_group_mean ( 
14-         x :  & mut  [ f64 ] , 
15-         sample_weights :  & [ f64 ] , 
16-         group_ids :  & [ usize ] , 
63+     fn  demean_column ( 
64+         & mut  self , 
65+         x :  & mut  ArrayViewMut1 < f64 > , 
66+         weights :  & ArrayView1 < f64 > , 
67+         groups :  & ArrayView1 < usize > , 
1768        group_weights :  & [ f64 ] , 
18-         group_weighted_sums :  & mut  [ f64 ] , 
1969    )  { 
20-         group_weighted_sums. fill ( 0.0 ) ; 
70+         self . group_sums . fill ( 0.0 ) ; 
71+         
72+         // Compute group sums 
73+         for  ( ( & xi,  & wi) ,  & gid)  in  x. iter ( ) . zip ( weights) . zip ( groups)  { 
74+             self . group_sums [ gid]  += wi *  xi; 
75+         } 
2176
22-         // Accumulate weighted  sums per group  
23-         x . iter ( ) 
24-             . zip ( sample_weights ) 
25-             . zip ( group_ids ) 
26-             . for_each ( |( ( & xi ,  & wi ) ,   & gid ) | { 
27-                 group_weighted_sums [ gid ]  += wi  *  xi ; 
77+         // Convert  sums to means  
78+         self . group_sums 
79+             . iter_mut ( ) 
80+             . zip ( group_weights . iter ( ) ) 
81+             . for_each ( |( sum ,  & weight ) | { 
82+                 * sum /= weight 
2883            } ) ; 
2984
30-         // Compute group means 
31-         let  group_means:  Vec < f64 >  = group_weighted_sums
32-             . iter ( ) 
33-             . zip ( group_weights) 
34-             . map ( |( & sum,  & weight) | sum / weight) 
35-             . collect ( ) ; 
36- 
37-         // Subtract means from each sample 
38-         x. iter_mut ( ) . zip ( group_ids) . for_each ( |( xi,  & gid) | { 
39-             * xi -= group_means[ gid] ; 
40-         } ) ; 
85+         // Subtract group means 
86+         for  ( xi,  & gid)  in  x. iter_mut ( ) . zip ( groups)  { 
87+             * xi -= self . group_sums [ gid]  // Really these are means now 
88+         } 
4189    } 
4290
43-      pub ( super )   fn   calc_group_weights ( 
44-          sample_weights :   & [ f64 ] , 
45-          group_ids :   & [ usize ] , 
46-         n_samples :   usize , 
47-         n_factors :   usize , 
48-         n_groups :   usize , 
49-     )  ->  Vec < f64 >   { 
50-         let   mut  group_weights =  vec ! [ 0.0 ;  n_factors  *  n_groups ] ; 
51-         for  i  in   0 ..n_samples  { 
52-             let  weight = sample_weights [ i ] ; 
53-             for  j  in   0 ..n_factors  { 
54-                 let  id = group_ids [ i  *  n_factors + j ] ; 
55-                 group_weights [ j  *  n_groups + id ]  += weight ; 
56-             } 
57-         } 
58-         group_weights 
91+ 
92+     // Check elementwise convergence and update checkpoint 
93+     fn   check_convergence ( 
94+         & mut   self , 
95+         x :   & ArrayView1 < f64 > , 
96+         tol :   f64 , 
97+     )  { 
98+         self . converged  =  true ;   // Innocent until proven guilty 
99+         x . iter ( ) 
100+             . zip ( self . checkpoint . iter_mut ( ) ) 
101+             . for_each ( | ( & xi ,  cp ) |  { 
102+                 if   ( xi -  * cp ) . abs ( )  > tol  { 
103+                      self . converged  =  false ;   // Guilty! 
104+                  } 
105+                  * cp = xi ;   // Update checkpoint 
106+              } ) ; 
59107    } 
60108} 
61109
62- fn  demean_impl ( 
63-     x :  & ArrayView2 < f64 > , 
64-     flist :  & ArrayView2 < usize > , 
65-     weights :  & ArrayView1 < f64 > , 
66-     tol :  f64 , 
67-     maxiter :  usize , 
68- )  -> ( Array2 < f64 > ,  bool )  { 
69-     let  ( n_samples,  n_features)  = x. dim ( ) ; 
70-     let  n_factors = flist. ncols ( ) ; 
71-     let  n_groups = flist. iter ( ) . cloned ( ) . max ( ) . unwrap ( )  + 1 ; 
72- 
73-     let  sample_weights:  Vec < f64 >  = weights. iter ( ) . cloned ( ) . collect ( ) ; 
74-     let  group_ids:  Vec < usize >  = flist. iter ( ) . cloned ( ) . collect ( ) ; 
75-     let  group_weights =
76-         internal:: calc_group_weights ( & sample_weights,  & group_ids,  n_samples,  n_factors,  n_groups) ; 
77- 
78-     let  not_converged = Arc :: new ( AtomicUsize :: new ( 0 ) ) ; 
79- 
80-     // Precompute slices of group_ids for each factor 
81-     let  group_ids_by_factor:  Vec < Vec < usize > >  = ( 0 ..n_factors) 
82-         . map ( |j| { 
83-             ( 0 ..n_samples) 
84-                 . map ( |i| group_ids[ i *  n_factors + j] ) 
85-                 . collect ( ) 
86-         } ) 
87-         . collect ( ) ; 
88- 
89-     // Precompute group weight slices 
90-     let  group_weight_slices:  Vec < & [ f64 ] >  = ( 0 ..n_factors) 
91-         . map ( |j| & group_weights[ j *  n_groups..( j + 1 )  *  n_groups] ) 
92-         . collect ( ) ; 
93- 
94-     let  process_column = |( k,  mut  col) :  ( usize ,  ndarray:: ArrayViewMut1 < f64 > ) | { 
95-         let  mut  xk_curr:  Vec < f64 >  = ( 0 ..n_samples) . map ( |i| x[ [ i,  k] ] ) . collect ( ) ; 
96-         let  mut  xk_prev:  Vec < f64 >  = xk_curr. iter ( ) . map ( |& v| v - 1.0 ) . collect ( ) ; 
97-         let  mut  gw_sums = vec ! [ 0.0 ;  n_groups] ; 
98- 
99-         let  mut  converged = false ; 
100-         for  _ in  0 ..maxiter { 
101-             for  j in  0 ..n_factors { 
102-                 internal:: subtract_weighted_group_mean ( 
103-                     & mut  xk_curr, 
104-                     & sample_weights, 
105-                     & group_ids_by_factor[ j] , 
106-                     group_weight_slices[ j] , 
107-                     & mut  gw_sums, 
108-                 ) ; 
109-             } 
110+ // Instead of recomputing the denominators for the weighted group averages every time, 
111+ // we'll precompute them and store them in a grid-like structure. The grid will have 
112+ // dimensions (m, k) where m is the number of factors and k is the maximum group ID.  
113+ struct  FactorGroupWeights  { 
114+     values :  Vec < f64 > , 
115+     width :  usize , 
116+ } 
110117
111-             if  internal:: sad_converged ( & xk_curr,  & xk_prev,  tol)  { 
112-                 converged = true ; 
113-                 break ; 
118+ impl  FactorGroupWeights  { 
119+     fn  new ( flist :  & ArrayView2 < usize > ,  weights :  & ArrayView1 < f64 > )  -> Self  { 
120+         let  n_samples = flist. nrows ( ) ; 
121+         let  n_factors = flist. ncols ( ) ; 
122+         let  width = flist. iter ( ) . max ( ) . unwrap ( )  + 1 ; 
123+ 
124+         let  mut  values = vec ! [ 0.0 ;  n_factors *  width] ; 
125+         for  i in  0 ..n_samples { 
126+             let  weight = weights[ i] ; 
127+             for  j in  0 ..n_factors { 
128+                 let  id = flist[ [ i,  j] ] ; 
129+                 values[ j *  width + id]  += weight; 
114130            } 
115-             xk_prev. copy_from_slice ( & xk_curr) ; 
116131        } 
117132
118-         if  !converged { 
119-             not_converged. fetch_add ( 1 ,  Ordering :: SeqCst ) ; 
133+         Self  { 
134+             values, 
135+             width, 
120136        } 
121-         Zip :: from ( & mut  col) . and ( & xk_curr) . for_each ( |col_elm,  & val| { 
122-             * col_elm = val; 
123-         } ) ; 
124-     } ; 
125- 
126-     let  mut  res = Array2 :: < f64 > :: zeros ( ( n_samples,  n_features) ) ; 
127- 
128-     res. axis_iter_mut ( ndarray:: Axis ( 1 ) ) 
129-         . into_par_iter ( ) 
130-         . enumerate ( ) 
131-         . for_each ( process_column) ; 
137+     } 
132138
133-     let  success = not_converged. load ( Ordering :: SeqCst )  == 0 ; 
134-     ( res,  success) 
139+     fn  factor_weight_slice ( & self ,  factor_index :  usize )  -> & [ f64 ]  { 
140+         & self . values [ factor_index *  self . width ..( factor_index + 1 )  *  self . width ] 
141+     } 
135142} 
136143
137144
@@ -196,7 +203,6 @@ fn demean_impl(
196203/// print(x_demeaned) 
197204/// print("Converged:", converged) 
198205/// ``` 
199- 
200206#[ pyfunction]  
201207#[ pyo3( signature = ( x,  flist,  weights,  tol=1e-8 ,  maxiter=100_000 ) ) ]  
202208pub  fn  _demean_rs ( 
@@ -207,13 +213,20 @@ pub fn _demean_rs(
207213    tol :  f64 , 
208214    maxiter :  usize , 
209215)  -> PyResult < ( Py < PyArray2 < f64 > > ,  bool ) >  { 
210-     let  x_arr = x. as_array ( ) ; 
211-     let  flist_arr = flist. as_array ( ) ; 
212-     let  weights_arr = weights. as_array ( ) ; 
216+     let  mut  x_array = x. as_array ( ) . to_owned ( ) ; 
217+     let  flist_array = flist. as_array ( ) ; 
218+     let  weights_array = weights. as_array ( ) ; 
219+ 
220+     let  converged = demean_impl ( 
221+         & mut  x_array, 
222+         flist_array, 
223+         weights_array, 
224+         tol, 
225+         maxiter, 
226+     ) ; 
227+ 
228+     let  pyarray = PyArray2 :: from_owned_array ( py,  x_array) ; 
229+     Ok ( ( pyarray. into_py ( py) ,  converged) ) 
230+ } 
213231
214-     let  ( out,  success)  =
215-         py. allow_threads ( || demean_impl ( & x_arr,  & flist_arr,  & weights_arr,  tol,  maxiter) ) ; 
216232
217-     let  pyarray = PyArray2 :: from_owned_array ( py,  out) ; 
218-     Ok ( ( pyarray. into_py ( py) ,  success) ) 
219- } 
0 commit comments