Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions tests/firedrake/adjoint/test_burgers_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,55 +35,87 @@ def setup_test(mesh):
return V, ic, nu


def _check_forward(tape):
def _control_bvs(controls):
# The control block variables legitimately retain their checkpoint
# across forward/adjoint replays (that is how the user-supplied
# control value is plumbed through to the next evaluation), so the
# post-replay clear-down assertions must skip them. Accepts either
# Control objects or the underlying overloaded variables.
if controls is None:
return set()
return {getattr(c, "block_variable", c) for c in controls}


def _check_forward(tape, controls=None):
skip = _control_bvs(controls)
for current_step in tape.timesteps[1:-1]:
for block in current_step:
for deps in block.get_dependencies():
if deps in skip:
continue
if (
deps not in tape.timesteps[0].checkpointable_state
and deps not in tape.timesteps[-1].checkpointable_state
):
assert deps._checkpoint is None
for out in block.get_outputs():
if out in skip:
continue
if out not in tape.timesteps[-1].checkpointable_state:
assert out._checkpoint is None


def _check_recompute(tape):
def _check_recompute(tape, controls=None):
skip = _control_bvs(controls)
for current_step in tape.timesteps[1:-1]:
for block in current_step:
for deps in block.get_dependencies():
if deps in skip:
continue
if deps not in tape.timesteps[0].checkpointable_state:
assert deps._checkpoint is None
for out in block.get_outputs():
if out in skip:
continue
assert out._checkpoint is None

for block in tape.timesteps[0]:
for out in block.get_outputs():
if out in skip:
continue
assert out._checkpoint is None
for block in tape.timesteps[len(tape.timesteps)-1]:
for deps in block.get_dependencies():
if deps in skip:
continue
if (
deps not in tape.timesteps[0].checkpointable_state
and deps not in tape.timesteps[len(tape.timesteps)-1].adjoint_dependencies
):
assert deps._checkpoint is None


def _check_reverse(tape):
def _check_reverse(tape, controls=None):
skip = _control_bvs(controls)
for step, current_step in enumerate(tape.timesteps):
if step > 0:
for block in current_step:
for deps in block.get_dependencies():
if deps in skip:
continue
if deps not in tape.timesteps[0].checkpointable_state:
assert deps._checkpoint is None

for out in block.get_outputs():
if out in skip:
continue
assert out._checkpoint is None
assert out.adj_value is None

for block in current_step:
for out in block.get_outputs():
if out in skip:
continue
assert out._checkpoint is None


Expand Down Expand Up @@ -157,23 +189,23 @@ def test_burgers_newton(solve_type, checkpointing, basics):
if checkpointing:
assert len(tape.timesteps) == total_steps
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_forward(tape)
_check_forward(tape, controls=[ic])

Jhat = ReducedFunctional(val, Control(ic))
if checkpointing != "NoneAdjoint":
dJ = Jhat.derivative()
if checkpointing is not None:
# Check if the reverse checkpointing is working correctly.
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_reverse(tape)
_check_reverse(tape, controls=[ic])

# Recomputing the functional with a modified control variable
# before the recompute test.
Jhat(project(sin(pi*SpatialCoordinate(mesh)[0]), V))
if checkpointing:
# Check is the checkpointing is working correctly.
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_recompute(tape)
_check_recompute(tape, controls=[ic])

# Recompute test
assert (np.allclose(Jhat(ic), val))
Expand Down
46 changes: 42 additions & 4 deletions tests/firedrake/adjoint/test_checkpointing_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from firedrake.adjoint import *
from .test_burgers_newton import _check_forward, \
_check_recompute, _check_reverse
from checkpoint_schedules import MixedCheckpointSchedule, StorageType
from checkpoint_schedules import MixedCheckpointSchedule, \
SingleMemoryStorageSchedule, StorageType
import numpy as np
from collections import deque

Expand Down Expand Up @@ -56,15 +57,15 @@ def test_multisteps(V):
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
displacement_0 = Function(V).assign(1.0)
val = J(displacement_0, V)
_check_forward(tape)
_check_forward(tape, controls=[displacement_0])
c = Control(displacement_0)
J_hat = ReducedFunctional(val, c)
dJ = J_hat.derivative()
_check_reverse(tape)
_check_reverse(tape, controls=[c])
# Recomputing the functional with a modified control variable
# before the recompute test.
J_hat(Function(V).assign(0.5))
_check_recompute(tape)
_check_recompute(tape, controls=[c])
# Recompute test
assert (np.allclose(J_hat(displacement_0), val))
# Test recompute adjoint-based gradient
Expand Down Expand Up @@ -92,3 +93,40 @@ def test_validity(V):
val_recomputed = J_hat(displacement_0)
assert np.allclose(val_recomputed, val_recomputed0)
assert np.allclose(dJ.dat.data_ro[:], dJ0.dat.data_ro[:])


@pytest.mark.skipcomplex
def test_control_value_survives_recompute():
"""Regression test for firedrakeproject/firedrake#5082.

Under SingleMemoryStorageSchedule, the checkpoint manager used to
clear the control's block-variable checkpoint during the forward
replay by writing var._checkpoint = None directly. That bypassed the
is_control guard in the BlockVariable setter, so the adjoint then
read back the underlying (stale) Function value instead of the new
control value installed by Control.update. For J = sum_k m**2 over
4 timesteps with m0 = 2, the correct derivative is 8 * m0 = 16; the
bug produced 8 (i.e. evaluated at the original m = 1).
"""
tape = get_working_tape()
tape.enable_checkpointing(SingleMemoryStorageSchedule())

mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "CG", 1)
m = Function(V).assign(1.0)
sumf = Function(V)
u = Function(V)
tst = TestFunction(V)
F = tst * u * dx - tst * m * m * dx
solver = NonlinearVariationalSolver(NonlinearVariationalProblem(F, u))

for _ in tape.timestepper(iter(range(4))):
solver.solve()
sumf.assign(sumf + u)

J_val = assemble(sumf * dx)
rf = ReducedFunctional(J_val, Control(m))

m0 = Function(V).assign(2.0)
assert np.allclose(rf(m0), 16.0)
assert np.allclose(rf.derivative(apply_riesz=True).dat.data_ro, 16.0)
Loading