Skip to content

Commit

Permalink
llvm: Move learnable matrices from RO params to RW params (#2933)
Browse files Browse the repository at this point in the history
Cleanup the function that filters compilable parameters.
Restrict the writeback function to always work on the compiled state (RW params).
  • Loading branch information
jvesely authored Mar 26, 2024
2 parents f58e2c5 + 69a339f commit 05ef92b
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 47 deletions.
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_params_to_pnl()
ex.writeback_state_to_pnl()

return ex.execute

Expand All @@ -210,7 +210,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_params_to_pnl()
ex.writeback_state_to_pnl()

return ex.cuda_execute

Expand Down
76 changes: 46 additions & 30 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,30 +1292,55 @@ def __deepcopy__(self, memo):
# ------------------------------------------------------------------------------------------------------------------
# Compilation support
# ------------------------------------------------------------------------------------------------------------------
def _is_compilable_param(self, p):

# User only parameters are not compiled.
if p.read_only and p.getter is not None:
return False

# Shared and aliased parameters are for user conveniecne and not compiled.
if isinstance(p, (ParameterAlias, SharedParameter)):
return False

# TODO this should use default value
val = p.get()

# Strings, builtins, functions, and methods are not compilable
return not isinstance(val, (str,
type(max),
type(np.sum),
type(make_parameter_property),
type(self._get_compilation_params)))


def _get_compilation_state(self):
# FIXME: MAGIC LIST, Use stateful tag for this
whitelist = {"previous_time", "previous_value", "previous_v",
"previous_w", "random_state",
"input_ports", "output_ports",
"adjustment_cost", "intensity_cost", "duration_cost",
"intensity"}

# Prune subcomponents (which are enabled by type rather than a list)
# that should be omitted
blacklist = { "objective_mechanism", "agent_rep", "projections", "shadow_inputs"}

# Only mechanisms use "value" state, can execute 'until finished',
# and need to track executions
# Mechanisms;
# * use "value" state
# * can execute 'until finished'
# * need to track number of executions
if hasattr(self, 'ports'):
whitelist.update({"value", "num_executions_before_finished",
"num_executions", "is_finished_flag"})

# If both the mechanism and its functoin use random_state it's DDM
# with integrator function. The mechanism's random_state is not used.
# If both the mechanism and its function use random_state.
# it's DDM with integrator function.
# The mechanism's random_state is not used.
if hasattr(self.parameters, 'random_state') and hasattr(self.function.parameters, 'random_state'):
whitelist.remove('random_state')


# Only mechanisms and compositions need 'num_executions'
# Compositions need to track number of executions
if hasattr(self, 'nodes'):
whitelist.add("num_executions")

Expand All @@ -1341,11 +1366,15 @@ def _get_compilation_state(self):
if hasattr(self.parameters, 'duplicate_keys'):
blacklist.add("previous_value")

# Matrices of learnable projections are stateful
if getattr(self, 'owner', None) and getattr(self.owner, 'learnable', False):
whitelist.add('matrix')

def _is_compilation_state(p):
# FIXME: This should use defaults instead of 'p.get'
return p.name not in blacklist and \
not isinstance(p, (ParameterAlias, SharedParameter)) and \
(p.name in whitelist or isinstance(p.get(), Component))
(p.name in whitelist or isinstance(p.get(), Component)) and \
self._is_compilable_param(p)

return filter(_is_compilation_state, self.parameters)

Expand All @@ -1362,9 +1391,10 @@ def llvm_state_ids(self):

def _get_state_initializer(self, context):
def _convert(p):
# FIXME: This should use defaults instead of 'p.get'
x = p.get(context)
if isinstance(x, np.random.RandomState):
if p.name == 'matrix': # Flatten matrix
val = tuple(np.asfarray(x).flatten())
elif isinstance(x, np.random.RandomState):
# Skip first element of random state (id string)
val = pnlvm._tupleize((*x.get_state()[1:], x.used_seed[0]))
elif isinstance(x, np.random.Generator):
Expand Down Expand Up @@ -1432,11 +1462,11 @@ def _get_compilation_params(self):
"learning_results", "learning_signal", "learning_signals",
"error_matrix", "error_signal", "activation_input",
"activation_output", "error_sources", "covariates_sources",
"target", "sample",
"target", "sample", "learning_function"
}
# Mechanism's need few extra entries:
# * matrix -- is never used directly, and is flatened below
# * integration rate -- shape mismatch with param port input
# * integration_rate -- shape mismatch with param port input
# * initializer -- only present on DDM and never used
# * search_space -- duplicated between OCM and its function
if hasattr(self, 'ports'):
Expand Down Expand Up @@ -1466,26 +1496,12 @@ def _get_compilation_params(self):
if cost_functions.DURATION not in cost_functions:
blacklist.add('duration_cost_fct')

def _is_compilation_param(p):
def _is_user_only_param(p):
if p.read_only and p.getter is not None:
return True
if isinstance(p, (ParameterAlias, SharedParameter)):
return True

return False

# Matrices of learnable projections are stateful
if getattr(self, 'owner', None) and getattr(self.owner, 'learnable', False):
blacklist.add('matrix')

if p.name not in blacklist and not _is_user_only_param(p):
# FIXME: this should use defaults
val = p.get()
# Check if the value type is valid for compilation
return not isinstance(val, (str, ComponentsMeta,
type(max),
type(np.sum),
type(_is_compilation_param),
type(self._get_compilation_params)))
return False
def _is_compilation_param(p):
return p.name not in blacklist and self._is_compilable_param(p)

return filter(_is_compilation_param, self.parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3886,7 +3886,7 @@ def instantiate_matrix(self, specification, context=None):
return np.array(specification)


def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, tags:frozenset):
def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
# Restrict to 1d arrays
if self.defaults.variable.ndim != 1:
warnings.warn("Shape mismatch: {} (in {}) got 2D input: {}".format(
Expand All @@ -3899,7 +3899,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, t
pnlvm.PNLCompilerWarning)
arg_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])

matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params)
matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params, state_struct_ptr=state)
normalize = ctx.get_param_or_state_ptr(builder, self, NORMALIZE, param_struct_ptr=params)

# Convert array pointer to pointer to the fist element
Expand Down
6 changes: 2 additions & 4 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11251,10 +11251,8 @@ def run(
self.parameters.results._set(results, context)

if self._is_learning(context):
# copies back matrix to pnl from param struct (after learning)
_comp_ex.writeback_params_to_pnl(params=_comp_ex._param_struct,
ids="llvm_param_ids",
condition=lambda p: p.name == "matrix")
# copies back matrix to pnl from state struct after learning
_comp_ex.writeback_state_to_pnl(condition=lambda p: p.name == "matrix")

self._propagate_most_recent_context(context)

Expand Down
2 changes: 2 additions & 0 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def _state_struct(p):
return self.get_state_struct_type(val)
if isinstance(val, ContentAddressableList):
return ir.LiteralStructType(self.get_state_struct_type(x) for x in val)
if p.name == 'matrix': # Flatten matrix
val = np.asfarray(val).flatten()
struct = self.convert_python_struct_to_llvm_ir(val)
return ir.ArrayType(struct, p.history_min_length + 1)

Expand Down
15 changes: 6 additions & 9 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,13 @@ def _get_compilation_param(self, name, init_method, arg):
return struct


def writeback_params_to_pnl(self, params=None, ids:Optional[str]=None, condition:Callable=lambda p: True):
def writeback_state_to_pnl(self, condition:Callable=lambda p: True):

assert (params is None) == (ids is None), "Either both 'params' and 'ids' have to be set or neither"

if params is None:
# Default to stateful params
params = self._state_struct
ids = "llvm_state_ids"

self._copy_params_to_pnl(self._execution_contexts[0], self._obj, params, ids, condition)
self._copy_params_to_pnl(self._execution_contexts[0],
self._obj,
self._state_struct,
"llvm_state_ids",
condition)


def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Callable):
Expand Down

0 comments on commit 05ef92b

Please sign in to comment.