Skip to content

Commit

Permalink
[WIP] Harmonize fixtures (#248)
Browse files Browse the repository at this point in the history
* fixtures return dataset

* fix test_mapping

* fix test_selector

* fix test label_prop

* fix test reweight

* fix test ot

* fix test pipeline

* getting rid of wrong fixture

* fix test scorer

* bring back da_binary_dataset
  • Loading branch information
antoinecollas authored Oct 4, 2024
1 parent 78b2806 commit e80e205
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 26 deletions.
16 changes: 9 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ def set_seed():

@pytest.fixture(scope='session')
def da_reg_dataset():
X, y, sample_domain = make_shifted_datasets(
return make_shifted_datasets(
n_samples_source=20,
n_samples_target=21,
shift="concept_drift",
mean=0.5,
noise=0.3,
label="regression",
random_state=43,
return_dataset=True,
)
return X, y, sample_domain


@pytest.fixture(scope='session')
Expand All @@ -45,6 +45,7 @@ def da_reg_datasets():
noise=0.3,
label="regression",
random_state=42,
return_dataset=True,
)

da_reg_dataset_2 = make_shifted_datasets(
Expand All @@ -55,33 +56,34 @@ def da_reg_datasets():
noise=0.3,
label="regression",
random_state=42,
return_dataset=True,
)
return da_reg_dataset_1, da_reg_dataset_2

@pytest.fixture(scope='session')
def da_multiclass_dataset():
X, y, sample_domain = make_shifted_datasets(
return make_shifted_datasets(
n_samples_source=20,
n_samples_target=21,
shift="concept_drift",
noise=0.1,
label="multiclass",
random_state=42,
return_dataset=True,
)
return X, y, sample_domain


@pytest.fixture(scope='session')
def da_binary_dataset():
X, y, sample_domain = make_shifted_datasets(
return make_shifted_datasets(
n_samples_source=20,
n_samples_target=21,
shift="concept_drift",
noise=0.1,
label="binary",
random_state=42,
return_dataset=True,
)
return X, y, sample_domain


@pytest.fixture(scope='session')
Expand All @@ -95,7 +97,7 @@ def da_blobs_dataset():
shift=0.13,
random_state=42,
cluster_std=0.05,
return_X_y=True,
return_dataset=True,
)


Expand Down
4 changes: 2 additions & 2 deletions skada/tests/test_label_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
],
)
def test_label_prop_estimator(estimator, da_blobs_dataset):
X, y, sample_domain = da_blobs_dataset
X, y, sample_domain = da_blobs_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_label_prop_estimator(estimator, da_blobs_dataset):
],
)
def test_label_prop_estimator_reg(estimator, da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
Expand Down
4 changes: 2 additions & 2 deletions skada/tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
],
)
def test_mapping_estimator(estimator, da_blobs_dataset):
X, y, sample_domain = da_blobs_dataset
X, y, sample_domain = da_blobs_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_reg_new_X_adapt(estimator):
],
)
def test_mapping_source_samples(estimator, da_blobs_dataset):
X, y, sample_domain = da_blobs_dataset
X, y, sample_domain = da_blobs_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
Expand Down
10 changes: 5 additions & 5 deletions skada/tests/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def test_JDOTRegressor(da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
rng = np.random.default_rng(42)
w = rng.uniform(size=(X.shape[0],))

Expand All @@ -42,7 +42,7 @@ def test_JDOTRegressor(da_reg_dataset):


def test_JDOTRegressor_pipeline(da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

jdot = make_da_pipeline(
Expand All @@ -60,7 +60,7 @@ def test_JDOTRegressor_pipeline(da_reg_dataset):
def test_JDOTClassifier(da_multiclass_dataset, da_binary_dataset):
rng = np.random.default_rng(43)
for dataset in [da_multiclass_dataset, da_binary_dataset]:
X, y, sample_domain = dataset
X, y, sample_domain = dataset.pack(as_sources=["s"], as_targets=["t"])
w = rng.uniform(size=(X.shape[0],))
Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

Expand Down Expand Up @@ -123,7 +123,7 @@ def test_JDOTClassifier(da_multiclass_dataset, da_binary_dataset):


def test_jdot_class_cost_matrix(da_multiclass_dataset):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

Expand Down Expand Up @@ -157,7 +157,7 @@ def test_jdot_class_cost_matrix(da_multiclass_dataset):


def test_jdot_class_tgt_loss(da_multiclass_dataset):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

Expand Down
2 changes: 1 addition & 1 deletion skada/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def fit_transform(self, X, y=None, **params):


def test_adaptation_output_propagate_labels(da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
_, X_target, _, target_domain = source_target_split(
X, sample_domain, sample_domain=sample_domain
)
Expand Down
8 changes: 5 additions & 3 deletions skada/tests/test_reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def test_reg_reweight_estimator(estimator):


def _base_test_new_X_adapt(estimator, da_dataset):
X_train, y_train, sample_domain = da_dataset
X_train, y_train, sample_domain = da_dataset.pack_train(
as_sources=["s"], as_targets=["t"]
)

# fit works with no errors
estimator.fit(X_train, y_train, sample_domain=sample_domain)
Expand Down Expand Up @@ -281,7 +283,7 @@ def test_KMMReweight_new_X_adapt(da_dataset):
],
)
def test_adaptation_output_propagation_multiple_steps(da_reg_dataset, mediator):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
_, X_target, _, target_domain = source_target_split(
X, sample_domain, sample_domain=sample_domain
)
Expand All @@ -307,7 +309,7 @@ def predict(self, X, sample_weight=None):


def test_select_source_target_output_merge(da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])
_, X_target, _, target_domain = source_target_split(
X, sample_domain, sample_domain=sample_domain
)
Expand Down
2 changes: 1 addition & 1 deletion skada/tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_mixval_scorer(da_dataset):


def test_mixval_scorer_regression(da_reg_dataset):
X, y, sample_domain = da_reg_dataset
X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"])

estimator = make_da_pipeline(DensityReweightAdapter(), LinearRegression())

Expand Down
10 changes: 5 additions & 5 deletions skada/tests/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def transform(
],
)
def test_source_selector_with_estimator(da_multiclass_dataset, selector_cls, side):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target = source_target_split(X, sample_domain=sample_domain)
output = {}

Expand Down Expand Up @@ -240,7 +240,7 @@ def predict(self, X):
def test_source_selector_with_transformer(
da_multiclass_dataset, selector_cls, side, _fit_transform
):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])
X_source, X_target = source_target_split(X, sample_domain=sample_domain)
output = {}

Expand Down Expand Up @@ -282,7 +282,7 @@ def fit_transform(self, X, y=None):
],
)
def test_source_selector_with_weights(da_multiclass_dataset, selector_cls, side):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])
sample_weight = np.ones(X.shape[0])
X_source, X_target = source_target_split(X, sample_domain=sample_domain)
output = {}
Expand Down Expand Up @@ -318,7 +318,7 @@ def predict(self, X, sample_weight=None):
def test_source_target_selector(
da_multiclass_dataset, source_estimator, target_estimator
):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])
source_masks = extract_source_indices(sample_domain)
# make sure sources and targets have significantly different mean
X[source_masks] += 100 * np.ones((source_masks.sum(), X.shape[1]))
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_source_target_selector(


def test_source_target_selector_fails_on_missing_domain(da_multiclass_dataset):
X, y, sample_domain = da_multiclass_dataset
X, y, sample_domain = da_multiclass_dataset.pack(as_sources=["s"], as_targets=["t"])
source_masks = extract_source_indices(sample_domain)
pipe = make_da_pipeline(SelectSourceTarget(StandardScaler()), SVC())

Expand Down

0 comments on commit e80e205

Please sign in to comment.