diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 84d960d..679e10b 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -275,25 +275,24 @@ pub fn qr(tensor: ExTensor) -> Result<(ExTensor, ExTensor), CandlexError> { let side = tensor.dims()[0]; let device = tensor.device(); - let qr = - nalgebra::linalg::QR::new( - nalgebra::DMatrix::from_vec( - side, - side, - tensor.t()?.flatten_all()?.to_vec1::()? - ) - ); - - Ok( - ( - ExTensor::new( - Tensor::new(qr.q().as_slice(), &device)?.reshape((side, side))?.t()? - ), - ExTensor::new( - Tensor::new(qr.r().as_slice(), &device)?.reshape((side, side))?.t()? - ) - ) - ) + let qr = nalgebra::linalg::QR::new(nalgebra::DMatrix::from_vec( + side, + side, + tensor.t()?.flatten_all()?.to_vec1::()?, + )); + + Ok(( + ExTensor::new( + Tensor::new(qr.q().as_slice(), &device)? + .reshape((side, side))? + .t()?, + ), + ExTensor::new( + Tensor::new(qr.r().as_slice(), &device)? + .reshape((side, side))? + .t()?, + ), + )) } #[rustler::nif(schedule = "DirtyCpu")]