@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
3737  solve (const  Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
3838        ThreadPool *pool) const ;
3939
40-   template  <class  _Scalar , int  _Rows, int  _Cols>
41-   Eigen::Matrix<_Scalar, _Rows, _Cols>
42-   sqrt_solve (const  Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43-              ThreadPool *pool) const ;
40+   template  <typename  Derived>
41+   Eigen::MatrixXd sqrt_solve (const  Eigen::DenseBase<Derived> &rhs,
42+                              ThreadPool *pool) const ;
4443
4544  BlockDiagonal sqrt_transpose () const ;
4645
@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
5150  Eigen::Index rows () const ;
5251
5352  Eigen::Index cols () const ;
53+ 
54+   bool  operator ==(const  BlockDiagonalLDLT &other) const ;
5455};
5556
5657struct  BlockDiagonal  {
@@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141142  return  output;
142143}
143144
144- template  <class   _Scalar ,  int  _Rows,  int  _Cols >
145- inline  Eigen::Matrix<_Scalar, _Rows, _Cols> 
146- BlockDiagonalLDLT::sqrt_solve (const  Eigen::Matrix<_Scalar, _Rows, _Cols > &rhs,
145+ template  <typename  Derived >
146+ inline  Eigen::MatrixXd 
147+ BlockDiagonalLDLT::sqrt_solve (const  Eigen::DenseBase<Derived > &rhs,
147148                              ThreadPool *pool) const  {
148149  ALBATROSS_ASSERT (cols () == rhs.rows ());
149-   Eigen::Matrix<_Scalar, _Rows, _Cols>  output (rows (), rhs.cols ());
150+   Eigen::MatrixXd  output (rows (), rhs.cols ());
150151
151152  auto  solve_and_fill_one_block = [&](const  size_t  i, const  Eigen::Index row) {
152-     const  auto  rhs_chunk = rhs.block (row, 0 , blocks[i].rows (), rhs.cols ());
153+     const  auto  rhs_chunk =
154+         rhs.derived ().block (row, 0 , blocks[i].rows (), rhs.cols ());
153155    output.block (row, 0 , blocks[i].rows (), rhs.cols ()) =
154156        blocks[i].sqrt_solve (rhs_chunk);
155157  };
156158
157-   apply_map (block_to_row_map (), solve_and_fill_one_block, pool);
159+   //  Intentionally leaving pool out here due to an unknown bug
160+   //  in which the thread pool version crashes in sqrt_solve.
161+   apply_map (block_to_row_map (), solve_and_fill_one_block);
158162  return  output;
159163}
160164
@@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182186  return  n;
183187}
184188
189+ inline  bool 
190+ BlockDiagonalLDLT::operator ==(const  BlockDiagonalLDLT &other) const  {
191+   return  blocks == other.blocks ;
192+ }
185193/* 
186194 * Block Diagonal 
187195 */  
0 commit comments