From dbd24fdc27624e76546143340cdfb70b784ba823 Mon Sep 17 00:00:00 2001 From: Dillon Smith Date: Tue, 22 Dec 2020 15:34:29 -0500 Subject: [PATCH 01/27] docs: expand docstring for `reset_stateful_functions_when` argument of `run` --- psyneulink/core/compositions/composition.py | 25 ++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index 6374d19e61e..3d95f68679b 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -8029,16 +8029,21 @@ def run( a cycle is not specified, it is assigned its `default values ` when initialized (see `Composition_Cycles_and_Feedback` additional details). - reset_stateful_functions_to : Dict { Node : Object | iterable [Object] } : default None - object or iterable of objects to be passed as arguments to nodes' reset methods when their - respective reset_stateful_function_when conditions are met. These are used to seed the stateful attributes - of Mechanisms that have stateful functions. If a node's reset_stateful_function_when condition is set to - Never, but they are listed in the reset_stateful_functions_to dict, then they will be reset once at the - beginning of the run, using the provided values. - - reset_stateful_functions_when : Condition : default Never() - sets the reset_stateful_function_when condition for all nodes in the Composition that currently have their - reset_stateful_function_when condition set to `Never ` for the duration of the run. + reset_stateful_functions_to : Dict { Node : Object | iterable [Object] } : default None + object or iterable of objects to be passed as arguments to nodes' reset methods when their + respective reset_stateful_function_when conditions are met. These are used to seed the stateful attributes + of Mechanisms that have stateful functions. If a node's reset_stateful_function_when condition is set to + Never, but they are listed in the reset_stateful_functions_to dict, then they will be reset once at the + beginning of the run, using the provided values. For a more in depth explanation of this argument, see + `Resetting Parameters of StatefulFunctions `. + + reset_stateful_functions_when : Dict { Node: Condition } | Condition : default Never() + if type is dict, sets the reset_stateful_function_when attribute for each key Node to its corresponding value + Condition. if type is Condition, sets the reset_stateful_function_when attribute for all nodes in the + Composition that currently have their reset_stateful_function_when conditions set to `Never `. + in either case, the specified Conditions persist only for the duration of the run, after which the nodes' + reset_stateful_functions_when attributes are returned to their previous Conditions. For a more in depth + explanation of this argument, see `Resetting Parameters of StatefulFunctions `. skip_initialization : bool : default False From 7e737da3a75ae0504192380b1a49634221d38828 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Sat, 2 Jan 2021 17:46:46 -0500 Subject: [PATCH 02/27] github-actions: Install numpy before the rest of the dependencies Some deps (like GPy) otherwise pull the latest numpy rc resulting in ABI mismatch. Bump cache key to replace corrupted caches. Signed-off-by: Jan Vesely --- .github/workflows/pnl-ci.yml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pnl-ci.yml b/.github/workflows/pnl-ci.yml index 90f9e3f63fd..ce07edba612 100644 --- a/.github/workflows/pnl-ci.yml +++ b/.github/workflows/pnl-ci.yml @@ -27,24 +27,24 @@ jobs: if: startsWith(runner.os, 'Linux') with: path: ~/.cache/pip/wheels - key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-${{ github.sha }} - restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels + key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2-${{ github.sha }} + restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2 - name: MacOS wheels cache uses: actions/cache@v2.1.3 if: startsWith(runner.os, 'macOS') with: path: ~/Library/Caches/pip/wheels - key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-${{ github.sha }} - restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels + key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2-${{ github.sha }} + restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2 - name: Windows wheels cache uses: actions/cache@v2.1.3 if: startsWith(runner.os, 'Windows') with: path: ~\AppData\Local\pip\Cache\wheels - key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-${{ github.sha }} - restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels + key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2-${{ github.sha }} + restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ matrix.python-architecture }}-pip-wheels-v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2.2.1 @@ -72,7 +72,8 @@ jobs: - name: Shared dependencies shell: bash run: | - python -m pip install --upgrade pip wheel + # explicitly install numpy (https://github.com/pypa/pip/issues/9239) + python -m pip install --upgrade pip wheel $(grep numpy requirements.txt) pip install -e .[dev] - name: Cleanup old wheels From 0a1e8934993b6e42178bae3a5ce91dacc79db4a6 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Sat, 2 Jan 2021 19:13:52 -0500 Subject: [PATCH 03/27] github-actions/windows: Use pytorch version constraint from requirements.txt Install pytorch after shared requirements Signed-off-by: Jan Vesely --- .github/workflows/pnl-ci.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pnl-ci.yml b/.github/workflows/pnl-ci.yml index ce07edba612..5d9ace58ee8 100644 --- a/.github/workflows/pnl-ci.yml +++ b/.github/workflows/pnl-ci.yml @@ -64,11 +64,6 @@ jobs: run: choco install --no-progress -y graphviz --version=2.38.0.20190211 if: startsWith(runner.os, 'Windows') - - name: Windows pytorch - run: | - python -m pip install --upgrade pip wheel - pip install torch -f https://download.pytorch.org/whl/cpu/torch_stable.html - if: startsWith(runner.os, 'Windows') && matrix.python-architecture != 'x86' - name: Shared dependencies shell: bash run: | @@ -76,6 +71,12 @@ jobs: python -m pip install --upgrade pip wheel $(grep numpy requirements.txt) pip install -e .[dev] + - name: Windows pytorch + shell: bash + run: | + pip install $(grep -o 'torch[0-9<=\.]*' requirements.txt) -f https://download.pytorch.org/whl/cpu/torch_stable.html + if: startsWith(runner.os, 'Windows') && matrix.python-architecture != 'x86' + - name: Cleanup old wheels shell: bash run: | From 63334dd0d07e4bf730724f8cca932b61b41bd55f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Jan 2021 06:04:47 +0000 Subject: [PATCH 04/27] github-actions(deps): bump actions/upload-artifact from v2.2.1 to v2.2.2 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from v2.2.1 to v2.2.2. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v2.2.1...e448a9b857ee2131e752b06002bf0e093c65e571) Signed-off-by: dependabot[bot] --- .github/workflows/pnl-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pnl-ci.yml b/.github/workflows/pnl-ci.yml index 5d9ace58ee8..69c14b933b8 100644 --- a/.github/workflows/pnl-ci.yml +++ b/.github/workflows/pnl-ci.yml @@ -103,7 +103,7 @@ jobs: run: pytest --junit-xml=tests_out.xml --verbosity=0 -n auto --maxprocesses=2 - name: Upload test results - uses: actions/upload-artifact@v2.2.1 + uses: actions/upload-artifact@v2.2.2 with: name: test-results-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.python-architecture }} path: tests_out.xml @@ -117,7 +117,7 @@ jobs: if: contains(github.ref, 'tags') - name: Upload dist packages - uses: actions/upload-artifact@v2.2.1 + uses: actions/upload-artifact@v2.2.2 with: name: dist-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.python-architecture }} path: dist/ From 881af92fd9711e8fbd937f74f8d8cf364b282491 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:18:53 -0500 Subject: [PATCH 05/27] llvm/context: Add `bool_ty` field --- psyneulink/core/llvm/builder_context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/psyneulink/core/llvm/builder_context.py b/psyneulink/core/llvm/builder_context.py index de5a40d1fbc..a76ee79f372 100644 --- a/psyneulink/core/llvm/builder_context.py +++ b/psyneulink/core/llvm/builder_context.py @@ -86,6 +86,7 @@ class LLVMBuilderContext: _llvm_generation = 0 int32_ty = ir.IntType(32) float_ty = ir.DoubleType() + bool_ty = ir.IntType(1) def __init__(self): self._modules = [] From 98f84cdc6fa92f38b85834b9d1893d7fbca24494 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:18:59 -0500 Subject: [PATCH 06/27] llvm/helpers: Add recursive array iterator --- psyneulink/core/llvm/helpers.py | 28 ++++++++++++-------- tests/llvm/test_helpers.py | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index d04a15ff8f5..82566c750ff 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -158,17 +158,6 @@ def csch(ctx, builder, x): den = builder.fsub(e2x, e2x.type(1)) return builder.fdiv(num, den) -def call_elementwise_operation(ctx, builder, x, operation, output_ptr): - """Recurse through an array structure and call operation on each scalar element of the structure. Store result in output_ptr""" - if isinstance(x.type.pointee, ir.ArrayType): - with array_ptr_loop(builder, x, str(x) + "_elementwise_op") as (b1, idx): - element_ptr = b1.gep(x, [ctx.int32_ty(0), idx]) - output_element_ptr = b1.gep(output_ptr, [ctx.int32_ty(0), idx]) - call_elementwise_operation(ctx, b1, element_ptr, operation, output_ptr=output_element_ptr) - else: - val = operation(ctx, builder, builder.load(x)) - builder.store(val, output_ptr) - def is_close(builder, val1, val2, rtol=1e-05, atol=1e-08): diff = builder.fsub(val1, val2, "is_close_diff") diff_neg = fneg(builder, diff, "is_close_fneg_diff") @@ -239,6 +228,23 @@ def is_boolean(x): type_t = x.type.pointee return isinstance(type_t, ir.IntType) and type_t.width == 1 +def recursive_iterate_arrays(ctx, builder, u, *args): + """Recursively iterates over all elements in scalar arrays of the same shape""" + assert isinstance(u.type.pointee, ir.ArrayType), "Can only iterate over arrays!" + assert all(len(u.type.pointee) == len(v.type.pointee) for v in args), "Tried to iterate over differing lengths!" + with array_ptr_loop(builder, u, str(u) + "," + str(args) + "_recursive_zip") as (b, idx): + u_ptr = b.gep(u, [ctx.int32_ty(0), idx]) + arg_ptrs = (b.gep(v, [ctx.int32_ty(0), idx]) for v in args) + if is_scalar(u_ptr): + yield (u_ptr, *arg_ptrs) + else: + yield from recursive_iterate_arrays(ctx, b, u_ptr, *arg_ptrs) + +def call_elementwise_operation(ctx, builder, x, operation, output_ptr): + """Recurse through an array structure and call operation on each scalar element of the structure. Store result in output_ptr""" + for (inp_ptr, out_ptr) in recursive_iterate_arrays(ctx, builder, x, output_ptr): + builder.store(operation(ctx, builder, builder.load(inp_ptr)), out_ptr) + def printf(builder, fmt, *args, override_debug=False): if "print_values" not in debug_env and not override_debug: return diff --git a/tests/llvm/test_helpers.py b/tests/llvm/test_helpers.py index 665df1d33c5..b68b3c90f2d 100644 --- a/tests/llvm/test_helpers.py +++ b/tests/llvm/test_helpers.py @@ -590,3 +590,49 @@ def test_helper_elementwise_op(mode, var, expected): bin_f.cuda_wrap_call(var, res) assert np.array_equal(res, expected) + +@pytest.mark.llvm +@pytest.mark.parametrize('mode', ['CPU', + pytest.param('PTX', marks=pytest.mark.cuda)]) +@pytest.mark.parametrize('var1,var2,expected', [ + (np.array([1.,2.,3.]), np.array([1.,2.,3.]), np.array([2.,4.,6.])), + (np.array([1.,2.,3.]), np.array([0.,1.,2.]), np.array([1.,3.,5.])), + (np.array([[1.,2.,3.], + [4.,5.,6.], + [7.,8.,9.]]), + np.array([[10.,11.,12.], + [13.,14.,15.], + [16.,17.,18.]]), + np.array([[11.,13.,15.], + [17.,19.,21.], + [23.,25.,27.]])), +]) +def test_helper_recursive_iterate_arrays(mode, var1, var2, expected): + with pnlvm.LLVMBuilderContext() as ctx: + arr_ptr_ty = ctx.convert_python_struct_to_llvm_ir(var1).as_pointer() + + func_ty = ir.FunctionType(ir.VoidType(), [arr_ptr_ty, arr_ptr_ty, arr_ptr_ty]) + + custom_name = ctx.get_unique_name("elementwise_op") + function = ir.Function(ctx.module, func_ty, name=custom_name) + u, v, out = function.args + block = function.append_basic_block(name="entry") + builder = ir.IRBuilder(block) + + for (a_ptr, b_ptr, o_ptr) in pnlvm.helpers.recursive_iterate_arrays(ctx, builder, u, v, out): + a = builder.load(a_ptr) + b = builder.load(b_ptr) + builder.store(builder.fadd(a,b), o_ptr) + builder.ret_void() + + bin_f = pnlvm.LLVMBinaryFunction.get(custom_name) + if mode == 'CPU': + ct_vec = np.ctypeslib.as_ctypes(var1) + ct_vec_2 = np.ctypeslib.as_ctypes(var2) + res = bin_f.byref_arg_types[2]() + bin_f(ct_vec, ct_vec_2, ctypes.byref(res)) + else: + res = copy.deepcopy(var1) + bin_f.cuda_wrap_call(var1, var2, res) + + assert np.array_equal(res, expected) From 7da910044525a5856ef8e175d433049a4d65bb07 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:19:04 -0500 Subject: [PATCH 07/27] llvm/helpers: Add array shape helpers --- psyneulink/core/llvm/helpers.py | 19 +++++++++++++++++++ tests/llvm/test_helpers.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index 82566c750ff..ee569ac1d80 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -228,6 +228,25 @@ def is_boolean(x): type_t = x.type.pointee return isinstance(type_t, ir.IntType) and type_t.width == 1 +def get_array_shape(x): + x_ty = x.type + if is_pointer(x): + x_ty = x_ty.pointee + + assert isinstance(x_ty, ir.ArrayType), f"Tried to get shape of non-array type: {x_ty}" + dimensions = [] + while hasattr(x_ty, "count"): + dimensions.append(x_ty.count) + x_ty = x_ty.element + + return dimensions + +def array_from_shape(shape, element_ty): + array_ty = element_ty + for dim in reversed(shape): + array_ty = ir.ArrayType(array_ty, dim) + return array_ty + def recursive_iterate_arrays(ctx, builder, u, *args): """Recursively iterates over all elements in scalar arrays of the same shape""" assert isinstance(u.type.pointee, ir.ArrayType), "Can only iterate over arrays!" diff --git a/tests/llvm/test_helpers.py b/tests/llvm/test_helpers.py index b68b3c90f2d..7ce013332ae 100644 --- a/tests/llvm/test_helpers.py +++ b/tests/llvm/test_helpers.py @@ -519,6 +519,25 @@ def test_helper_is_boolean(self, mode, ir_type, expected): assert res == expected + @pytest.mark.llvm + @pytest.mark.parametrize('ir_type,expected', [ + (DOUBLE_VECTOR_TYPE, [1]), + (DOUBLE_VECTOR_PTR_TYPE, [1]), + (DOUBLE_MATRIX_TYPE, [1, 1]), + (DOUBLE_MATRIX_PTR_TYPE, [1, 1]), + ], ids=str) + def test_helper_get_array_shape(self, ir_type, expected): + assert pnlvm.helpers.get_array_shape(ir_type(None)) == expected + + @pytest.mark.llvm + @pytest.mark.parametrize('ir_type,shape', [ + (DOUBLE_VECTOR_TYPE, (1,)), + (DOUBLE_MATRIX_TYPE, (1,1)), + ], ids=str) + def test_helper_array_from_shape(self, ir_type, shape): + with pnlvm.LLVMBuilderContext() as ctx: + assert ir_type == pnlvm.helpers.array_from_shape(shape, ctx.float_ty) + @pytest.mark.llvm @pytest.mark.parametrize('mode', ['CPU', pytest.param('PTX', marks=pytest.mark.cuda)]) From fb14aa149f001de4c01de3ad4238efef964cd373 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:19:10 -0500 Subject: [PATCH 08/27] llvm/udf: Generalize `sum` to support multidim arguments --- psyneulink/core/llvm/codegen.py | 23 +++++++++++++++-------- tests/functions/test_user_defined_func.py | 23 +++++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index b1e842268f9..586d5429344 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -31,13 +31,20 @@ def __init__(self, ctx, builder, func_globals, func_params, arg_in, arg_out): self.register = {} #setup default functions - def _vec_sum(x): - dim = len(x.type.pointee) - output_scalar = builder.alloca(ctx.float_ty) - # Get the pointer to the first element of the array to convert from [? x double]* -> double* - vec_u = builder.gep(x, [ctx.int32_ty(0), ctx.int32_ty(0)]) - builder.call(ctx.import_llvm_function("__pnl_builtin_vec_sum"), [vec_u, ctx.int32_ty(dim), output_scalar]) - return output_scalar + def _list_sum(x): + # HACK: Obtain polymorphic addition function by visiting the add node + # this should ideally be moved to an explicit helper + add_func = self.visit_Add(None) + + total_sum = builder.alloca(x.type.pointee.element) + builder.store(total_sum.type.pointee(None), total_sum) + with helpers.array_ptr_loop(builder, x, "list_sum") as (b, idx): + curr_val = b.gep(x, [ctx.int32_ty(0), idx]) + tmp = add_func(total_sum, curr_val) + if helpers.is_pointer(tmp): + tmp = builder.load(tmp) + b.store(tmp, total_sum) + return total_sum def _tanh(x): output_ptr = builder.alloca(x.type.pointee) @@ -49,7 +56,7 @@ def _exp(x): helpers.call_elementwise_operation(self.ctx, self.builder, x, helpers.exp, output_ptr) return output_ptr - self.register['sum'] = _vec_sum + self.register['sum'] = _list_sum # setup numpy numpy_handlers = { diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index 044e9f080f1..dc556efb944 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -367,6 +367,29 @@ def myFunction(variable, param1, param2): val = benchmark(e, [-1, 2, 3, 4]) assert np.allclose(val, [[10]]) + @pytest.mark.parametrize("op,variable,expected", [ # parameter is string since compiled udf doesn't support closures as of present + ("SUM", [1.0, 3.0], 4), + ("SUM", [[1.0], [3.0]], [4.0]), + ]) + @pytest.mark.parametrize("bin_execute", ['Python', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), + ]) + @pytest.mark.benchmark(group="Function UDF") + def test_user_def_func_builtin(self, op, variable, expected, bin_execute, benchmark): + if op == "SUM": + def myFunction(variable): + return sum(variable) + + U = UserDefinedFunction(custom_function=myFunction, default_variable=variable) + if bin_execute == 'LLVM': + e = pnlvm.execution.FuncExecution(U).execute + elif bin_execute == 'PTX': + e = pnlvm.execution.FuncExecution(U).cuda_execute + else: + e = U + val = benchmark(e, variable) + assert np.allclose(val, expected) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), From 63e1fb5012c6ff76ceb8afea64742044565d14e6 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:19:17 -0500 Subject: [PATCH 09/27] llvm/udf: Add `len` --- psyneulink/core/llvm/codegen.py | 7 +++++++ tests/functions/test_user_defined_func.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 586d5429344..05c0986f409 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -46,6 +46,12 @@ def _list_sum(x): b.store(tmp, total_sum) return total_sum + def _len(x): + x_ty = x.type + if helpers.is_pointer(x): + x_ty = x_ty.pointee + return ctx.float_ty(len(x_ty)) + def _tanh(x): output_ptr = builder.alloca(x.type.pointee) helpers.call_elementwise_operation(self.ctx, self.builder, x, helpers.tanh, output_ptr) @@ -57,6 +63,7 @@ def _exp(x): return output_ptr self.register['sum'] = _list_sum + self.register['len'] = _len # setup numpy numpy_handlers = { diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index dc556efb944..ca1eaf1f368 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -370,6 +370,9 @@ def myFunction(variable, param1, param2): @pytest.mark.parametrize("op,variable,expected", [ # parameter is string since compiled udf doesn't support closures as of present ("SUM", [1.0, 3.0], 4), ("SUM", [[1.0], [3.0]], [4.0]), + ("LEN", [1.0, 3.0], 2), + ("LEN", [[1.0], [3.0]], 2), + ("LEN_TUPLE", [0, 0], 2), ]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), @@ -380,6 +383,12 @@ def test_user_def_func_builtin(self, op, variable, expected, bin_execute, benchm if op == "SUM": def myFunction(variable): return sum(variable) + elif op == "LEN": + def myFunction(variable): + return len(variable) + elif op == "LEN_TUPLE": + def myFunction(variable): + return len((1,2)) U = UserDefinedFunction(custom_function=myFunction, default_variable=variable) if bin_execute == 'LLVM': From cfa7c99a0fe29627c95fa6c3587afbaab013d709 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 06:19:25 -0500 Subject: [PATCH 10/27] llvm/udf: Add support for `shape` attribute of numpy arrays --- psyneulink/core/llvm/codegen.py | 6 ++++++ tests/functions/test_user_defined_func.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 05c0986f409..13135ee933c 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -274,6 +274,12 @@ def visit_Name(self, node): def visit_Attribute(self, node): val = self.visit(node.value) + + # special case numpy attributes + if node.attr == "shape": + shape = helpers.get_array_shape(val) + return ir.LiteralStructType([self.ctx.float_ty] * len(shape))(shape) + return val[node.attr] def visit_Num(self, node): diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index ca1eaf1f368..9cc7001a5ea 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -318,25 +318,29 @@ def myFunction(variable): val = benchmark(e, 0) assert np.allclose(val, expected) - @pytest.mark.parametrize("op,expected", [ # parameter is string since compiled udf doesn't support closures as of present - ("TANH", [0.76159416, 0.99505475]), - ("EXP", [2.71828183, 20.08553692]), + @pytest.mark.parametrize("op,variable,expected", [ # parameter is string since compiled udf doesn't support closures as of present + ("TANH", [[1, 3]], [0.76159416, 0.99505475]), + ("EXP", [[1, 3]], [2.71828183, 20.08553692]), + ("SHAPE", [1, 2], [2]), + ("SHAPE", [[1, 3]], [1, 2]), ]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), ]) @pytest.mark.benchmark(group="Function UDF") - def test_user_def_func_numpy(self, op, expected, bin_execute, benchmark): - variable = [[1, 3]] + def test_user_def_func_numpy(self, op, variable, expected, bin_execute, benchmark): if op == "TANH": def myFunction(variable): return np.tanh(variable) elif op == "EXP": def myFunction(variable): return np.exp(variable) + elif op == "SHAPE": + def myFunction(variable): + return variable.shape - U = UserDefinedFunction(custom_function=myFunction, default_variable=[[0, 0]]) + U = UserDefinedFunction(custom_function=myFunction, default_variable=variable) if bin_execute == 'LLVM': e = pnlvm.execution.FuncExecution(U).execute elif bin_execute == 'PTX': From ce8b01ff2f6995b504b32569e86b7a617fda244e Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 18:47:09 -0500 Subject: [PATCH 11/27] llvm/udf: Refactor register syntax + add float --- psyneulink/core/llvm/codegen.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 13135ee933c..2383f59dcb8 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -28,7 +28,6 @@ def __init__(self, ctx, builder, func_globals, func_params, arg_in, arg_out): self.func_params = func_params self.arg_in = arg_in self.arg_out = arg_out - self.register = {} #setup default functions def _list_sum(x): @@ -62,8 +61,11 @@ def _exp(x): helpers.call_elementwise_operation(self.ctx, self.builder, x, helpers.exp, output_ptr) return output_ptr - self.register['sum'] = _list_sum - self.register['len'] = _len + self.register = { + "sum": _list_sum, + "len": _len, + "float": ctx.float_ty, + } # setup numpy numpy_handlers = { From 3f14314b66a1881d0bc7fb82bffa6a6ebc67392d Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 18:44:33 -0500 Subject: [PATCH 12/27] llvm/udf: Add `astype` method --- psyneulink/core/llvm/codegen.py | 22 ++++++++++++++++++++++ tests/functions/test_user_defined_func.py | 9 +++++++++ 2 files changed, 31 insertions(+) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 2383f59dcb8..9543a639fd7 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -65,6 +65,7 @@ def _exp(x): "sum": _list_sum, "len": _len, "float": ctx.float_ty, + "int": ctx.int32_ty, } # setup numpy @@ -281,6 +282,27 @@ def visit_Attribute(self, node): if node.attr == "shape": shape = helpers.get_array_shape(val) return ir.LiteralStructType([self.ctx.float_ty] * len(shape))(shape) + elif node.attr == "astype": + def astype(ty): + def _convert(ctx, builder, x): + if helpers.is_pointer(x): + x = builder.load(x) + if helpers.is_integer(x) and ty is ctx.float_ty: + if helpers.is_boolean(x): + return builder.uitofp(x, ty) + return builder.sitofp(x, ty) + elif helpers.is_floating_point(x) and ty is self.register["int"]: + return builder.fptosi(x, ty) + elif (helpers.is_floating_point(x) and ty is ctx.float_ty): + return x + if helpers.is_scalar(val): + return _convert(self.ctx, self.builder, val) + else: + output_ptr = self.builder.alloca(helpers.array_from_shape(helpers.get_array_shape(val), ty)) + helpers.call_elementwise_operation(self.ctx, self.builder, val, _convert, output_ptr) + return output_ptr + # we only support float types + return astype return val[node.attr] diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index 9cc7001a5ea..c77881d34bd 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -323,6 +323,8 @@ def myFunction(variable): ("EXP", [[1, 3]], [2.71828183, 20.08553692]), ("SHAPE", [1, 2], [2]), ("SHAPE", [[1, 3]], [1, 2]), + ("ASTYPE_FLOAT", [1], [1.0]), + ("ASTYPE_INT", [-1.5], [-1.0]), ]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), @@ -339,6 +341,13 @@ def myFunction(variable): elif op == "SHAPE": def myFunction(variable): return variable.shape + elif op == "ASTYPE_FLOAT": + def myFunction(variable): + return variable.astype(float) + elif op == "ASTYPE_INT": + # return types cannot be integers, so we cast back to float and check for truncation + def myFunction(variable): + return variable.astype(int).astype(float) U = UserDefinedFunction(custom_function=myFunction, default_variable=variable) if bin_execute == 'LLVM': From 2de5b729673f7ca891c7bdecd648427115ea75ba Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 18:55:26 -0500 Subject: [PATCH 13/27] llvm/udf: Add multidim comparisons --- psyneulink/core/llvm/codegen.py | 95 ++++++++++------ tests/functions/test_user_defined_func.py | 131 ++++++++++++++++------ 2 files changed, 157 insertions(+), 69 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 9543a639fd7..32d03dc2c83 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -71,7 +71,13 @@ def _exp(x): # setup numpy numpy_handlers = { 'tanh': _tanh, - 'exp': _exp + 'exp': _exp, + 'equal': self._generate_fcmp_handler(self.ctx, self.builder, "=="), + 'not_equal': self._generate_fcmp_handler(self.ctx, self.builder, "!="), + 'less': self._generate_fcmp_handler(self.ctx, self.builder, "<"), + 'less_equal': self._generate_fcmp_handler(self.ctx, self.builder, "<="), + 'greater': self._generate_fcmp_handler(self.ctx, self.builder, ">"), + 'greater_equal': self._generate_fcmp_handler(self.ctx, self.builder, ">="), } for k, v in func_globals.items(): @@ -392,53 +398,68 @@ def _or(x, y): return _or - def visit_Eq(self, node): - def _eq(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('==', x, y)) + def _generate_fcmp_handler(self, ctx, builder, cmp): + def _cmp_array(ctx, builder, u, v): + assert u.type == v.type + shape = helpers.get_array_shape(u) + output_ptr = builder.alloca(helpers.array_from_shape(shape, ctx.bool_ty)) - return _eq + for (u_ptr, v_ptr, out_ptr) in helpers.recursive_iterate_arrays(ctx, builder, u, v, output_ptr): + u_val = builder.load(u_ptr) + v_val = builder.load(v_ptr) + builder.store(builder.fcmp_ordered(cmp, u_val, v_val), out_ptr) - def visit_NotEq(self, node): - def _neq(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('!=', x, y)) + return output_ptr - return _neq + def _cmp_array_scalar(ctx, builder, array, s): + shape = helpers.get_array_shape(array) + output_ptr = builder.alloca(helpers.array_from_shape(shape, ctx.bool_ty)) + helpers.call_elementwise_operation(ctx, builder, array, lambda ctx, builder, x: builder.fcmp_ordered(cmp, x, s), output_ptr) - def visit_Lt(self, node): - def _lt(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('<', x, y)) + return output_ptr - return _lt + def _cmp_scalar_array(ctx, builder, s, array): + shape = helpers.get_array_shape(array) + output_ptr = builder.alloca(helpers.array_from_shape(shape, ctx.bool_ty)) + helpers.call_elementwise_operation(ctx, builder, array, lambda ctx, builder, x: builder.fcmp_ordered(cmp, s, x), output_ptr) - def visit_LtE(self, node): - def _lte(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('<=', x, y)) + return output_ptr + + def _cmp(x, y): + if helpers.is_floating_point(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered(cmp, x, y)) + elif helpers.is_vector(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, _cmp_array_scalar) + elif helpers.is_floating_point(x) and helpers.is_vector(y): + return self._generate_binop(x, y, _cmp_scalar_array) + elif helpers.is_2d_matrix(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, _cmp_array_scalar) + elif helpers.is_floating_point(x) and helpers.is_2d_matrix(y): + return self._generate_binop(x, y, _cmp_scalar_array) + elif helpers.is_vector(x) and helpers.is_vector(y): + return self._generate_binop(x, y, _cmp_array) + elif helpers.is_2d_matrix(x) and helpers.is_2d_matrix(y): + return self._generate_binop(x, y, _cmp_array) - return _lte + return _cmp - def visit_Gt(self, node): - def _gt(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('>', x, y)) + def visit_Eq(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, "==") - return _gt + def visit_NotEq(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, "!=") - def visit_GtE(self, node): - def _gte(x, y): - assert helpers.is_floating_point(x), f"{x} is not a floating point type!" - assert helpers.is_floating_point(y), f"{y} is not a floating point type!" - return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fcmp_ordered('>=', x, y)) + def visit_Lt(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, "<") + + def visit_LtE(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, "<=") - return _gte + def visit_Gt(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, ">") + + def visit_GtE(self, node): + return self._generate_fcmp_handler(self.ctx, self.builder, ">=") def visit_Compare(self, node): comp_val = self.visit(node.left) diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index c77881d34bd..aaddce0aa33 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -109,72 +109,59 @@ def myFunction(variable): val = benchmark(e, [0]) assert val == 1.0 - @pytest.mark.parametrize("op", [ # parameter is string since compiled udf doesn't support closures as of present - "Eq", - "NotEq", - "Lt", - "LtE", - "Gt", - "GtE", + @pytest.mark.parametrize("op,var1,var2,expected", [ # parameter is string since compiled udf doesn't support closures as of present + ("Eq", 1.0, 2.0, 0.0), + ("NotEq", 1.0, 2.0, 1.0), + ("Lt", 1.0, 2.0, 1.0), + ("LtE", 1.0, 2.0, 1.0), + ("Gt", 1.0, 2.0, 0.0), + ("GtE", 1.0, 2.0, 0.0), ]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), ]) @pytest.mark.benchmark(group="Function UDF") - def test_user_def_func_cmpop(self, op, bin_execute, benchmark): + def test_user_def_func_cmpop(self, op, var1, var2, expected, bin_execute, benchmark): + # we explicitly use np here to ensure that the result is castable to float in the scalar-scalar case if op == "Eq": - def myFunction(variable): - var1 = 1.0 - var2 = 1.0 + def myFunction(variable, var1, var2): if var1 == var2: return 1.0 else: return 0.0 elif op == "NotEq": - def myFunction(variable): - var1 = 1.0 - var2 = 2.0 + def myFunction(variable, var1, var2): if var1 != var2: return 1.0 else: return 0.0 elif op == "Lt": - def myFunction(variable): - var1 = 1.0 - var2 = 2.0 + def myFunction(variable, var1, var2): if var1 < var2: return 1.0 else: return 0.0 elif op == "LtE": - def myFunction(variable): - var1 = 1.0 - var2 = 2.0 - var3 = 1.0 - if var1 <= var2 and var1 <= var3: + def myFunction(variable, var1, var2): + if var1 <= var2: return 1.0 else: return 0.0 elif op == "Gt": - def myFunction(variable): - var1 = 2.0 - var2 = 1.0 + def myFunction(variable, var1, var2): if var1 > var2: return 1.0 else: return 0.0 elif op == "GtE": - def myFunction(variable): - var1 = 3.0 - var2 = 2.0 - var3 = 3.0 - if var1 >= var2 and var1 >= var3: + def myFunction(variable, var1, var2): + if var1 >= var2: return 1.0 else: return 0.0 - U = UserDefinedFunction(custom_function=myFunction, default_variable=[0]) + U = UserDefinedFunction(custom_function=myFunction, default_variable=[0], var1=var1, var2=var2) if bin_execute == 'LLVM': e = pnlvm.execution.FuncExecution(U).execute elif bin_execute == 'PTX': @@ -182,7 +169,87 @@ def myFunction(variable): else: e = U val = benchmark(e, [0]) - assert val == 1.0 + assert np.allclose(expected, val) + + @pytest.mark.parametrize("op,var1,var2,expected", [ # parameter is string since compiled udf doesn't support closures as of present + ("Eq", 1.0, 2.0, 0.0), + ("Eq", [1.0, 2.0], [1.0, 2.0], [1.0, 1.0]), + ("Eq", 1.0, [1.0, 2.0], [1.0, 0.0]), + ("Eq", [2.0, 1.0], 1.0, [0.0, 1.0]), + ("Eq", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[1.0, 0.0], [0.0, 0.0]]), + ("Eq", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[1.0, 0.0], [0.0, 0.0]]), + ("Eq", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]), + ("NotEq", 1.0, 2.0, 1.0), + ("NotEq", [1.0, 2.0], [1.0, 2.0], [0.0, 0.0]), + ("NotEq", 1.0, [1.0, 2.0], [0.0, 1.0]), + ("NotEq", [2.0, 1.0], 1.0, [1.0, 0.0]), + ("NotEq", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[0.0, 1.0], [1.0, 1.0]]), + ("NotEq", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[0.0, 1.0], [1.0, 1.0]]), + ("NotEq", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[0.0, 0.0], [0.0, 0.0]]), + ("Lt", 1.0, 2.0, 1.0), + ("Lt", [1.0, 2.0], [1.0, 2.0], [0.0, 0.0]), + ("Lt", 1.0, [1.0, 2.0], [0.0, 1.0]), + ("Lt", [2.0, 1.0], 1.0, [0.0, 0.0]), + ("Lt", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[0.0, 0.0], [0.0, 0.0]]), + ("Lt", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[0.0, 1.0], [1.0, 1.0]]), + ("Lt", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[0.0, 0.0], [0.0, 0.0]]), + ("LtE", 1.0, 2.0, 1.0), + ("LtE", [1.0, 2.0], [1.0, 2.0], [1.0, 1.0]), + ("LtE", 1.0, [1.0, 2.0], [1.0, 1.0]), + ("LtE", [2.0, 1.0], 1.0, [0.0, 1.0]), + ("LtE", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[1.0, 0.0], [0.0, 0.0]]), + ("LtE", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]), + ("LtE", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]), + ("Gt", 1.0, 2.0, 0.0), + ("Gt", [1.0, 2.0], [1.0, 2.0], [0.0, 0.0]), + ("Gt", 1.0, [1.0, 2.0], [0.0, 0.0]), + ("Gt", [2.0, 1.0], 1.0, [1.0, 0.0]), + ("Gt", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[0.0, 1.0], [1.0, 1.0]]), + ("Gt", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[0.0, 0.0], [0.0, 0.0]]), + ("Gt", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[0.0, 0.0], [0.0, 0.0]]), + ("GtE", 1.0, 2.0, 0.0), + ("GtE", [1.0, 2.0], [1.0, 2.0], [1.0, 1.0]), + ("GtE", 1.0, [1.0, 2.0], [1.0, 0.0]), + ("GtE", [2.0, 1.0], 1.0, [1.0, 1.0]), + ("GtE", [[1.0, 2.0], [3.0, 4.0]], 1.0, [[1.0, 1.0], [1.0, 1.0]]), + ("GtE", 1.0, [[1.0, 2.0], [3.0, 4.0]], [[1.0, 0.0], [0.0, 0.0]]), + ("GtE", [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]), + ]) + @pytest.mark.parametrize("bin_execute", ['Python', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), + ]) + @pytest.mark.benchmark(group="Function UDF") + def test_user_def_func_cmpop_numpy(self, op, var1, var2, expected, bin_execute, benchmark): + # we explicitly use np here to ensure that the result is castable to float in the scalar-scalar case + if op == "Eq": + def myFunction(variable, var1, var2): + return np.equal(var1, var2).astype(float) + elif op == "NotEq": + def myFunction(variable, var1, var2): + return np.not_equal(var1, var2).astype(float) + elif op == "Lt": + def myFunction(variable, var1, var2): + return np.less(var1, var2).astype(float) + elif op == "LtE": + def myFunction(variable, var1, var2): + return np.less_equal(var1, var2).astype(float) + elif op == "Gt": + def myFunction(variable, var1, var2): + return np.greater(var1, var2).astype(float) + elif op == "GtE": + def myFunction(variable, var1, var2): + return np.greater_equal(var1, var2).astype(float) + + U = UserDefinedFunction(custom_function=myFunction, default_variable=[0], var1=var1, var2=var2) + if bin_execute == 'LLVM': + e = pnlvm.execution.FuncExecution(U).execute + elif bin_execute == 'PTX': + e = pnlvm.execution.FuncExecution(U).cuda_execute + else: + e = U + val = benchmark(e, [0]) + assert np.allclose(expected, val) class TestUserDefFunc: From 806d64921b5c491bc4acf21501d80fa33ee18500 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 23:53:17 -0500 Subject: [PATCH 14/27] llvm/udf: Add division --- psyneulink/core/llvm/codegen.py | 45 +++++++++++++++++++++++ tests/functions/test_user_defined_func.py | 31 ++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 32d03dc2c83..a2ba90928cd 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -257,6 +257,51 @@ def _mul(x, y): return _mul + def visit_Div(self, node): + def _div_array(ctx, builder, u, v): + assert u.type == v.type + output_ptr = builder.alloca(u.type.pointee) + + for (u_ptr, v_ptr, out_ptr) in helpers.recursive_iterate_arrays(ctx, builder, u, v, output_ptr): + u_val = builder.load(u_ptr) + v_val = builder.load(v_ptr) + builder.store(builder.fdiv(u_val, v_val), out_ptr) + + return output_ptr + + def _div_array_scalar(ctx, builder, array, s): + output_ptr = builder.alloca(array.type.pointee) + helpers.call_elementwise_operation(ctx, builder, array, lambda ctx, builder, x: builder.fdiv(x, s), output_ptr) + + return output_ptr + + def _div_scalar_array(ctx, builder, s, array): + output_ptr = builder.alloca(array.type.pointee) + helpers.call_elementwise_operation(ctx, builder, array, lambda ctx, builder, x: builder.fdiv(s, x), output_ptr) + + return output_ptr + + def _div(x, y): + if helpers.is_floating_point(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, lambda ctx, builder, x, y: builder.fdiv(x, y)) + elif helpers.is_floating_point(x) and (helpers.is_2d_matrix(y) or helpers.is_vector(y)): + return self._generate_binop(x, y, _div_scalar_array) + elif helpers.is_vector(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, _div_array_scalar) + elif helpers.is_2d_matrix(x) and helpers.is_floating_point(y): + return self._generate_binop(x, y, _div_array_scalar) + elif helpers.is_vector(x) and helpers.is_vector(y): + if x.type != y.type: + # Special case: Cast y into scalar if it can be done + if helpers.get_array_shape(y) == [1]: + y = self.builder.gep(y, [self.ctx.int32_ty(0), self.ctx.int32_ty(0)]) + return self._generate_binop(x, y, _div_array_scalar) + return self._generate_binop(x, y, _div_array) + elif helpers.is_2d_matrix(x) and helpers.is_2d_matrix(y): + return self._generate_binop(x, y, _div_array) + assert False, f"Unable to divide arguments {x}, {y}" + return _div + def _generate_unop(self, x, callback): if helpers.is_floating_point(x) and helpers.is_pointer(x): x = self.builder.load(x) diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index aaddce0aa33..b6db2514382 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -70,6 +70,37 @@ def myFunction(_, param1, param2): val = benchmark(e, 0) assert np.allclose(val, param1 * param2) + @pytest.mark.parametrize("param1, param2", [ + (1, 2), + (np.ones(2), 2), + (2, np.ones(2)), + (np.ones((2, 2)), 2), + (2, np.ones((2, 2))), + (np.ones(2), np.array([1, 2])), + (np.ones(2), np.array([2.])), + (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), + ], ids=["scalar-scalar", "vec-scalar", "scalar-vec", "mat-scalar", "scalar-mat", "vec-vec", "vec-vec-differing", "mat-mat"]) + @pytest.mark.parametrize("bin_execute", ['Python', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), + ]) + @pytest.mark.benchmark(group="Function UDF") + def test_user_def_func_div(self, param1, param2, bin_execute, benchmark): + # default val is same shape as expected output + def myFunction(_, param1, param2): + # we only use param1 and param2 to avoid automatic shape changes of the variable + return param1 / param2 + + U = UserDefinedFunction(custom_function=myFunction, param1=param1, param2=param2) + if bin_execute == 'LLVM': + e = pnlvm.execution.FuncExecution(U).execute + elif bin_execute == 'PTX': + e = pnlvm.execution.FuncExecution(U).cuda_execute + else: + e = U + val = benchmark(e, 0) + assert np.allclose(val, np.divide(param1, param2)) + @pytest.mark.parametrize("op", [ # parameter is string since compiled udf doesn't support closures as of present "AND", "OR", From da97c1710c2507e8c1e0cd11fa2ade2b08dab1e3 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 23:54:55 -0500 Subject: [PATCH 15/27] llvm/udf: Remove unused branch --- psyneulink/core/llvm/codegen.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index a2ba90928cd..7d74ee84c92 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -368,8 +368,6 @@ def _assign_target(target, value): self.register[id] = value else: to_store = value - if helpers.is_pointer(value): - to_store = builder.load(value) target = self.visit(target) self.builder.store(to_store, target) From dcdbc15aeba6dfb9530eb61b97be4a88af8035d1 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 30 Nov 2020 23:57:27 -0500 Subject: [PATCH 16/27] llvm/udf: Add `max` --- psyneulink/core/llvm/codegen.py | 13 +++++++++++++ tests/functions/test_user_defined_func.py | 6 ++++++ 2 files changed, 19 insertions(+) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 7d74ee84c92..df0f9b738f8 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -61,11 +61,23 @@ def _exp(x): helpers.call_elementwise_operation(self.ctx, self.builder, x, helpers.exp, output_ptr) return output_ptr + def _max(x): + assert helpers.is_vector(x) or helpers.is_2d_matrix(x), "Attempted to call max on invalid variable! Only 1-d and 2-d lists are supported!" + curr = builder.alloca(ctx.float_ty) + builder.store(ctx.float_ty('NaN'), curr) + for (element_ptr,) in helpers.recursive_iterate_arrays(ctx, builder, x): + element = builder.load(element_ptr) + greater = builder.fcmp_unordered('>', element, builder.load(curr)) + with builder.if_then(greater): + builder.store(element, curr) + return curr + self.register = { "sum": _list_sum, "len": _len, "float": ctx.float_ty, "int": ctx.int32_ty, + "max": _max, } # setup numpy @@ -78,6 +90,7 @@ def _exp(x): 'less_equal': self._generate_fcmp_handler(self.ctx, self.builder, "<="), 'greater': self._generate_fcmp_handler(self.ctx, self.builder, ">"), 'greater_equal': self._generate_fcmp_handler(self.ctx, self.builder, ">="), + "max": _max, } for k, v in func_globals.items(): diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index b6db2514382..dd55ff700d1 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -484,6 +484,9 @@ def myFunction(variable, param1, param2): ("LEN", [1.0, 3.0], 2), ("LEN", [[1.0], [3.0]], 2), ("LEN_TUPLE", [0, 0], 2), + ("MAX", [0.0, 0.0], 0), + ("MAX", [1.0, 2.0], 2), + ("MAX", [[2.0, 1.0], [6.0, 2.0]], 6), ]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), @@ -500,6 +503,9 @@ def myFunction(variable): elif op == "LEN_TUPLE": def myFunction(variable): return len((1,2)) + elif op == "MAX": + def myFunction(variable): + return np.max(variable) U = UserDefinedFunction(custom_function=myFunction, default_variable=variable) if bin_execute == 'LLVM': From 3d9ba45eca2cae4c6253d3b168c5e4adaeb10a43 Mon Sep 17 00:00:00 2001 From: SamKG Date: Tue, 1 Dec 2020 00:50:10 -0500 Subject: [PATCH 17/27] llvm/udf: Add reward function to UDF tests --- tests/functions/test_user_defined_func.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index dd55ff700d1..a36b1b1a8b5 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -372,6 +372,25 @@ def myFunction(variable, param): val = benchmark(e, variable) assert np.allclose(val, -variable) + @pytest.mark.parametrize("bin_execute", ['Python', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), + ]) + @pytest.mark.benchmark(group="Function UDF") + def test_user_def_reward_func(self, bin_execute, benchmark): + variable = [[1,2,3,4]] + def myFunction(x,t0=0.48): + return (x[0][0]>0).astype(float) * (x[0][2]>0).astype(float) / (np.max([x[0][1],x[0][3]]) + t0) + U = UserDefinedFunction(custom_function=myFunction, default_variable=variable, param=variable) + if bin_execute == 'LLVM': + e = pnlvm.execution.FuncExecution(U).execute + elif bin_execute == 'PTX': + e = pnlvm.execution.FuncExecution(U).cuda_execute + else: + e = U + val = benchmark(e, variable) + assert np.allclose(val, 0.2232142857142857) + @pytest.mark.parametrize("dtype, expected", [ # parameter is string since compiled udf doesn't support closures as of present ("SCALAR", 1.0), ("VECTOR", [1,2]), From 01800cfffff32e6202a7f8e383994f9ba4ac6397 Mon Sep 17 00:00:00 2001 From: SamKG Date: Fri, 18 Dec 2020 21:45:03 -0500 Subject: [PATCH 18/27] tests/udf: Add test ids --- tests/functions/test_user_defined_func.py | 40 +++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index a36b1b1a8b5..45bbe47665e 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -11,14 +11,14 @@ class TestBinaryOperations: @pytest.mark.parametrize("param1, param2", [ - (1, 2), # scalar - scalar - (np.ones(2), 2), # vec - scalar - (2, np.ones(2)), # scalar - vec - (np.ones((2, 2)), 2), # mat - scalar - (2, np.ones((2, 2))), # scalar - mat - (np.ones(2), np.array([1, 2])), # vec - vec - (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), # mat - mat - ]) + (1, 2), + (np.ones(2), 2), + (2, np.ones(2)), + (np.ones((2, 2)), 2), + (2, np.ones((2, 2))), + (np.ones(2), np.array([1, 2])), + (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), + ], ids=["scalar-scalar", "vec-scalar", "scalar-vec", "mat-scalar", "scalar-mat", "vec-vec", "mat-mat"]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), @@ -41,14 +41,14 @@ def myFunction(_, param1, param2): assert np.allclose(val, param1 + param2) @pytest.mark.parametrize("param1, param2", [ - (1, 2), # scalar - scalar - (np.ones(2), 2), # vec - scalar - (2, np.ones(2)), # scalar - vec - (np.ones((2, 2)), 2), # mat - scalar - (2, np.ones((2, 2))), # scalar - mat - (np.ones(2), np.array([1, 2])), # vec - vec - (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), # mat - mat - ]) + (1, 2), + (np.ones(2), 2), + (2, np.ones(2)), + (np.ones((2, 2)), 2), + (2, np.ones((2, 2))), + (np.ones(2), np.array([1, 2])), + (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), + ], ids=["scalar-scalar", "vec-scalar", "scalar-vec", "mat-scalar", "scalar-mat", "vec-vec", "mat-mat"]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), @@ -349,10 +349,10 @@ def myFunction(variable): assert np.allclose(val, [[6, 10]]) @pytest.mark.parametrize("variable", [ - (1), # scalar - (np.ones((2))), # vec-2d - (np.ones((2, 2))) # mat - ]) + (1), + (np.ones((2))), + (np.ones((2, 2))) + ], ids=["scalar", "vec-2d", "mat"]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), From 2ea1f7272f7e0f587940fa3aac98ba0cbd1a2bf1 Mon Sep 17 00:00:00 2001 From: SamKG Date: Sun, 20 Dec 2020 21:36:16 -0500 Subject: [PATCH 19/27] llvm/helpers: Mark `call_elementwise_operation` for removal --- psyneulink/core/llvm/helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index ee569ac1d80..ca0401a32e6 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -259,6 +259,7 @@ def recursive_iterate_arrays(ctx, builder, u, *args): else: yield from recursive_iterate_arrays(ctx, b, u_ptr, *arg_ptrs) +# TODO: Remove this function. Can be replaced by `recursive_iterate_arrays` def call_elementwise_operation(ctx, builder, x, operation, output_ptr): """Recurse through an array structure and call operation on each scalar element of the structure. Store result in output_ptr""" for (inp_ptr, out_ptr) in recursive_iterate_arrays(ctx, builder, x, output_ptr): From 97465428c075bd3bf8c4cd5ea60fdb20fe4c0967 Mon Sep 17 00:00:00 2001 From: SamKG Date: Mon, 4 Jan 2021 20:34:23 -0500 Subject: [PATCH 20/27] llvm/udf: Support single element vector multiplication --- psyneulink/core/llvm/codegen.py | 5 +++++ tests/functions/test_user_defined_func.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index df0f9b738f8..5928f3d98b7 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -264,6 +264,11 @@ def _mul(x, y): elif helpers.is_floating_point(x) and helpers.is_2d_matrix(y): return self._generate_binop(y, x, _mul_mat_scalar) elif helpers.is_vector(x) and helpers.is_vector(y): + if x.type != y.type: + # Special case: Cast y into scalar if it can be done + if helpers.get_array_shape(y) == [1]: + y = self.builder.gep(y, [self.ctx.int32_ty(0), self.ctx.int32_ty(0)]) + return self._generate_binop(x, y, _mul_vec_scalar) return self._generate_binop(x, y, _mul_vec) elif helpers.is_2d_matrix(x) and helpers.is_2d_matrix(y): return self._generate_binop(x, y, _mul_mat) diff --git a/tests/functions/test_user_defined_func.py b/tests/functions/test_user_defined_func.py index 45bbe47665e..9573c2cd0ba 100644 --- a/tests/functions/test_user_defined_func.py +++ b/tests/functions/test_user_defined_func.py @@ -47,8 +47,9 @@ def myFunction(_, param1, param2): (np.ones((2, 2)), 2), (2, np.ones((2, 2))), (np.ones(2), np.array([1, 2])), + (np.ones(2), np.array([2.])), (np.ones((2, 2)), np.array([[1, 2], [3, 4]])), - ], ids=["scalar-scalar", "vec-scalar", "scalar-vec", "mat-scalar", "scalar-mat", "vec-vec", "mat-mat"]) + ], ids=["scalar-scalar", "vec-scalar", "scalar-vec", "mat-scalar", "scalar-mat", "vec-vec", "vec-vec-differing", "mat-mat"]) @pytest.mark.parametrize("bin_execute", ['Python', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.param('PTX', marks=[pytest.mark.llvm, pytest.mark.cuda]), From a21a4f09d3c781ea3a2fc939365afa45f17dbd33 Mon Sep 17 00:00:00 2001 From: SamKG Date: Tue, 5 Jan 2021 13:32:18 -0500 Subject: [PATCH 21/27] llvm: Allow for tuples to be created as arrays when possible --- psyneulink/core/llvm/builder_context.py | 4 +++- psyneulink/core/llvm/codegen.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/psyneulink/core/llvm/builder_context.py b/psyneulink/core/llvm/builder_context.py index a76ee79f372..ba89b4dec56 100644 --- a/psyneulink/core/llvm/builder_context.py +++ b/psyneulink/core/llvm/builder_context.py @@ -319,7 +319,9 @@ def convert_python_struct_to_llvm_ir(self, t): return ir.ArrayType(elems_t[0], len(elems_t)) return ir.LiteralStructType(elems_t) elif type(t) is tuple: - elems_t = (self.convert_python_struct_to_llvm_ir(x) for x in t) + elems_t = [self.convert_python_struct_to_llvm_ir(x) for x in t] + if len(elems_t) > 0 and all(x == elems_t[0] for x in elems_t): + return ir.ArrayType(elems_t[0], len(elems_t)) return ir.LiteralStructType(elems_t) elif isinstance(t, enum.Enum): # FIXME: Consider enums of non-int type diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 5928f3d98b7..b9adbbd08e9 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -350,7 +350,7 @@ def visit_Attribute(self, node): # special case numpy attributes if node.attr == "shape": shape = helpers.get_array_shape(val) - return ir.LiteralStructType([self.ctx.float_ty] * len(shape))(shape) + return ir.ArrayType(self.ctx.float_ty, len(shape))(shape) elif node.attr == "astype": def astype(ty): def _convert(ctx, builder, x): @@ -404,7 +404,10 @@ def visit_Tuple(self, node): elements = [self.builder.load(element) if helpers.is_pointer(element) else element for element in elements] element_types = [element.type for element in elements] - ret_list = self.builder.alloca(ir.LiteralStructType(element_types)) + if len(element_types) > 0 and all(x == element_types[0] for x in element_types): + ret_list = self.builder.alloca(ir.ArrayType(element_types[0], len(element_types))) + else: + ret_list = self.builder.alloca(ir.LiteralStructType(element_types)) for idx, element in enumerate(elements): self.builder.store(element, self.builder.gep(ret_list, [self.ctx.int32_ty(0), self.ctx.int32_ty(idx)])) From 232fc32cd37e182e6e298ecfa9559584f2957e96 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Jan 2021 06:06:42 +0000 Subject: [PATCH 22/27] requirements: update numpy requirement from <1.19.5 to <1.19.6 Updates the requirements on [numpy](https://github.com/numpy/numpy) to permit the latest version. - [Release notes](https://github.com/numpy/numpy/releases) - [Changelog](https://github.com/numpy/numpy/blob/master/doc/HOWTO_RELEASE.rst.txt) - [Commits](https://github.com/numpy/numpy/compare/v0.2.0...v1.19.5) Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6c36f01afec..e18a7b7b068 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ grpcio-tools<1.35.0 llvmlite<0.36 matplotlib<3.3.4 networkx<2.6 -numpy<1.19.5 +numpy<1.19.6 pillow<8.1.0 toposort<1.7 torch<1.8.0; sys_platform != 'win32' and platform_machine == 'x86_64' and platform_python_implementation == 'CPython' From 0a28145073c6fc4472637715fa7629b3f537d997 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Thu, 7 Jan 2021 11:56:25 -0500 Subject: [PATCH 23/27] llvm/helpers: Simplify array loop id to avoid illegal characters Fixes compilation failures on arm64/ppc64/pypy in cmpop_numpy tests. Fixes: 0a1fefd6fcfcc2c5c4479046579c353f0daab88b Signed-off-by: Jan Vesely --- psyneulink/core/llvm/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index ca0401a32e6..0752cfbe8df 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -251,7 +251,7 @@ def recursive_iterate_arrays(ctx, builder, u, *args): """Recursively iterates over all elements in scalar arrays of the same shape""" assert isinstance(u.type.pointee, ir.ArrayType), "Can only iterate over arrays!" assert all(len(u.type.pointee) == len(v.type.pointee) for v in args), "Tried to iterate over differing lengths!" - with array_ptr_loop(builder, u, str(u) + "," + str(args) + "_recursive_zip") as (b, idx): + with array_ptr_loop(builder, u, "recursive_iteration") as (b, idx): u_ptr = b.gep(u, [ctx.int32_ty(0), idx]) arg_ptrs = (b.gep(v, [ctx.int32_ty(0), idx]) for v in args) if is_scalar(u_ptr): From d2b7a3b9b8e0d6721afb475a7915bc0290148e1a Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Thu, 7 Jan 2021 16:30:33 -0500 Subject: [PATCH 24/27] travis: Disable ppc64le runs Signed-off-by: Jan Vesely --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9a3aa390d28..29f638cf4d7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,8 @@ dist: bionic arch: - amd64 - arm64 - - ppc64le +# Disabled due to intermittent failures and long running times +# - ppc64le # Disabled until grpcio works with s390x # https://github.com/grpc/grpc/issues/23797 # - s390x From 11eef932c72ca7827b6b628f72cb9872ca0e2ecf Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Fri, 18 Dec 2020 20:57:37 -0500 Subject: [PATCH 25/27] utilities: add contains_type returns True if an iterable contains an item of a certain type in any sub-iterable --- psyneulink/core/globals/utilities.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/psyneulink/core/globals/utilities.py b/psyneulink/core/globals/utilities.py index 0dec1452b8b..fb90ee4d198 100644 --- a/psyneulink/core/globals/utilities.py +++ b/psyneulink/core/globals/utilities.py @@ -132,6 +132,7 @@ 'scalar_distance', 'sinusoid', 'tensor_power', 'TEST_CONDTION', 'type_match', 'underscore_to_camelCase', 'UtilitiesError', 'unproxy_weakproxy', 'create_union_set', 'merge_dictionaries', + 'contains_type' ] logger = logging.getLogger(__name__) @@ -1789,3 +1790,25 @@ def gen_friendly_comma_str(items): divider = f',{divider}' return f"{', '.join(items[:-1])}{divider}{items[-1]}" + + +def contains_type( + arr: collections.Iterable, + typ: typing.Union[type, typing.Tuple[type, ...]] +) -> bool: + """ + Returns: + True if **arr** is a possibly nested Iterable that contains + an instance of **typ** (or one type in **typ** if tuple) + + Note: `isinstance(**arr**, **typ**)` should be used to check + **arr** itself if needed + """ + try: + for a in arr: + if isinstance(a, typ) or (a is not arr and contains_type(a, typ)): + return True + except TypeError: + pass + + return False From 0cbd6e34848cfe1081326b7d13c3c3ead70821bd Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Fri, 18 Dec 2020 21:28:44 -0500 Subject: [PATCH 26/27] StatefulFunction: validate for uninitialized noise functions in list --- .../statefulfunctions/statefulfunction.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/psyneulink/core/components/functions/statefulfunctions/statefulfunction.py b/psyneulink/core/components/functions/statefulfunctions/statefulfunction.py index 78fbbb7e97d..4c2ab566931 100644 --- a/psyneulink/core/components/functions/statefulfunctions/statefulfunction.py +++ b/psyneulink/core/components/functions/statefulfunctions/statefulfunction.py @@ -17,6 +17,7 @@ """ import abc +import collections import typecheck as tc import warnings import numbers @@ -24,12 +25,12 @@ import numpy as np from psyneulink.core import llvm as pnlvm -from psyneulink.core.components.component import DefaultsFlexibility, _has_initializers_setter +from psyneulink.core.components.component import DefaultsFlexibility, _has_initializers_setter, ComponentsMeta from psyneulink.core.components.functions.function import Function_Base, FunctionError from psyneulink.core.components.functions.distributionfunctions import DistributionFunction from psyneulink.core.globals.keywords import STATEFUL_FUNCTION_TYPE, STATEFUL_FUNCTION, NOISE, RATE from psyneulink.core.globals.parameters import Parameter -from psyneulink.core.globals.utilities import parameter_spec, iscompatible, object_has_single_value, convert_to_np_array +from psyneulink.core.globals.utilities import parameter_spec, iscompatible, object_has_single_value, convert_to_np_array, contains_type from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set from psyneulink.core.globals.context import ContextFlags, handle_external_context @@ -196,6 +197,15 @@ class Parameters(Function_Base.Parameters): initializer = Parameter(np.array([0]), pnl_internal=True) has_initializers = Parameter(True, setter=_has_initializers_setter, pnl_internal=True) + def _validate_noise(self, noise): + if ( + isinstance(noise, collections.Iterable) + # assume ComponentsMeta are functions + and contains_type(noise, ComponentsMeta) + ): + # TODO: make this validation unnecessary by handling automatically? + return 'functions in a list must be instantiated and have the desired noise variable shape' + @handle_external_context() @tc.typecheck def __init__(self, From 20eef5b83f65d5273ef98756c89e8566777ac7f7 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Thu, 7 Jan 2021 21:02:47 -0500 Subject: [PATCH 27/27] Log: fix crash when any logged item lacks eid entry of any other item if an execution_id was logged for any component, all components must have logged an entry for that execution_id or Log.nparray and Log.nparray_dictionary would crash --- psyneulink/core/globals/log.py | 11 +++++++++-- tests/log/test_log.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/psyneulink/core/globals/log.py b/psyneulink/core/globals/log.py index 2ac04763ac6..69a20de7e50 100644 --- a/psyneulink/core/globals/log.py +++ b/psyneulink/core/globals/log.py @@ -1401,7 +1401,10 @@ def nparray(self, # If any time values are empty, revert to indexing the entries; # this requires that all entries have the same length else: - max_len = max([len(self.logged_entries[e][eid]) for e in entries]) + try: + max_len = max([len(self.logged_entries[e][eid]) for e in entries]) + except KeyError: + max_len = 0 # If there are no time values, only support entries of the same length # Must dealias both e and zeroth entry because either/both of these could be 'value' @@ -1751,7 +1754,11 @@ def _assemble_entry_data(self, entry, time_values, execution_id=None): # entry = self._dealias_owner_name(entry) row = [] time_col = iter(time_values) - data = self.logged_entries[entry][execution_id] + try: + data = self.logged_entries[entry][execution_id] + except KeyError: + return [None] + time = next(time_col, None) for i in range(len(self.logged_entries[entry][execution_id])): # iterate through log entry tuples: diff --git a/tests/log/test_log.py b/tests/log/test_log.py index 78c05c732aa..30c7017d65a 100644 --- a/tests/log/test_log.py +++ b/tests/log/test_log.py @@ -1002,6 +1002,15 @@ def test_log_multi_calls_single_timestep(self, scheduler_conditions, multi_run): assert log_dict['Run'] == [[0], [0], [0], [1], [1], [1]] assert np.allclose(log_dict['value'], [[[0.52466739, 0.47533261]] * 6]) + def test_log_with_non_full_execution_id_entries(self): + t = pnl.TransferMechanism() + + t.parameters.noise.set(0, context=1, override=True) + t.parameters.value.set(0, context=2, override=True) + + t.log.nparray() + t.log.nparray_dictionary() + class TestClearLog: