From 5d3ac0443d275d1d007642efbcd44dc81ab705cb Mon Sep 17 00:00:00 2001 From: Zekun Shi Date: Fri, 12 Jan 2024 17:01:23 +0800 Subject: [PATCH] fix test --- autofd/__init__.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/autofd/__init__.py b/autofd/__init__.py index 0e51d37..8fe0d61 100644 --- a/autofd/__init__.py +++ b/autofd/__init__.py @@ -13,8 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -import optax +try: + import jax + import optax + + def scale_by_learning_rate( + learning_rate, + *, + flip_sign: bool = True, + ): + m = -1 if flip_sign else 1 + # avoid calling tracer + if callable(learning_rate + ) and not isinstance(learning_rate, jax.core.Tracer): + return optax._src.transform.scale_by_schedule( + lambda count: m * learning_rate(count) + ) + return optax._src.transform.scale(m * learning_rate) + + optax._src.alias._scale_by_learning_rate = scale_by_learning_rate +except ImportError: + print("optax not install, skip patching") from . import operators # noqa from .general_array import SpecTree # noqa @@ -23,23 +42,6 @@ is_function, num_args, random_input, with_spec, zeros_like ) - -def scale_by_learning_rate( - learning_rate, - *, - flip_sign: bool = True, -): - m = -1 if flip_sign else 1 - # avoid calling tracer - if callable(learning_rate) and not isinstance(learning_rate, jax.core.Tracer): - return optax._src.transform.scale_by_schedule( - lambda count: m * learning_rate(count) - ) - return optax._src.transform.scale(m * learning_rate) - - -optax._src.alias._scale_by_learning_rate = scale_by_learning_rate - __all__ = [ "Spec", "SpecTree",