Skip to content

Commit

Permalink
Fixes imports from _copt.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Sep 11, 2024
1 parent c4de55f commit 94d5077
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions groupyr/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Create base classes based on the sparse group lasso."""

import contextlib
from . import _copt as cp
import numpy as np
import warnings

Expand All @@ -18,6 +17,10 @@
check_is_fitted,
)

from groupyr._copt.loss import SquareLoss, HuberLoss, LogLoss
import groupyr._copt.utils as cp_utils
from groupyr._copt.proximal_gradient import minimize_proximal_gradient

from ._prox import SparseGroupL1
from .utils import check_groups

Expand Down Expand Up @@ -212,14 +215,14 @@ def fit(self, X, y, loss="squared_loss"):
coef = np.zeros(n_features)

if loss == "huber":
f = cp.loss.HuberLoss(X, y)
f = HuberLoss(X, y)
elif loss == "log":
f = cp.loss.LogLoss(X, y)
f = LogLoss(X, y)
else:
f = cp.loss.SquareLoss(X, y)
f = SquareLoss(X, y)

if self.include_solver_trace:
self.solver_trace_ = cp.utils.Trace(f)
self.solver_trace_ = cp_utils.Trace(f)
else:
self.solver_trace_ = None

Expand Down Expand Up @@ -255,7 +258,7 @@ def fit(self, X, y, loss="squared_loss"):
if self.suppress_solver_warnings:
warnings.filterwarnings("ignore", category=RuntimeWarning)

pgd = cp.minimize_proximal_gradient(
pgd = minimize_proximal_gradient(
f.f_grad,
coef,
sg1.prox,
Expand Down

0 comments on commit 94d5077

Please sign in to comment.