Skip to content

Commit

Permalink
TST Refactor test fixture to be more atomic (skops-dev#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Aho authored Oct 7, 2022
1 parent 4d8d70f commit b069c5f
Showing 1 changed file with 58 additions and 25 deletions.
83 changes: 58 additions & 25 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,70 @@ def model_card(model_diagram=True):


@pytest.fixture
def model_card_metadata_from_config(destination_path):
def iris_data():
X, y = load_iris(return_X_y=True, as_frame=True)
yield X, y


@pytest.fixture
def iris_estimator(iris_data):
X, y = iris_data
est = LogisticRegression(solver="liblinear").fit(X, y)
yield est


@pytest.fixture
def iris_pkl_file(iris_estimator):
pkl_file = tempfile.mkstemp(suffix=".pkl", prefix="skops-test")[1]
with open(pkl_file, "wb") as f:
pickle.dump(est, f)
pickle.dump(iris_estimator, f)
yield pkl_file


@pytest.fixture
def iris_skops_file(iris_estimator):
skops_folder = tempfile.mkdtemp()
model_name = "model.skops"
skops_path = Path(skops_folder) / model_name
save(iris_estimator, skops_path)
yield skops_path


def _create_model_card_from_saved_model(
destination_path,
iris_estimator,
iris_data,
save_file,
):
X, y = iris_data
hub_utils.init(
model=pkl_file,
model=save_file,
requirements=[f"scikit-learn=={sklearn.__version__}"],
dst=destination_path,
task="tabular-classification",
data=X,
)
card = Card(
est, model_diagram=True, metadata=metadata_from_config(destination_path)
card = Card(iris_estimator, metadata=metadata_from_config(destination_path))
card.save(Path(destination_path) / "README.md")
return card


@pytest.fixture
def skops_model_card_metadata_from_config(
destination_path, iris_estimator, iris_skops_file, iris_data
):
yield _create_model_card_from_saved_model(
destination_path, iris_estimator, iris_data, iris_skops_file
)


@pytest.fixture
def pkl_model_card_metadata_from_config(
destination_path, iris_estimator, iris_pkl_file, iris_data
):
yield _create_model_card_from_saved_model(
destination_path, iris_estimator, iris_data, iris_pkl_file
)
yield card


@pytest.fixture
Expand Down Expand Up @@ -159,42 +206,28 @@ def test_add_metrics(destination_path, model_card):
assert ("acc" in card) and ("f1" in card) and ("0.1" in card)


def test_code_autogeneration(destination_path, model_card_metadata_from_config):
def test_code_autogeneration(destination_path, pkl_model_card_metadata_from_config):
# test if getting started code is automatically generated
model_card_metadata_from_config.save(Path(destination_path) / "README.md")
metadata = metadata_load(local_path=Path(destination_path) / "README.md")
filename = metadata["model_file"]
with open(Path(destination_path) / "README.md") as f:
assert f"joblib.load({filename})" in f.read()


def test_code_autogeneration_skops(destination_path):
def test_code_autogeneration_skops(
destination_path, skops_model_card_metadata_from_config
):
# test if getting started code is automatically generated for skops format
X, y = load_iris(return_X_y=True, as_frame=True)
model = fit_model()
skops_folder = tempfile.mkdtemp()
model_name = "model.skops"
save(model, Path(skops_folder) / model_name)
hub_utils.init(
model=Path(skops_folder) / model_name,
requirements=[f"scikit-learn=={sklearn.__version__}"],
dst=destination_path,
task="tabular-classification",
data=X,
)
card = Card(model, metadata=metadata_from_config(destination_path))
card.save(Path(destination_path) / "README.md")
metadata = metadata_load(local_path=Path(destination_path) / "README.md")
filename = metadata["model_file"]
with open(Path(destination_path) / "README.md") as f:
assert f'clf = load("{filename}")' in f.read()


def test_metadata_from_config_tabular_data(
model_card_metadata_from_config, destination_path
pkl_model_card_metadata_from_config, destination_path
):
# test if widget data is correctly set in the README
model_card_metadata_from_config.save(Path(destination_path) / "README.md")
metadata = metadata_load(local_path=Path(destination_path) / "README.md")
assert "widget" in metadata

Expand Down

0 comments on commit b069c5f

Please sign in to comment.