Skip to content

Commit

Permalink
Allows passing struct by value, utilizing constant memory on the device
Browse files Browse the repository at this point in the history
  • Loading branch information
rosenrodt committed May 9, 2019
1 parent 36c8913 commit b040362
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/boost/compute/types/struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ inline std::string adapt_struct_insert_member(T (Struct::*)[N], const char *name
) \
<< "}"; \
} \
template<> \
struct set_kernel_arg<type> \
{ \
void operator()(kernel &kernel_, size_t index, const type &c) \
{ \
kernel_.set_arg(index, sizeof(type), &c); \
} \
}; \
}}}

#endif // BOOST_COMPUTE_TYPES_STRUCT_HPP
33 changes: 33 additions & 0 deletions test/test_struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <boost/compute/type_traits/type_definition.hpp>
#include <boost/compute/utility/source.hpp>

#include "check_macros.hpp"

namespace compute = boost::compute;

// example code defining an atom class
Expand Down Expand Up @@ -131,6 +133,37 @@ BOOST_AUTO_TEST_CASE(custom_kernel)
queue.enqueue_1d_range_kernel(custom_kernel, 0, atoms.size(), 1);
}

BOOST_AUTO_TEST_CASE(custom_kernel_set_struct_by_value)
{
std::string source = BOOST_COMPUTE_STRINGIZE_SOURCE(
__kernel void custom_kernel(Atom atom,
__global float *position,
__global int *number)
{
position[0] = atom.x;
position[1] = atom.y;
position[2] = atom.z;
number[0] = atom.number;
}
);
source = compute::type_definition<chemistry::Atom>() + "\n" + source;
compute::program program =
compute::program::build_with_source(source, context);
compute::kernel custom_kernel = program.create_kernel("custom_kernel");

chemistry::Atom atom(1.0f, 2.0f, 3.0f, 4);
compute::vector<float> position(3);
compute::vector<int> number(1);

custom_kernel.set_arg(0, atom);
custom_kernel.set_arg(1, position);
custom_kernel.set_arg(2, number);
queue.enqueue_task(custom_kernel);

CHECK_RANGE_EQUAL(float, 3, position, (1.0f, 2.0f, 3.0f));
BOOST_CHECK_EQUAL(number.at(0), 4);
}

// Creates a StructWithArray containing 'x', 'y', 'z'.
StructWithArray make_struct_with_array(int x, int y, int z)
{
Expand Down

0 comments on commit b040362

Please sign in to comment.