diff --git a/heat/regression/lasso.py b/heat/regression/lasso.py index b858674e4a..7d99a72454 100644 --- a/heat/regression/lasso.py +++ b/heat/regression/lasso.py @@ -149,7 +149,7 @@ def fit(self, x: DNDarray, y: DNDarray) -> None: # Looping through each coordinate for j in range(n): - X_j = ht.array(x.larray[:, j : j + 1], is_split=0) + X_j = ht.array(x.larray[:, j : j + 1], is_split=0, device=x.device, comm=x.comm) y_est = x @ theta theta_j = theta.larray[j].item()