Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing and referencing of variables #169

Open
jrmaddison opened this issue Sep 24, 2024 · 0 comments
Open

Checkpointing and referencing of variables #169

jrmaddison opened this issue Sep 24, 2024 · 0 comments

Comments

@jrmaddison
Copy link
Contributor

The pyadjoint tape references backend variables. This means that any memory allocated for the forward variables, during the forward calculation, is referenced by the tape. This can prevent memory usage being reduced by checkpointing.

Example

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import MultistageCheckpointSchedule

N = 100

mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)

tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))

u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
    u_ = Function(space, name="u")
    assemble(Interpolate(u + u, space), tensor=u_)
    u = u_
    del u_
pause_annotation()
del u

deps = set()
outputs = set()
for block in tape._blocks:
    for dep in block.get_dependencies():
        if isinstance(dep.output, Function):
            deps.add(dep.output.count())
    for dep in block.get_outputs():
        if isinstance(dep.output, Function):
            outputs.add(dep.output.count())

print(f"{len(deps)=}")
print(f"{len(outputs)=}")

leads to output

len(deps)=100
len(outputs)=100
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant