Skip to content

Commit

Permalink
refactor: let user pass base function and not implement load method b…
Browse files Browse the repository at this point in the history
…ecause it is always the same
  • Loading branch information
JasperHG90 committed Jan 16, 2021
1 parent a1ec1cc commit 0c11de8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
12 changes: 2 additions & 10 deletions src/piven/models/mlp_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def check_model_params(


# Make build function for the model wrapper
def piven_model(
def piven_mlp_model(
input_dim, dense_units, dropout_rate, lambda_, bias_init_low, bias_init_high, lr
):
model = build_keras_piven(
Expand All @@ -74,7 +74,7 @@ class PivenMlpModel(PivenBaseModel):
def build(self, preprocess: Union[None, Pipeline, TransformerMixin] = None):
# All build params are passed to init and should be checked here
check_model_params(**self.params)
model = PivenKerasRegressor(build_fn=piven_model, **self.params)
model = PivenKerasRegressor(build_fn=piven_mlp_model, **self.params)
if preprocess is None:
pipeline = Pipeline([("model", model)])
else:
Expand All @@ -84,11 +84,3 @@ def build(self, preprocess: Union[None, Pipeline, TransformerMixin] = None):
regressor=pipeline, transformer=StandardScaler()
)
return self

@classmethod
def load(cls, path: str):
model_config = PivenMlpModel.load_model_config(path)
model = PivenMlpModel.load_model_from_disk(piven_model, path)
run = cls(**model_config)
run.model = model
return run
4 changes: 2 additions & 2 deletions tests/test_mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import tensorflow as tf
import numpy as np
from piven.models import PivenMlpModel
from piven.models import PivenMlpModel, piven_mlp_model
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_experiment_io(self, mock_data, experiment):
with tempfile.TemporaryDirectory() as tmpdir:
experiment.save(tmpdir)
assert (Path(tmpdir) / "experiment_params.json").is_file()
_ = PivenMlpModel.load(path=tmpdir)
_ = PivenMlpModel.load(path=tmpdir, build_fn=piven_mlp_model)

def test_experiment_scoring(self, mock_data, experiment):
x_train, x_valid, y_train, y_valid = mock_data
Expand Down

0 comments on commit 0c11de8

Please sign in to comment.