Skip to content

Commit

Permalink
Merge pull request #316 from MilesCranmer/backend-update
Browse files Browse the repository at this point in the history
Pass through `enable_autodiff` parameter
  • Loading branch information
MilesCranmer authored Apr 22, 2023
2 parents 02c54ae + c0ffbd2 commit 06d99a7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/param_groupings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
- precision
- fast_cycle
- turbo
- enable_autodiff
- random_state
- deterministic
- warm_start
Expand Down
8 changes: 8 additions & 0 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
If you pass complex data, the corresponding complex precision
will be used (i.e., `64` for complex128, `32` for complex64).
Default is `32`.
enable_autodiff : bool
Whether to create derivative versions of operators for automatic
differentiation. This is only necessary if you wish to compute
the gradients of an expression within a custom loss function.
Default is `False`.
random_state : int, Numpy RandomState instance or None
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.
Expand Down Expand Up @@ -747,6 +752,7 @@ def __init__(
fast_cycle=False,
turbo=False,
precision=32,
enable_autodiff=False,
random_state=None,
deterministic=False,
warm_start=False,
Expand Down Expand Up @@ -839,6 +845,7 @@ def __init__(
self.fast_cycle = fast_cycle
self.turbo = turbo
self.precision = precision
self.enable_autodiff = enable_autodiff
self.random_state = random_state
self.deterministic = deterministic
self.warm_start = warm_start
Expand Down Expand Up @@ -1623,6 +1630,7 @@ def _run(self, X, y, mutated_params, weights, seed):
maxdepth=maxdepth,
fast_cycle=self.fast_cycle,
turbo=self.turbo,
enable_autodiff=self.enable_autodiff,
migration=self.migration,
hof_migration=self.hof_migration,
fraction_replaced_hof=self.fraction_replaced_hof,
Expand Down
4 changes: 2 additions & 2 deletions pysr/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "0.12.1"
__symbolic_regression_jl_version__ = "0.16.2"
__version__ = "0.12.2"
__symbolic_regression_jl_version__ = "0.17.0"

0 comments on commit 06d99a7

Please sign in to comment.