Skip to content

Commit

Permalink
port eddie's stuff into HAL
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Sep 12, 2024
1 parent e32909b commit cb8097f
Show file tree
Hide file tree
Showing 10 changed files with 418 additions and 382 deletions.
107 changes: 70 additions & 37 deletions runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,11 @@
#include "iree/base/api.h"
#include "iree/base/tracing.h"

typedef struct iree_hal_hsa_allocator_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_resource_t resource;

hsa_agent_t hsa_agent;

hsa_agent_t cpu_agent;
hsa_amd_memory_pool_t cpu_pool;

// One memory pool and region for now
hsa_amd_memory_pool_t buffers_pool;
hsa_region_t kernel_argument_pool;

const iree_hal_hsa_dynamic_symbols_t* symbols;

iree_allocator_t host_allocator;

// Whether the GPU and CPU can concurrently access HSA managed data in a
// coherent way. We would need to explicitly perform flushing and invalidation
// between GPU and CPU if not.
bool supports_concurrent_managed_access;

IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
} iree_hal_hsa_allocator_t;

namespace {
extern const iree_hal_allocator_vtable_t iree_hal_hsa_allocator_vtable;
}

static iree_hal_hsa_allocator_t* iree_hal_hsa_allocator_cast(
iree_hal_hsa_allocator_t* iree_hal_hsa_allocator_cast(
iree_hal_allocator_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_allocator_vtable);
return (iree_hal_hsa_allocator_t*)base_value;
Expand Down Expand Up @@ -160,6 +134,51 @@ static hsa_status_t iterate_find_cpu_agent_pool_callback(
return HSA_STATUS_SUCCESS;
}

hsa_status_t get_coarse_global_mem_pool(hsa_amd_memory_pool_t pool, void* data,
bool kernarg) {
hsa_amd_segment_t segment_type;
auto ret = hsa_amd_memory_pool_get_info(
pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT, &segment_type);
if (ret != HSA_STATUS_SUCCESS) {
return ret;
}

if (segment_type == HSA_AMD_SEGMENT_GLOBAL) {
hsa_amd_memory_pool_global_flag_t global_pool_flags;
ret = hsa_amd_memory_pool_get_info(
pool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &global_pool_flags);
if (ret != HSA_STATUS_SUCCESS) {
return ret;
}

if (kernarg) {
if ((global_pool_flags &
HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED) &&
(global_pool_flags & HSA_REGION_GLOBAL_FLAG_KERNARG)) {
*static_cast<hsa_amd_memory_pool_t*>(data) = pool;
}
} else {
if ((global_pool_flags &
HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED) &&
!(global_pool_flags & HSA_REGION_GLOBAL_FLAG_KERNARG)) {
*static_cast<hsa_amd_memory_pool_t*>(data) = pool;
}
}
}

return HSA_STATUS_SUCCESS;
}

hsa_status_t get_coarse_global_dev_mem_pool(hsa_amd_memory_pool_t pool,
void* data) {
return get_coarse_global_mem_pool(pool, data, false);
}

hsa_status_t get_coarse_global_kernarg_mem_pool(hsa_amd_memory_pool_t pool,
void* data) {
return get_coarse_global_mem_pool(pool, data, true);
}

iree_status_t iree_hal_hsa_allocator_create(
const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, hsa_agent_t agent,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
Expand Down Expand Up @@ -193,16 +212,30 @@ iree_status_t iree_hal_hsa_allocator_create(
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;

hsa_symbols->hsa_agent_iterate_regions(agent, get_kernarg_memory_region,
allocator);
hsa_symbols->hsa_amd_agent_iterate_memory_pools(
agent, get_fine_grained_memory_pool, allocator);

hsa_symbols->hsa_iterate_agents(&iterate_find_cpu_agent_callback,
(void*)allocator);
hsa_symbols->hsa_amd_agent_iterate_memory_pools(
allocator->cpu_agent, &iterate_find_cpu_agent_pool_callback,
(void*)allocator);
// hsa_symbols->hsa_agent_iterate_regions(agent, get_kernarg_memory_region,
// allocator);
// hsa_symbols->hsa_amd_agent_iterate_memory_pools(
// agent, get_fine_grained_memory_pool, allocator);
//
// hsa_symbols->hsa_iterate_agents(&iterate_find_cpu_agent_callback,
// (void*)allocator);
// hsa_symbols->hsa_amd_agent_iterate_memory_pools(
// allocator->cpu_agent, &iterate_find_cpu_agent_pool_callback,
// (void*)allocator);

// Find a pool for DEV BOs. This is a global system memory pool that is
// mapped to the device. Will be used for PDIs and DPU instructions.
hsa_status_t r = hsa_symbols->hsa_amd_agent_iterate_memory_pools(
agent, get_coarse_global_dev_mem_pool, &allocator->cpu_pool);
assert(r == HSA_STATUS_SUCCESS);

// Find a pool that supports kernel args. This is just normal system memory.
// It will be used for commands and input data.
r = hsa_symbols->hsa_amd_agent_iterate_memory_pools(
agent, get_coarse_global_kernarg_mem_pool,
&allocator->kernel_argument_pool);
assert(r == HSA_STATUS_SUCCESS);
assert(allocator->kernel_argument_pool.handle);

*out_allocator = (iree_hal_allocator_t*)allocator;

Expand Down
31 changes: 30 additions & 1 deletion runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,37 @@ iree_status_t iree_hal_hsa_allocator_create(
const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, hsa_agent_t agent,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);

struct iree_hal_hsa_allocator_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_resource_t resource;

hsa_agent_t hsa_agent;
hsa_agent_t cpu_agent;
hsa_agent_t aie_agent;
hsa_amd_memory_pool_t cpu_pool;

// One memory pool and region for now
hsa_amd_memory_pool_t buffers_pool;
hsa_region_t kernel_argument_pool;

const iree_hal_hsa_dynamic_symbols_t* symbols;

iree_allocator_t host_allocator;

// Whether the GPU and CPU can concurrently access HSA managed data in a
// coherent way. We would need to explicitly perform flushing and invalidation
// between GPU and CPU if not.
bool supports_concurrent_managed_access;

IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
};

iree_hal_hsa_allocator_t* iree_hal_hsa_allocator_cast(
iree_hal_allocator_t* base_value);

#ifdef __cplusplus
} // extern "C"
} // extern "C"
#endif // __cplusplus

#endif // IREE_EXPERIMENTAL_HSA_ALLOCATOR_H_
59 changes: 31 additions & 28 deletions runtime/src/iree-amd-aie/driver/hsa/hsa_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ iree_status_t iree_hal_hsa_device_create(

iree_status_t status = iree_hal_hsa_device_check_params(params);

size_t num_queue_packets = 1024;
hsa_queue_type_t queue_type = HSA_QUEUE_TYPE_MULTI;
size_t num_queue_packets = 64;
hsa_queue_type_t queue_type = HSA_QUEUE_TYPE_SINGLE;
void (*callback)(hsa_status_t, hsa_queue_t*, void*) = nullptr;
void* data = nullptr;
uint32_t private_segment_size = 0;
Expand All @@ -181,36 +181,39 @@ iree_status_t iree_hal_hsa_device_create(
agent, dispatch_queue, symbols,
host_allocator, out_device);

iree_event_pool_t* host_event_pool = nullptr;
if (iree_status_is_ok(status)) {
status = iree_event_pool_allocate(params->event_pool_capacity,
host_allocator, &host_event_pool);
}

iree_hal_hsa_event_pool_t* device_event_pool = nullptr;
if (iree_status_is_ok(status)) {
status =
iree_hal_hsa_event_pool_allocate(symbols, params->event_pool_capacity,
host_allocator, &device_event_pool);
}

iree_hal_hsa_timepoint_pool_t* timepoint_pool = nullptr;
if (iree_status_is_ok(status)) {
status = iree_hal_hsa_timepoint_pool_allocate(
host_event_pool, device_event_pool, params->event_pool_capacity,
host_allocator, &timepoint_pool);
}

// iree_event_pool_t* host_event_pool = nullptr;
// if (iree_status_is_ok(status)) {
// status = iree_event_pool_allocate(params->event_pool_capacity,
// host_allocator, &host_event_pool);
// }
//
// iree_hal_hsa_event_pool_t* device_event_pool = nullptr;
// if (iree_status_is_ok(status)) {
// status =
// iree_hal_hsa_event_pool_allocate(symbols,
// params->event_pool_capacity,
// host_allocator,
// &device_event_pool);
// }
//
// iree_hal_hsa_timepoint_pool_t* timepoint_pool = nullptr;
// if (iree_status_is_ok(status)) {
// status = iree_hal_hsa_timepoint_pool_allocate(
// host_event_pool, device_event_pool, params->event_pool_capacity,
// host_allocator, &timepoint_pool);
// }
//
if (iree_status_is_ok(status)) {
iree_hal_hsa_device_t* hsa_device = iree_hal_hsa_device_cast(*out_device);
hsa_device->host_event_pool = host_event_pool;
hsa_device->device_event_pool = device_event_pool;
hsa_device->timepoint_pool = timepoint_pool;
// hsa_device->host_event_pool = host_event_pool;
// hsa_device->device_event_pool = device_event_pool;
// hsa_device->timepoint_pool = timepoint_pool;
} else {
// Release resources we have accquired after HAL device creation.
if (timepoint_pool) iree_hal_hsa_timepoint_pool_free(timepoint_pool);
if (device_event_pool) iree_hal_hsa_event_pool_release(device_event_pool);
if (host_event_pool) iree_event_pool_free(host_event_pool);
// if (timepoint_pool) iree_hal_hsa_timepoint_pool_free(timepoint_pool);
// if (device_event_pool)
// iree_hal_hsa_event_pool_release(device_event_pool); if
// (host_event_pool) iree_event_pool_free(host_event_pool);
// Release other resources via the HAL device.
iree_hal_device_release(*out_device);
}
Expand Down
55 changes: 51 additions & 4 deletions runtime/src/iree-amd-aie/driver/hsa/hsa_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <cstdint>
#include <cstring>
#include <vector>

#include "iree-amd-aie/driver/hsa/api.h"
#include "iree-amd-aie/driver/hsa/dynamic_symbols.h"
Expand Down Expand Up @@ -47,6 +48,9 @@ typedef struct iree_hal_hsa_driver_t {
// Number of GPU agents
int num_gpu_agents;

// Number of AIE agents
int num_aie_agents;

// IREE device ID to hsa_agent_t
hsa_agent_t agents[IREE_HAL_HSA_MAX_DEVICES];
} iree_hal_hsa_driver_t;
Expand Down Expand Up @@ -119,6 +123,49 @@ hsa_status_t iterate_populate_gpu_agents_callback(hsa_agent_t agent,
return HSA_STATUS_SUCCESS;
}

hsa_status_t iterate_count_aie_agents_callback(hsa_agent_t agent,
void* base_driver) {
iree_hal_hsa_callback_package_t* package =
(iree_hal_hsa_callback_package_t*)(base_driver);
iree_hal_hsa_driver_t* driver = package->driver;
int* count_ptr = (int*)package->return_value;
hsa_device_type_t type;
hsa_status_t status =
(&(driver->hsa_symbols))
->hsa_agent_get_info(agent, HSA_AGENT_INFO_DEVICE, &type);
if (status != HSA_STATUS_SUCCESS) {
return status;
}
if (type == HSA_DEVICE_TYPE_AIE) {
*count_ptr = *count_ptr + 1;
}
return HSA_STATUS_SUCCESS;
}

hsa_status_t iterate_populate_aie_agents_callback(hsa_agent_t agent,
void* base_driver) {
iree_hal_hsa_callback_package_t* package =
(iree_hal_hsa_callback_package_t*)(base_driver);
iree_hal_hsa_driver_t* driver = package->driver;
size_t* index_ptr = package->index;
hsa_agent_t* agents_ptr = (hsa_agent_t*)package->return_value;

hsa_device_type_t type;
hsa_status_t status =
(&(driver->hsa_symbols))
->hsa_agent_get_info(agent, HSA_AGENT_INFO_DEVICE, &type);
if (status != HSA_STATUS_SUCCESS) {
return status;
}

if (type == HSA_DEVICE_TYPE_AIE) {
size_t current_index = *index_ptr;
agents_ptr[current_index] = agent;
*index_ptr = current_index + 1;
}
return HSA_STATUS_SUCCESS;
}

// Initializes the HSA system.
iree_status_t iree_hal_hsa_init(iree_hal_hsa_driver_t* driver) {
IREE_TRACE_ZONE_BEGIN(z0);
Expand Down Expand Up @@ -160,16 +207,16 @@ static iree_status_t iree_hal_hsa_driver_create_internal(

memcpy(&driver->device_params, device_params, sizeof(driver->device_params));

driver->num_gpu_agents = 0;
driver->num_aie_agents = 0;

// Populate HSA agents
// Query the number of available HSA devices.
iree_hal_hsa_callback_package_t symbols_and_device_count = {
.driver = driver, .return_value = &driver->num_gpu_agents};
.driver = driver, .return_value = &driver->num_aie_agents};

IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR(
z0, &driver->hsa_symbols,
hsa_iterate_agents(&iterate_count_gpu_agents_callback,
hsa_iterate_agents(&iterate_count_aie_agents_callback,
&symbols_and_device_count),
"hsa_iterate_agents");

Expand All @@ -179,7 +226,7 @@ static iree_status_t iree_hal_hsa_driver_create_internal(

IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR(
z0, &driver->hsa_symbols,
hsa_iterate_agents(&iterate_populate_gpu_agents_callback,
hsa_iterate_agents(&iterate_populate_aie_agents_callback,
&symbols_and_agents),
"hsa_iterate_agents");

Expand Down
Loading

0 comments on commit cb8097f

Please sign in to comment.