diff --git a/include/boost/compute/types/struct.hpp b/include/boost/compute/types/struct.hpp index 92aeaedf2..85dddddcd 100644 --- a/include/boost/compute/types/struct.hpp +++ b/include/boost/compute/types/struct.hpp @@ -168,6 +168,14 @@ inline std::string adapt_struct_insert_member(T (Struct::*)[N], const char *name ) \ << "}"; \ } \ + template<> \ + struct set_kernel_arg \ + { \ + void operator()(kernel &kernel_, size_t index, const type &c) \ + { \ + kernel_.set_arg(index, sizeof(type), &c); \ + } \ + }; \ }}} #endif // BOOST_COMPUTE_TYPES_STRUCT_HPP diff --git a/test/test_struct.cpp b/test/test_struct.cpp index d7a7b6505..325a7b76e 100644 --- a/test/test_struct.cpp +++ b/test/test_struct.cpp @@ -23,6 +23,8 @@ #include #include +#include "check_macros.hpp" + namespace compute = boost::compute; // example code defining an atom class @@ -131,6 +133,37 @@ BOOST_AUTO_TEST_CASE(custom_kernel) queue.enqueue_1d_range_kernel(custom_kernel, 0, atoms.size(), 1); } +BOOST_AUTO_TEST_CASE(custom_kernel_set_struct_by_value) +{ + std::string source = BOOST_COMPUTE_STRINGIZE_SOURCE( + __kernel void custom_kernel(Atom atom, + __global float *position, + __global int *number) + { + position[0] = atom.x; + position[1] = atom.y; + position[2] = atom.z; + number[0] = atom.number; + } + ); + source = compute::type_definition() + "\n" + source; + compute::program program = + compute::program::build_with_source(source, context); + compute::kernel custom_kernel = program.create_kernel("custom_kernel"); + + chemistry::Atom atom(1.0f, 2.0f, 3.0f, 4); + compute::vector position(3); + compute::vector number(1); + + custom_kernel.set_arg(0, atom); + custom_kernel.set_arg(1, position); + custom_kernel.set_arg(2, number); + queue.enqueue_task(custom_kernel); + + CHECK_RANGE_EQUAL(float, 3, position, (1.0f, 2.0f, 3.0f)); + CHECK_RANGE_EQUAL(int, 1, number, (4)); +} + // Creates a StructWithArray containing 'x', 'y', 'z'. StructWithArray make_struct_with_array(int x, int y, int z) {