Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Oct 14, 2023
1 parent a41779e commit b7fc70f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
7 changes: 6 additions & 1 deletion bluecast/tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def transform(
print("Predicting successful.")
assert len(y_probs) == len(df_val.index)
assert len(y_classes) == len(df_val.index)
assert automl.experiment_tracker.experiment_id == 0
assert (
len(automl.experiment_tracker.experiment_id)
<= automl.conf_training.hyperparameter_tuning_rounds
)


class CustomModel(BaseClassMlModel):
Expand Down Expand Up @@ -250,3 +253,5 @@ def transform(
# Assert the expected results
assert isinstance(predicted_probas, np.ndarray)
assert isinstance(predicted_classes, np.ndarray)
print(bluecast.experiment_tracker.experiment_id)
assert len(bluecast.experiment_tracker.experiment_id) == 0 # due to custom model
12 changes: 9 additions & 3 deletions bluecast/tests/test_cast_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def test_blueprint_cv_xgboost(synthetic_train_test_data):
target_col="target",
)
print(automl_cv.experiment_tracker.experiment_id)
assert len(automl_cv.experiment_tracker.experiment_id) == train_config.hyperparameter_tuning_rounds
assert automl_cv.experiment_tracker.experiment_id[-1] == 9
assert (
len(automl_cv.experiment_tracker.experiment_id)
<= train_config.hyperparameter_tuning_rounds * nb_models
)
assert automl_cv.experiment_tracker.experiment_id[-1] == 22
print("Autotuning successful.")
y_probs, y_classes = automl_cv.predict(df_val.drop("target", axis=1))
print("Predicting successful.")
Expand All @@ -66,7 +69,10 @@ def test_blueprint_cv_xgboost(synthetic_train_test_data):
target_col="target",
)
assert automl_cv.stratifier
assert len(automl_cv.experiment_tracker.experiment_id) == train_config.hyperparameter_tuning_rounds
assert (
len(automl_cv.experiment_tracker.experiment_id)
<= train_config.hyperparameter_tuning_rounds
) # due to Optuna pruning
assert automl_cv.experiment_tracker.experiment_id[-1] == 9

# Assert that the bluecast_models attribute is updated
Expand Down

0 comments on commit b7fc70f

Please sign in to comment.