Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
zekun-shi committed Jan 12, 2024
1 parent 37be812 commit 5d3ac04
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions autofd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import optax
try:
import jax
import optax

def scale_by_learning_rate(
learning_rate,
*,
flip_sign: bool = True,
):
m = -1 if flip_sign else 1
# avoid calling tracer
if callable(learning_rate
) and not isinstance(learning_rate, jax.core.Tracer):
return optax._src.transform.scale_by_schedule(
lambda count: m * learning_rate(count)
)
return optax._src.transform.scale(m * learning_rate)

optax._src.alias._scale_by_learning_rate = scale_by_learning_rate
except ImportError:
print("optax not install, skip patching")

from . import operators # noqa
from .general_array import SpecTree # noqa
Expand All @@ -23,23 +42,6 @@
is_function, num_args, random_input, with_spec, zeros_like
)


def scale_by_learning_rate(
learning_rate,
*,
flip_sign: bool = True,
):
m = -1 if flip_sign else 1
# avoid calling tracer
if callable(learning_rate) and not isinstance(learning_rate, jax.core.Tracer):
return optax._src.transform.scale_by_schedule(
lambda count: m * learning_rate(count)
)
return optax._src.transform.scale(m * learning_rate)


optax._src.alias._scale_by_learning_rate = scale_by_learning_rate

__all__ = [
"Spec",
"SpecTree",
Expand Down

0 comments on commit 5d3ac04

Please sign in to comment.