Skip to content

Commit

Permalink
llvm: Enable state writeback on all compiled Functions and Mechanisms (
Browse files Browse the repository at this point in the history
…#2938)

Automatically write-back state parameters for compiled tests of both Functions and Mechanisms
  • Loading branch information
jvesely authored Apr 4, 2024
2 parents 05ef92b + a63de1d commit c060405
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
28 changes: 21 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,14 @@ def cuda_param(val):
return pytest.param(val, marks=[pytest.mark.llvm, pytest.mark.cuda])

@pytest.helpers.register
def get_func_execution(func, func_mode, *, writeback:bool=True):
def get_func_execution(func, func_mode):
if func_mode == 'LLVM':
ex = pnlvm.execution.FuncExecution(func)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_state_to_pnl()
ex.writeback_state_to_pnl()

return ex.execute

Expand All @@ -209,8 +208,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_state_to_pnl()
ex.writeback_state_to_pnl()

return ex.cuda_execute

Expand All @@ -222,9 +220,25 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
if mech_mode == 'LLVM':
return pnlvm.execution.MechExecution(mech).execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.execute

elif mech_mode == 'PTX':
return pnlvm.execution.MechExecution(mech).cuda_execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.cuda_execute

elif mech_mode == 'Python':
def mech_wrapper(x):
mech.execute(x)
Expand Down
21 changes: 16 additions & 5 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,21 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
ids=ids,
condition=condition)
else:
# TODO: Reconstruct Python RandomState
if attribute == "random_state":
continue

# TODO: Reconstruct Python memory storage
if attribute == "ring_memory":
continue

# "old_val" is a helper storage in compiled RecurrentTransferMechanism
# to workaround the fact that compiled projections do no pull values
# from their source output ports
# recurrent projection of RTM is not a PNL parameter.
if attribute in {"old_val", "recurrent_projection"}:
continue

# Handle PNL parameters
pnl_param = getattr(component.parameters, attribute)
pnl_value = pnl_param.get(context=context)
Expand All @@ -183,10 +198,6 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
# Writeback parameter value if the condition matches
elif condition(pnl_param):

# TODO: Reconstruct Python RandomState
if attribute == "random_state":
continue

# Replace empty structures with None
if ctypes.sizeof(compiled_attribute_param) == 0:
value = None
Expand All @@ -202,7 +213,7 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
if hasattr(old_value, 'shape'):
value = value.reshape(old_value.shape)

pnl_param.set(value, context=context)
pnl_param.set(value, context=context, override=True)


class CUDAExecution(Execution):
Expand Down
4 changes: 1 addition & 3 deletions tests/functions/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def test_basic(func, variable, params, expected, benchmark, func_mode):
if variable is philox_var:
f.parameters.random_state.set(_SeededPhilox([module_seed]))

# Do not allow writeback. "ring_memory" used by DictionaryMemory is a
# custom structure, not a PNL parameter
EX = pytest.helpers.get_func_execution(f, func_mode, writeback=False)
EX = pytest.helpers.get_func_execution(f, func_mode)

EX(variable)

Expand Down

0 comments on commit c060405

Please sign in to comment.