From 01765209dfa860157d7c8f23bbe885552e3077ad Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sun, 10 Dec 2023 00:35:33 +0900 Subject: [PATCH] Update utils.py --- serket/_src/utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/serket/_src/utils.py b/serket/_src/utils.py index c1c66bf..eafb5cd 100644 --- a/serket/_src/utils.py +++ b/serket/_src/utils.py @@ -640,14 +640,4 @@ def map_func(view): result = jax.vmap(map_func)(views) return result.reshape(*output_shape, *result.shape[1:]) - return single_call_wrapper - - -def frozen_field(**kwargs): - """Freeze a field after setting it and unfreeze it before getting it.""" - # this is useful for setting a field that is not a jax-type - # to allow the class to be passed across jax-boundaries] - return sk.field( - on_getattr=[*kwargs.pop("on_getattr", []), sk.unfreeze], - on_setattr=[*kwargs.pop("on_setattr", []), sk.freeze], - ) + return single_call_wrapper \ No newline at end of file