Skip to content

Commit

Permalink
adding tests for trivial checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 18, 2023
1 parent c3c50d5 commit f39111b
Showing 1 changed file with 207 additions and 0 deletions.
207 changes: 207 additions & 0 deletions tests/test_trivial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# For tlm_adjoint copyright information see ACKNOWLEDGEMENTS in the tlm_adjoint
# root directory

# This file is part of tlm_adjoint.
#
# tlm_adjoint is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, version 3 of the License.
# tlm_adjoint is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.

import functools
import pytest
from checkpoint_schedules.schedule import \
Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType
from checkpoint_schedules import SingleDiskStorageSchedule, \
SingleMemoryStorageSchedule


def single_disk_copy(n):
checkpointing = SingleDiskStorageSchedule()
return (checkpointing,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_disk_move(n):
checkpointing = SingleDiskStorageSchedule(move_data=True)
return (checkpointing,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_memory(n):
checkpointing = SingleMemoryStorageSchedule()
return (checkpointing,
{StorageType.RAM: 0, StorageType.DISK: 0}, n)


@pytest.mark.parametrize(
"schedule",
[
single_disk_copy,
single_disk_move,
single_memory,
]
)


def test_validity(schedule, n=10):

Check failure on line 56 in tests/test_trivial.py

View workflow job for this annotation

GitHub Actions / Run linters

E304 blank lines found after function decorator
"""Test the checkpoint revolvers.
Parameters
----------
schedule : object
Revolver schedule.
n : int
Total forward steps.
S : int
Snapshots.
"""
@functools.singledispatch
def action(cp_action):
raise TypeError("Unexpected action")

@action.register(Forward)
def action_forward(cp_action):
nonlocal model_n
# Start at the current location of the forward
assert model_n is not None and model_n == cp_action.n0

assert cp_action.storage == StorageType.WORK or \
cp_action.storage == StorageType.DISK
# If the schedule has been finalized, end at or before the end of the
# forward
assert cp_schedule.max_n is None or cp_action.n1 <= n
if cp_schedule.max_n is not None:
# Do not advance further than the current location of the adjoint
assert cp_action.n1 <= n - model_r

n1 = min(cp_action.n1, n)

model_n = n1
data.clear()
assert not cp_action.write_ics
assert cp_action.write_adj_deps is True
assert len(data.intersection(range(cp_action.n0, n1))) == 0
if cp_action.storage == StorageType.DISK:
for step in range(cp_action.n0, n1):
data.update(range(step, step + 1))
snapshots[cp_action.storage][step] = (set(ics), set(data))
data.clear()
assert cp_action.n0 == min(snapshots[cp_action.storage])
else:
data.update(range(cp_action.n0, n1))
assert cp_action.n0 == min(data)

assert len(ics) == 0

if n1 == n:
cp_schedule.finalize(n1)

@action.register(Reverse)
def action_reverse(cp_action):
nonlocal model_r

# Start at the current location of the adjoint
assert cp_action.n1 == n - model_r
# Advance at least one step
assert cp_action.n0 < cp_action.n1
# Non-linear dependency data for these steps is stored
assert data.issuperset(range(cp_action.n0, cp_action.n1))

model_r += cp_action.n1 - cp_action.n0
if cp_action.clear_adj_deps:
data.clear()

@action.register(Copy)
def action_copy(cp_action):
nonlocal model_n
# The checkpoint exists
assert cp_action.n in snapshots[cp_action.from_storage]
cp = snapshots[cp_action.from_storage][cp_action.n]

# No data is currently stored for this step
assert cp_action.n not in ics
assert cp_action.n not in data

# The checkpoint contains forward data
assert len(cp[0]) == 0 and len(cp[1]) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r
model_n = None

data.clear()
data.update(cp[1])

@action.register(Move)
def action_move(cp_action):
nonlocal model_n
# The checkpoint exists
assert cp_action.n in snapshots[cp_action.from_storage]
assert cp_action.n == max(snapshots[cp_action.from_storage])
cp = snapshots[cp_action.from_storage][cp_action.n]

# No data is currently stored for this step
assert cp_action.n not in ics
assert cp_action.n not in data

# The checkpoint contains forward data
assert len(cp[0]) == 0 and len(cp[1]) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r
model_n = None
data.clear()
data.update(cp[1])
del snapshots[cp_action.from_storage][cp_action.n]

@action.register(EndForward)
def action_end_forward(cp_action):
# The correct number of forward steps has been taken
assert model_n is not None and model_n == n

@action.register(EndReverse)
def action_end_reverse(cp_action):
nonlocal model_r, cp_schedule

# The correct number of adjoint steps has been taken
assert model_r == n
is_exhausted = cp_schedule.is_exhausted
if is_exhausted is False:
model_r = 0

model_n = 0
model_r = 0
ics = set()
data = set()
assert isinstance(n, int)
snapshots = {StorageType.RAM: {}, StorageType.DISK: {}}
cp_schedule, storage_limits, data_limit = schedule(n)
if cp_schedule is None:
raise TypeError("Incompatible with schedule type.")
assert cp_schedule.n == 0
assert cp_schedule.r == 0
assert cp_schedule.max_n is None or cp_schedule.max_n == n

for _, cp_action in enumerate(cp_schedule):
action(cp_action)
assert model_n is None or model_n == cp_schedule.n
assert model_r == cp_schedule.r

# Checkpoint storage limits are not exceeded
for storage_type, storage_limit in storage_limits.items():
assert len(snapshots[storage_type]) <= storage_limit
# Data storage limit is not exceeded
assert min(1, len(ics)) + len(data) <= data_limit

if isinstance(cp_action, EndReverse):
break

0 comments on commit f39111b

Please sign in to comment.