From bd3b2ec818a31157c7fd875f5debdb2d4bbe7a0b Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 23 Jul 2023 09:50:16 +0800 Subject: [PATCH] fix autograd bugs --- .../_src/math/object_transform/autograd.py | 16 ++++++++------ .../object_transform/tests/test_autograd.py | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index eb5571c4e..4c3045558 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -224,13 +224,17 @@ def __call__(self, *args, **kwargs): ) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if current_transform_number(): - return self._return(rets) + # if not the outermost transformation + if current_transform_number(): + return self._return(rets) + else: + self._dyn_vars = stack + self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index ff5d67e27..b4fefc056 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -1149,4 +1149,25 @@ def test_debug_correctness2(self): self.assertTrue(bm.allclose(r1[1], r2[1])) self.assertTrue(bm.allclose(r1[2], r2[2])) + def test_cache1(self): + file = tempfile.TemporaryFile(mode='w+') + + def f(a, b): + print('compiling f ...', file=file) + return a + b + + grad1 = bm.grad(f)(1., 2.) # call "f" twice, one for Variable finding, one for compiling + grad2 = bm.vector_grad(f)(1., 2.) # call "f" once for compiling + + file.seek(0) + print(file.read().strip()) + + expect_res = ''' +compiling f ... +compiling f ... +compiling f ... + ''' + file.seek(0) + self.assertTrue(file.read().strip() == expect_res.strip()) +