Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Add additional cache tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 23, 2024
1 parent 72dedc6 commit 2df3f35
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions test/unit/test_updated_caching.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import ctypes
import pytest
import os
import tempfile
from functools import partial
from itertools import chain
from textwrap import dedent

from pyop2.caching import ( # noqa: F401
from pyop2.caching import (
disk_only_cache,
memory_cache,
memory_and_disk_cache,
default_parallel_hashkey,
clear_memory_cache
)
from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval # noqa: F401
from pyop2.compilation import load
from pyop2.mpi import MPI, COMM_WORLD


class StateIncrement:
Expand Down Expand Up @@ -138,3 +143,44 @@ def test_function_over_different_comms(request, state, decorator, uncached_funct
comm23.Free()

clear_memory_cache(COMM_WORLD)


# pyop2/compilation.py uses a custom cache which we test here
@pytest.mark.parallel(nprocs=2)
def test_writing_large_so():
# This test exercises the compilation caching when handling larger files
if COMM_WORLD.rank == 0:
preamble = dedent("""\
#include <stdio.h>\n
void big(double *result){
""")
variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024))
lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables)
program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", ))))
with open("big.c", "w") as fh:
fh.write(program)

COMM_WORLD.Barrier()
with open("big.c", "r") as fh:
program = fh.read()

if COMM_WORLD.rank == 1:
os.remove("big.c")

fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD)
assert fn is not None


@pytest.mark.parallel(nprocs=2)
def test_two_comms_compile_the_same_code():
new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank)
new_comm.name = "test_two_comms"
code = dedent("""\
#include <stdio.h>\n
void noop(){
printf("Do nothing!\\n");
}
""")

fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD)
assert fn is not None

0 comments on commit 2df3f35

Please sign in to comment.