Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 9, 2023
1 parent 93a418a commit 0176520
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions serket/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,4 @@ def map_func(view):
result = jax.vmap(map_func)(views)
return result.reshape(*output_shape, *result.shape[1:])

return single_call_wrapper


def frozen_field(**kwargs):
"""Freeze a field after setting it and unfreeze it before getting it."""
# this is useful for setting a field that is not a jax-type
# to allow the class to be passed across jax-boundaries]
return sk.field(
on_getattr=[*kwargs.pop("on_getattr", []), sk.unfreeze],
on_setattr=[*kwargs.pop("on_setattr", []), sk.freeze],
)
return single_call_wrapper

0 comments on commit 0176520

Please sign in to comment.