Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[samples] Speed up simulation by importing buffers #105

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/samples/nsnet2/nsnet2_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ int run_nsnet2_experiment(
iree_hal_executable_library_query_fn_t implementation) {
if (!snrt_is_dm_core()) return quidditch_dispatch_enter_worker_loop();

double(*data)[161] = malloc(161 * sizeof(double));
double(*data)[161] = aligned_alloc(64, 161 * sizeof(double));

for (int i = 0; i < IREE_ARRAYSIZE(*data); i++) {
(*data)[i] = (i + 1);
Expand Down
30 changes: 19 additions & 11 deletions runtime/samples/util/run_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ iree_status_t run_model(const model_config_t* config) {
if (!iree_status_is_ok(result)) goto error_release_context;

for (iree_host_size_t i = 0; i < config->num_inputs; i++) {
iree_const_byte_span_t span = iree_make_const_byte_span(
config->input_data[i],
config->input_sizes[i] *
iree_hal_element_dense_byte_count(config->element_type));
iree_hal_external_buffer_t external_buffer = {
.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION,
.flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE,
.size = config->input_sizes[i] *
iree_hal_element_dense_byte_count(config->element_type),
.handle.host_allocation = {(void*)config->input_data[i]},
};

iree_hal_buffer_params_t params = {
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE,
Expand All @@ -103,15 +106,20 @@ iree_status_t run_model(const model_config_t* config) {
};
iree_hal_buffer_params_canonicalize(&params);

iree_hal_buffer_view_t* buffer = NULL;
result = iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), config->input_ranks[i],
config->input_shapes[i], config->element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, params, span, &buffer);
if (!iree_status_is_ok(result)) goto error_release_context;
iree_hal_buffer_t* buffer = NULL;
IREE_CHECK_OK(iree_hal_allocator_import_buffer(
iree_hal_device_allocator(device), params, &external_buffer,
iree_hal_buffer_release_callback_null(), &buffer));

iree_hal_buffer_view_t* buffer_view = NULL;
IREE_CHECK_OK(iree_hal_buffer_view_create(
buffer, config->input_ranks[i], config->input_shapes[i],
config->element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
host_allocator, &buffer_view));
iree_hal_buffer_release(buffer);

iree_vm_ref_t arg_buffer_view_ref;
arg_buffer_view_ref = iree_hal_buffer_view_move_ref(buffer);
arg_buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
result = iree_vm_list_push_ref_retain(inputs, &arg_buffer_view_ref);
if (!iree_status_is_ok(result)) goto error_release_context;
}
Expand Down
3 changes: 2 additions & 1 deletion runtime/samples/util/run_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ typedef struct {

/// Number of input tensors.
iree_host_size_t num_inputs;
/// Input tensor data in dense row major encoding.
/// Input tensor data in dense row major encoding. Must be aligned to 64
/// bytes.
const void** input_data;
/// Number of elements for each input in 'input_data'.
const iree_host_size_t* input_sizes;
Expand Down
2 changes: 1 addition & 1 deletion runtime/samples/vec_multiply/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <util/run_model.h>

int main() {
double data[4];
iree_alignas(64) double data[4];
if (!snrt_is_dm_core()) return quidditch_dispatch_enter_worker_loop();

for (int i = 0; i < IREE_ARRAYSIZE(data); i++) {
Expand Down