Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…Link into devel
  • Loading branch information
jdcpni committed Aug 6, 2024
2 parents 98a8387 + e93b787 commit 9e21bd2
Show file tree
Hide file tree
Showing 20 changed files with 663 additions and 522 deletions.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
jupyter<1.0.1
packaging<25.0
pytest<8.3.2
pytest<8.3.3
pytest-benchmark<4.0.1
pytest-cov<5.0.1
pytest-forked<1.7.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2096,44 +2096,33 @@ def _function(self,
# if ocm is not None and ocm.parameters.comp_execution_mode._get(context) in {"PTX", "LLVM"}:
if ocm is not None and ocm.parameters.comp_execution_mode._get(context) in {"PTX", "LLVM"}:

# If we have a numpy array, convert back to ctypes
if isinstance(all_values, np.ndarray):
ct_values = all_values.flatten().ctypes.data_as(ctypes.POINTER(ctypes.c_double))
num_values = len(all_values.flatten())
else:
ct_values = all_values
num_values = len(ct_values)
ct_values = all_values
num_values = len(ct_values)

# Reduce array of values to min/max
# select_min params are:
# params, state, min_sample_ptr, sample_ptr, min_value_ptr, value_ptr, opt_count_ptr, count
min_tags = frozenset({"select_min", "evaluate_type_objective"})
bin_func = pnlvm.LLVMBinaryFunction.from_obj(self, tags=min_tags)
bin_func = pnlvm.LLVMBinaryFunction.from_obj(self, tags=min_tags, numpy_args=(2, 4, 6))

ct_param = bin_func.byref_arg_types[0](*self._get_param_initializer(context))
ct_state = bin_func.byref_arg_types[1](*self._get_state_initializer(context))
ct_opt_sample = bin_func.byref_arg_types[2](float("NaN"))
ct_alloc = None # NULL for samples
ct_opt_value = bin_func.byref_arg_types[4]()
ct_opt_count = bin_func.byref_arg_types[6](0)
ct_start = bin_func.c_func.argtypes[7](0)
ct_stop = bin_func.c_func.argtypes[8](num_values)

bin_func(ct_param, ct_state, ct_opt_sample, ct_alloc, ct_opt_value,
ct_values, ct_opt_count, ct_start, ct_stop)

optimal_value = ct_opt_value.value
optimal_sample = np.ctypeslib.as_array(ct_opt_sample)

if not isinstance(all_values, np.ndarray):
all_values = np.ctypeslib.as_array(ct_values)

# These are normally stored in the parent function (OptimizationFunction).
# Since we didn't call super()._function like the python path,
# save the values here
if self.parameters.save_samples._get(context):
self.parameters.saved_samples._set(all_samples, context)
if self.parameters.save_values._get(context):
self.parameters.saved_values._set(all_values, context)
optimal_sample = bin_func.np_buffer_for_arg(2)
optimal_value = bin_func.np_buffer_for_arg(4)
number_of_optimal_values = bin_func.np_buffer_for_arg(6, fill_value=0)

bin_func(ct_param,
ct_state,
optimal_sample,
None, # samples. NULL, it's generated by the function.
optimal_value,
ct_values,
number_of_optimal_values,
bin_func.c_func.argtypes[7](0), # start
bin_func.c_func.argtypes[8](num_values)) # stop

# Convert outputs to Numpy/Python
all_values = np.ctypeslib.as_array(ct_values)

# Python version
else:
Expand All @@ -2153,6 +2142,12 @@ def _function(self,
[all_samples[:,i] for i in range(all_samples.shape[1])])
optimal_value, optimal_sample = next(value_sample_pairs)

# The algorithm below implements "Reservoir sampling"[0]. This
# matches the compiled implementation of "select_min". The
# advantage of reservoir sampling is constant memory requirements
# and a single pass over the evaluated values.
# The disadvantage is multiple calls to the PRNG.
# https://en.wikipedia.org/wiki/Reservoir_sampling
select_randomly = self.parameters.select_randomly_from_optimal_values._get(context)
for value, sample in value_sample_pairs:
if select_randomly and np.allclose(value, optimal_value):
Expand Down
9 changes: 5 additions & 4 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11681,7 +11681,8 @@ def _execute_controller(self,
assert (execution_mode == pnlvm.ExecutionMode.LLVM
or execution_mode & pnlvm.ExecutionMode._Fallback),\
f"PROGRAM ERROR: Unrecognized compiled execution_mode: '{execution_mode}'."
_comp_ex.execute_node(self.controller, context=context)
_comp_ex.freeze_values()
_comp_ex.execute_node(self.controller)

context.remove_flag(ContextFlags.PROCESSING)

Expand Down Expand Up @@ -12010,7 +12011,7 @@ def execute(
build_CIM_input = self._build_variable_for_input_CIM(inputs)

if execution_mode & pnlvm.ExecutionMode.COMPILED:
_comp_ex.execute_node(self.input_CIM, inputs, context)
_comp_ex.execute_node(self.input_CIM, inputs)
# FIXME: parameter_CIM should be executed here as well,
# but node execution of nested compositions with
# outside control is not supported yet.
Expand Down Expand Up @@ -12295,7 +12296,7 @@ def execute(

# Execute Mechanism
if execution_mode & pnlvm.ExecutionMode.COMPILED:
_comp_ex.execute_node(node, context=context)
_comp_ex.execute_node(node)
else:
if node is not self.controller:
mech_context = copy(context)
Expand Down Expand Up @@ -12507,7 +12508,7 @@ def execute(
# Extract result here
if execution_mode & pnlvm.ExecutionMode.COMPILED:
_comp_ex.freeze_values()
_comp_ex.execute_node(self.output_CIM, context=context)
_comp_ex.execute_node(self.output_CIM)
report(self,
PROGRESS_REPORT,
report_num=report_num,
Expand Down
40 changes: 28 additions & 12 deletions psyneulink/core/llvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import codegen
from .builder_context import *
from .builder_context import _all_modules, _convert_llvm_ir_to_ctype
from .builder_context import _all_modules, _convert_llvm_ir_to_ctype, _convert_llvm_ir_to_dtype
from .debug import debug_env
from .execution import *
from .execution import _tupleize
Expand Down Expand Up @@ -123,7 +123,7 @@ def _llvm_build(target_generation=_binary_generation + 1):


class LLVMBinaryFunction:
def __init__(self, name: str):
def __init__(self, name: str, *, numpy_args=()):
self.name = name

self.__c_func = None
Expand All @@ -143,17 +143,25 @@ def __init__(self, name: str):
# Create ctype function instance
start = time.perf_counter()
return_type = _convert_llvm_ir_to_ctype(f.return_value.type)
params = [_convert_llvm_ir_to_ctype(a.type) for a in f.args]
args = [_convert_llvm_ir_to_ctype(a.type) for a in f.args]

# '_type_' special attribute stores pointee type for pointers
# https://docs.python.org/3/library/ctypes.html#ctypes._Pointer._type_
self.byref_arg_types = [a._type_ if hasattr(a, "contents") else None for a in args]
self.np_params = [_convert_llvm_ir_to_dtype(getattr(a.type, "pointee", a.type)) for a in f.args]

for a in numpy_args:
assert self.byref_arg_types[a] is not None
args[a] = np.ctypeslib.ndpointer(dtype=self.np_params[a].base, shape=self.np_params[a].shape)

middle = time.perf_counter()
self.__c_func_type = ctypes.CFUNCTYPE(return_type, *params)
self.__c_func_type = ctypes.CFUNCTYPE(return_type, *args)
finish = time.perf_counter()

if "time_stat" in debug_env:
print("Time to create ctype function '{}': {} ({} to create types)".format(
name, finish - start, middle - start))

self.byref_arg_types = [p._type_ for p in params]

@property
def c_func(self):
if self.__c_func is None:
Expand Down Expand Up @@ -218,26 +226,34 @@ def cuda_wrap_call(self, *args, **kwargs):
wrap_args = (jit_engine.pycuda.driver.InOut(a) if isinstance(a, np.ndarray) else a for a in args)
self.cuda_call(*wrap_args, **kwargs)

def np_buffer_for_arg(self, arg_num, *, extra_dimensions=(), fill_value=np.nan):

out_base = self.np_params[arg_num].base
out_shape = extra_dimensions + self.np_params[arg_num].shape

# fill the buffer with NaN poison
return np.full(out_shape, fill_value, dtype=out_base)

@staticmethod
@functools.lru_cache(maxsize=32)
def from_obj(obj, *, tags:frozenset=frozenset()):
def from_obj(obj, *, tags:frozenset=frozenset(), numpy_args:tuple=()):
name = LLVMBuilderContext.get_current().gen_llvm_function(obj, tags=tags).name
return LLVMBinaryFunction.get(name)
return LLVMBinaryFunction.get(name, numpy_args=numpy_args)

@staticmethod
@functools.lru_cache(maxsize=32)
def get(name: str):
return LLVMBinaryFunction(name)
def get(name: str, *, numpy_args:tuple=()):
return LLVMBinaryFunction(name, numpy_args=numpy_args)

def get_multi_run(self):
def get_multi_run(self, *, numpy_args=()):
try:
multirun_llvm = _find_llvm_function(self.name + "_multirun")
except ValueError:
function = _find_llvm_function(self.name)
with LLVMBuilderContext.get_current() as ctx:
multirun_llvm = codegen.gen_multirun_wrapper(ctx, function)

return LLVMBinaryFunction.get(multirun_llvm.name)
return LLVMBinaryFunction.get(multirun_llvm.name, numpy_args=numpy_args)


_cpu_engine = None
Expand Down
96 changes: 76 additions & 20 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def module_count():
'mt_rand_init', 'philox_rand_init'))


class _node_wrapper():
class _node_assembly():
def __init__(self, composition, node):
self._comp = weakref.proxy(composition)
self._node = node
Expand All @@ -61,7 +61,7 @@ def __repr__(self):
return "Node wrapper for node '{}' in composition '{}'".format(self._node, self._comp)

def _gen_llvm_function(self, *, ctx, tags:frozenset):
return codegen.gen_node_wrapper(ctx, self._comp, self._node, tags=tags)
return codegen.gen_node_assembly(ctx, self._comp, self._node, tags=tags)

def _comp_cached(func):
@functools.wraps(func)
Expand Down Expand Up @@ -349,6 +349,13 @@ def get_state_space(self, builder, component, state_ptr, param):
return helpers.get_state_space(builder, component, state_ptr, param_name)

def check_used_params(self, component, *, tags:frozenset):
"""
This function checks that parameters included in the compiled structures are used in compiled code.
If the assertion in this function triggers the parameter name should be added to the parameter
block list in the Component class.
"""

# Skip the check if the parameter use is not tracked. Some components (like node wrappers)
# don't even have parameters.
if component not in self._component_state_use and component not in self._component_param_use:
Expand Down Expand Up @@ -378,12 +385,6 @@ def check_used_params(self, component, *, tags:frozenset):
if hasattr(component, 'evaluate_agent_rep'):
used_param_ids.add('num_trials_per_estimate')

if hasattr(component, 'adapt_scale'):
used_param_ids.add('threshold')
used_param_ids.add('adapt_scale')
used_param_ids.add('adapt_base')
used_param_ids.add('adapt_entropy_weighting')

unused_param_ids = component_param_ids - used_param_ids - initializers
unused_state_ids = component_state_ids - used_state_ids

Expand Down Expand Up @@ -504,57 +505,62 @@ def get_data_struct_type(self, component):

return ir.LiteralStructType([])

def get_node_wrapper(self, composition, node):
cache = getattr(composition, '_wrapped_nodes', None)
def get_node_assembly(self, composition, node):
cache = getattr(composition, '_node_assemblies', None)
if cache is None:
cache = weakref.WeakKeyDictionary()
setattr(composition, '_wrapped_nodes', cache)
return cache.setdefault(node, _node_wrapper(composition, node))
setattr(composition, '_node_assemblies', cache)
return cache.setdefault(node, _node_assembly(composition, node))

def convert_python_struct_to_llvm_ir(self, t):
self._stats["types_converted"] += 1
if t is None:
return ir.LiteralStructType([])
elif type(t) is list:
if len(t) == 0:
return ir.LiteralStructType([])
elems_t = [self.convert_python_struct_to_llvm_ir(x) for x in t]
if all(x == elems_t[0] for x in elems_t):
return ir.ArrayType(elems_t[0], len(elems_t))
return ir.LiteralStructType(elems_t)
elif type(t) is tuple:

elif isinstance(t, (list, tuple)):
elems_t = [self.convert_python_struct_to_llvm_ir(x) for x in t]
if len(elems_t) > 0 and all(x == elems_t[0] for x in elems_t):
return ir.ArrayType(elems_t[0], len(elems_t))

return ir.LiteralStructType(elems_t)

elif isinstance(t, enum.Enum):
# FIXME: Consider enums of non-int type
assert all(round(x.value) == x.value for x in type(t))
return self.int32_ty

elif isinstance(t, (int, float, np.floating)):
return self.float_ty

elif isinstance(t, np.integer):
# Python 'int' is handled above as it is the default type for '0'
return ir.IntType(t.nbytes * 8)

elif isinstance(t, np.ndarray):
# 0d uint32 values were likely created from enums (above) and are
# observed here after compilation sync.
# Avoid silent promotion to float (via Python's builtin int-type)
if t.ndim == 0 and t.dtype == np.uint32:
return self.convert_python_struct_to_llvm_ir(t.reshape(1)[0])
return self.convert_python_struct_to_llvm_ir(t.tolist())

elif isinstance(t, np.random.RandomState):
return pnlvm.builtins.get_mersenne_twister_state_struct(self)

elif isinstance(t, np.random.Generator):
assert isinstance(t.bit_generator, np.random.Philox)
return pnlvm.builtins.get_philox_state_struct(self)

elif isinstance(t, Time):
return ir.ArrayType(self.int32_ty, len(TimeScale))

elif isinstance(t, SampleIterator):
if isinstance(t.generator, list):
return ir.ArrayType(self.float_ty, len(t.generator))

# Generic iterator is {start, increment, count}
return ir.LiteralStructType((self.float_ty, self.float_ty, self.int32_ty))

assert False, "Don't know how to convert {}".format(type(t))


Expand Down Expand Up @@ -765,3 +771,53 @@ def _convert_llvm_ir_to_ctype(t: ir.Type):
assert False, "Don't know how to convert LLVM type: {}".format(t)

return ret_t

@functools.lru_cache(maxsize=16)
def _convert_llvm_ir_to_dtype(t: ir.Type):

if isinstance(t, ir.IntType):
if t.width == 8:
return np.uint8().dtype

elif t.width == 16:
return np.uint16().dtype

elif t.width == 32:
return np.uint32().dtype

elif t.width == 64:
return np.uint64().dtype

else:
assert False, "Unsupported integer type: {}".format(type(t))

elif isinstance(t, ir.DoubleType):
return np.float64().dtype

elif isinstance(t, ir.FloatType):
return np.float32().dtype

elif isinstance(t, ir.HalfType):
return np.float16().dtype

elif isinstance(t, ir.ArrayType):
element_type = _convert_llvm_ir_to_dtype(t.element)

# Create multidimensional array instead of nesting
if element_type.subdtype is not None:
element_type, shape = element_type.subdtype
else:
shape = ()

ret_t = np.dtype((element_type, (len(t),) + shape))

elif isinstance(t, ir.LiteralStructType):
field_list = []
for i, e in enumerate(t.elements):
field_list.append(("field_" + str(i), _convert_llvm_ir_to_dtype(e)))

ret_t = np.dtype(field_list, align=True)
else:
assert False, "Don't know how to convert LLVM type to dtype: {}".format(t)

return ret_t
Loading

0 comments on commit 9e21bd2

Please sign in to comment.