Skip to content

Commit

Permalink
Fixed dangerous default values (i just learned what these were)
Browse files Browse the repository at this point in the history
  • Loading branch information
danyoungday committed Apr 26, 2024
1 parent 3f11496 commit 014da6e
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions use_cases/eluc/predictors/neural_network/neural_net_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,16 @@ class NeuralNetPredictor(Predictor):
in order to take advantage of the linear relationship in the data.
Data is automatically standardized and the scaler is saved with the model.
"""
def __init__(self, features=None, label=None, hidden_sizes=[4096], linear_skip=True,
dropout=0, device="mps", epochs=3, batch_size=2048, optim_params={},
train_pct=1, step_lr_params={"step_size": 1, "gamma": 0.1}):
def __init__(self, features=None, label=None, hidden_sizes=None, linear_skip=True,
dropout=0, device="mps", epochs=3, batch_size=2048, optim_params=None,
train_pct=1, step_lr_params=None):
# Fix dangerous default param values
if not step_lr_params:
step_lr_params = {"step_size": 1, "gamma": 0.1}
if not hidden_sizes:
hidden_sizes = [4096]
if not optim_params:
optim_params = {}

self.features=None
self.label=None
Expand Down

0 comments on commit 014da6e

Please sign in to comment.