diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 99254a798f..5f160c90de 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -105,8 +105,24 @@ def qr( except AttributeError: q, r = a.larray.qr(some=False) - q = factories.array(q, device=a.device, comm=a.comm) - r = factories.array(r, device=a.device, comm=a.comm) + q = DNDarray( + q, + gshape=q.shape, + dtype=a.dtype, + split=a.split, + device=a.device, + comm=a.comm, + balanced=True, + ) + r = DNDarray( + r, + gshape=r.shape, + dtype=a.dtype, + split=a.split, + device=a.device, + comm=a.comm, + balanced=True, + ) ret = QR(q if calc_q else None, r) return ret # =============================== Prep work ==================================================== diff --git a/heat/core/linalg/svdtools.py b/heat/core/linalg/svdtools.py index 3ff273a79c..5b6bbd12fc 100644 --- a/heat/core/linalg/svdtools.py +++ b/heat/core/linalg/svdtools.py @@ -446,9 +446,27 @@ def hsvd( A.comm.Bcast(U_loc, root=0) # separate U_loc and err_squared_loc again err_squared_loc = U_loc[-1, 0] - U = factories.array(U_loc[:-1], device=A.device, split=None, comm=A.comm) + U_shape = U_loc[:-1].shape + U = DNDarray( + U_loc[:-1], + gshape=U_shape, + dtype=A.dtype, + device=A.device, + split=None, + comm=A.comm, + balanced=True, + ) rel_error_estimate = ( - factories.array(err_squared_loc**0.5, device=A.device, split=None, comm=A.comm) / Anorm + DNDarray( + err_squared_loc**0.5, + gshape=err_squared_loc.shape, + dtype=A.dtype, + device=A.device, + split=None, + comm=A.comm, + balanced=True, + ) + / Anorm ) # Postprocessing: