Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
WIP: Fixing and tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 23, 2024
1 parent f654fe7 commit 72dedc6
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 54 deletions.
20 changes: 14 additions & 6 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import hashlib
import os
import pickle
import weakref
from collections.abc import MutableMapping
from pathlib import Path
from warnings import warn # noqa F401
Expand All @@ -44,6 +45,7 @@
from functools import partial, wraps

from pyop2.configuration import configuration
from pyop2.exceptions import CachingError, HashError # noqa: F401
from pyop2.logger import debug
from pyop2.mpi import (
MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm
Expand Down Expand Up @@ -245,9 +247,8 @@ class _CacheMiss:
def _as_hexdigest(*args):
hash_ = hashlib.md5()
for a in args:
# JBTODO: Remove or edit this check!
if isinstance(a, MPI.Comm) or isinstance(a, cachetools.keys._HashedTuple):
breakpoint()
if isinstance(a, MPI.Comm):
raise HashError("Communicators cannot be hashed, caching will be broken!")
hash_.update(str(a).encode())
return hash_.hexdigest()

Expand Down Expand Up @@ -385,6 +386,8 @@ class DEFAULT_CACHE(dict):
# - DictLikeDiskAccess = instrument(DictLikeDiskAccess)


# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT
# environment variable.
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
Expand Down Expand Up @@ -429,8 +432,13 @@ def wrapper(*args, **kwargs):
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]:
_KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache))
if (comm, comm.name, func, weakref.ref(local_cache)) not in [c[1:] for c in _KNOWN_CACHES]:
# JBTODO: When a comm is freed we will not hold a ref to the cache,
# but we should have a finalizer that extracts the stats before the object
# is deleted.
_KNOWN_CACHES.append(
(next(_CACHE_CIDX), comm, comm.name, func, weakref.ref(local_cache))
)

# JBTODO: Replace everything below here with:
# value = local_cache.get(key, CACHE_MISS)
Expand Down Expand Up @@ -508,7 +516,7 @@ def decorator(func):
# * Add some sort of cache statistics ✓
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method ✓
# * Systematic investigation into cache sizes/types for Firedrake
# - Is a mem cache needed for DLLs? No
# - Is a mem cache needed for DLLs? ~~No~~ Yes!!
# - Is LRUCache better than a simple dict? (memory profile test suite)
# - What is the optimal maxsize?
# * Add some docstrings and maybe some exposition!
94 changes: 46 additions & 48 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from pathlib import Path
from contextlib import contextmanager
from tempfile import gettempdir
from itertools import cycle
from uuid import uuid4


Expand All @@ -69,6 +68,8 @@ def _check_hashes(x, y, datatype):

_check_op = mpi.MPI.Op.Create(_check_hashes, commute=True)
_compiler = None
# Directory must be unique per user for shared machines
MEM_TMP_DIR = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")


def set_default_compiler(compiler):
Expand Down Expand Up @@ -421,8 +422,6 @@ def load_hashkey(*args, **kwargs):
return default_parallel_hashkey(code_hash, *args[1:], **kwargs)


# JBTODO: This should not be memory cached
# ...benchmarking disagrees with my assessment
@mpi.collective
@memory_cache(hashkey=load_hashkey, broadcast=False)
@PETSc.Log.EventDecorator()
Expand Down Expand Up @@ -476,7 +475,7 @@ def __init__(self, code, argtypes):
check_source_hashes(compiler_instance, code, extension, comm)
# This call is cached on disk
so_name = make_so(compiler_instance, code, extension, comm)
# This call is cached in memory by the OS
# This call might be cached in memory by the OS (system dependent)
dll = ctypes.CDLL(so_name)

if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
Expand Down Expand Up @@ -532,6 +531,14 @@ def _make_so_hashkey(compiler, jitmodule, extension, comm):


def check_source_hashes(compiler, jitmodule, extension, comm):
"""A check to see whether code generated on all ranks is identical.
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
"""
# Reconstruct hash from filename
hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm))
with mpi.temp_internal_comm(comm) as icomm:
Expand All @@ -549,9 +556,6 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
raise CompilationError(f"Generated code differs across ranks (see output in {output})")


FILE_CYCLER = cycle(f"{ii:02x}" for ii in range(256))


@mpi.collective
@parallel_cache(
hashkey=_make_so_hashkey,
Expand All @@ -570,16 +574,14 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""
if filename is None:
# JBTODO: Remove this directory at some point?
# Directory must be unique per user for shared machines
pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")
# A UUID should ensure we have a unique path
uuid = uuid4().hex
tempdir = pyop2_tempdir.joinpath(f"{uuid[:2]}")
# Taking the first two characters avoids using excessive filesystem inodes
tempdir = MEM_TMP_DIR.joinpath(f"{uuid[:2]}")
# This path + filename should be unique
filename = tempdir.joinpath(f"{uuid[2:]}.{extension}")
else:
pyop2_tempdir = None
tempdir = None
filename = Path(filename).absolute()

# Compilation communicators are reference counted on the PyOP2 comm
Expand All @@ -594,8 +596,8 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
exe = compiler.cc
compiler_flags = compiler.cflags

# JBTODO: Do we still need to worry about atomic file renaming in this function?
base = filename.name
# TODO: Do we still need to worry about atomic file renaming in this function?
base = filename.stem
path = filename.parent
pid = os.getpid()
cname = filename.with_name(f"{base}_p{pid}.{extension}")
Expand All @@ -606,11 +608,10 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):

# Compile on compilation communicator (ccomm) rank 0
if ccomm.rank == 0:
if pyop2_tempdir is None:
if tempdir is None:
filename.parent.mkdir(exist_ok=True)
else:
pyop2_tempdir.mkdir(exist_ok=True)
tempdir.mkdir(exist_ok=True)
tempdir.mkdir(parents=True, exist_ok=True)
logfile = path.joinpath(f"{base}_p{pid}.log")
errfile = path.joinpath(f"{base}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
Expand All @@ -632,13 +633,9 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
return ccomm.bcast(soname, root=0)


# JBTODO: Probably don't want to do this if we fail to compile...
# ~ @atexit
# ~ def _cleanup_tempdir():
# ~ pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")


def _run(cc, logfile, errfile, step="Compilation", filemode="w"):
""" Run a compilation command and handle logging + errors.
"""
debug(f"{step} command: {' '.join(cc)}")
try:
if configuration['no_fork_available']:
Expand Down Expand Up @@ -686,28 +683,29 @@ def clear_cache(prompt=False):
:arg prompt: if ``True`` prompt before removing any files
"""
cachedir = configuration['cache_dir']

if not os.path.exists(cachedir):
print("Cache directory could not be found")
return
if len(os.listdir(cachedir)) == 0:
print("No cached libraries to remove")
return

remove = True
if prompt:
user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ")

while user.lower() not in ['', 'y', 'n']:
print("Please answer y or n.")
user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ")

if user.lower() == 'n':
remove = False

if remove:
print(f"Removing cached libraries from {cachedir}")
shutil.rmtree(cachedir, ignore_errors=True)
else:
print("Not removing cached libraries")
cachedirs = [configuration['cache_dir'], MEM_TMP_DIR]

for directory in cachedirs:
if not os.path.exists(directory):
print("Cache directory could not be found")
return
if len(os.listdir(directory)) == 0:
print("No cached libraries to remove")
return

remove = True
if prompt:
user = input(f"Remove cached libraries from {directory}? [Y/n]: ")

while user.lower() not in ['', 'y', 'n']:
print("Please answer y or n.")
user = input(f"Remove cached libraries from {directory}? [Y/n]: ")

if user.lower() == 'n':
remove = False

if remove:
print(f"Removing cached libraries from {directory}")
shutil.rmtree(directory, ignore_errors=True)
else:
print("Not removing cached libraries")
1 change: 1 addition & 0 deletions pyop2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pyop2.exceptions import ConfigurationError


# JBTODO: Add a PYOP2_SPMD_STRICT environment variable to add various SPMD checks.
class Configuration(dict):
r"""PyOP2 configuration parameters
Expand Down
10 changes: 10 additions & 0 deletions pyop2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,13 @@ class CompilationError(RuntimeError):
class SparsityFormatError(ValueError):

"""Unable to produce a sparsity for this matrix format."""


class CachingError(ValueError):

"""A caching error."""


class HashError(CachingError):

"""Something is wrong with the hash."""
2 changes: 2 additions & 0 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class PyOP2CommError(ValueError):
# PYOP2_FINALISED flag.


# JBTODO: Make this decorator infinitely more useful by adding barriers before
# and after the function call, if being run with PYOP2_SPMD_STRICT=1.
def collective(fn):
extra = trim("""
This function is logically collective over MPI ranks, it is an
Expand Down

0 comments on commit 72dedc6

Please sign in to comment.