Skip to content

Commit

Permalink
Update coefficient assignment (#914)
Browse files Browse the repository at this point in the history
  • Loading branch information
kchare authored May 27, 2022
1 parent 7b8c8e0 commit fa40fa3
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 5 deletions.
1 change: 1 addition & 0 deletions dask_ml/linear_model/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def fit(self, X, y=None):
self.intercept_ = self._coef[-1]
else:
self.coef_ = self._coef
self.intercept_ = 0.0
return self

def _check_array(self, X):
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def log1p(A):

@dispatch(np.ndarray)
def add_intercept(X):
return np.concatenate([X, np.ones((X.shape[0], 1))], axis=1)
return _add_intercept(X)


def _add_intercept(x):
ones = np.ones((x.shape[0], 1), dtype=x.dtype)
return np.concatenate([ones, x], axis=1)
return np.concatenate([x, ones], axis=1)


@dispatch(da.Array) # noqa: F811
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,37 @@ def scheduler(request):
yield cluster
else:
yield not_cluster


@pytest.fixture
def medium_size_regression():
"""X, y pair for regression with N >> p.
There are many more samples in this problem than there are
features. Useful for testing stability of solutions.
"""
X, y = make_regression(
chunks=100, n_samples=500, n_features=100, n_informative=100, random_state=0
)
return X, y


@pytest.fixture
def medium_size_counts():
"""X, y pair for classification with N >> p.
The samples outnumber the total features, leading to
greater stability of the solutions. Useful for testing
the accuracy of solvers.
"""
sample_size = 500
n_features = 100
X, y = make_counts(
chunks=100,
n_samples=sample_size,
n_features=n_features,
n_informative=n_features,
random_state=0,
scale=1 / np.sqrt(n_features),
)
return X, y
92 changes: 89 additions & 3 deletions tests/linear_model/test_glm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import dask.array as da
import dask.dataframe as dd
import numpy as np
import numpy.linalg as LA
import pandas as pd
import pytest
import sklearn.linear_model
from dask.dataframe.utils import assert_eq
from dask_glm.regularizers import Regularizer
from sklearn.pipeline import make_pipeline

import dask_ml.linear_model
from dask_ml.datasets import make_classification, make_counts, make_regression
from dask_ml.linear_model import LinearRegression, LogisticRegression, PoissonRegression
from dask_ml.linear_model.utils import add_intercept
Expand Down Expand Up @@ -39,9 +42,6 @@ def get_params(self, deep=True):
return {}


X, y = make_classification(chunks=50)


def test_lr_init(solver):
LogisticRegression(solver=solver)

Expand Down Expand Up @@ -174,6 +174,15 @@ def test_add_intercept_raises_chunks():
assert m.match("Chunking is only allowed")


def test_add_intercept_ordering():
"""Tests that add_intercept gives same result for dask / numpy objects"""
X_np = np.arange(100).reshape(20, 5)
X_da = da.from_array(X_np, chunks=(20, 5))
np_result = add_intercept(X_np)
da_result = add_intercept(X_da)
da.utils.assert_eq(np_result, da_result)


def test_lr_score():
X = da.from_array(np.arange(1000).reshape(1000, 1))
lr = LinearRegression()
Expand Down Expand Up @@ -203,3 +212,80 @@ def test_logistic_predict_proba_shape():
lr.fit(X, y)
prob = lr.predict_proba(X)
assert prob.shape == (100, 2)


@pytest.mark.parametrize(
"est,data",
[
(LinearRegression, "single_chunk_regression"),
(LogisticRegression, "single_chunk_classification"),
(PoissonRegression, "single_chunk_count_classification"),
],
)
def test_model_coef_dask_numpy(est, data, request):
"""Tests that models return same coefficients and intercepts with array types"""
X, y = request.getfixturevalue(data)
np_mod, da_mod = est(fit_intercept=True), est(fit_intercept=True)
da_mod.fit(X, y)
np_mod.fit(X.compute(), y.compute())
da_coef = np.hstack((da_mod.coef_, da_mod.intercept_))
np_coef = np.hstack((np_mod.coef_, np_mod.intercept_))

rel_error = LA.norm(da_coef - np_coef) / LA.norm(np_coef)
assert rel_error < 1e-8


# fmt: off
@pytest.mark.parametrize("solver", ["newton", "lbfgs"])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize(
"est, skl_params, data_generator",
[
("LinearRegression", {}, "medium_size_regression"),
("LogisticRegression", {"penalty": "none"}, "single_chunk_classification"),
("PoissonRegression", {"alpha": 0}, "medium_size_counts"),
],
)
def test_model_against_sklearn(
est, skl_params, data_generator, fit_intercept, solver, request
):
"""
Test accuracy of model predictions and coefficients.
All tests of model coefficients are done via relative error, the
standard for optimization proofs, and by the numpy utility
``np.testing.assert_allclose``. This ensures that the model coefficients
match up with SK Learn.
"""
X, y = request.getfixturevalue(data_generator)

# sklearn uses 'PoissonRegressor' while dask-ml uses 'PoissonRegression'
assert est in ["LinearRegression", "LogisticRegression", "PoissonRegression"]
EstDask = getattr(dask_ml.linear_model, est)
EstSklearn = getattr(
sklearn.linear_model, est if "Poisson" not in est else "PoissonRegressor"
)

dask_ml_model = EstDask(
fit_intercept=fit_intercept, solver=solver, penalty="l2", C=1e8, max_iter=500
)
dask_ml_model.fit(X, y)

# skl_model has to be fit with numpy data
skl_model = EstSklearn(fit_intercept=fit_intercept, **skl_params)
skl_model.fit(X.compute(), y.compute())

# test coefficients
est, truth = np.hstack((dask_ml_model.intercept_, dask_ml_model.coef_)), np.hstack(
(skl_model.intercept_, skl_model.coef_.flatten())
)
rel_error = LA.norm(est - truth) / LA.norm(truth)
assert rel_error < 1e-3

np.testing.assert_allclose(truth, est, rtol=1e-3, atol=2e-4)

# test predictions
skl_preds = skl_model.predict(X.compute())
dml_preds = dask_ml_model.predict(X)

np.testing.assert_allclose(skl_preds, dml_preds, rtol=1e-3, atol=2e-3)

0 comments on commit fa40fa3

Please sign in to comment.