@@ -408,6 +408,8 @@ struct vk_device_struct {
408
408
bool subgroup_ballot;
409
409
bool subgroup_clustered;
410
410
bool multi_add;
411
+ bool shader_int64;
412
+ bool buffer_device_address;
411
413
412
414
bool add_rms_fusion;
413
415
uint32_t partials_binding_alignment;
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
655
657
vk::MemoryPropertyFlags memory_property_flags;
656
658
void * ptr;
657
659
size_t size = 0;
660
+ vk::DeviceAddress bda_addr {};
658
661
659
662
vk_device device;
660
663
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
987
990
};
988
991
989
992
struct vk_op_im2col_push_constants {
993
+ uint64_t dst_addr;
990
994
uint32_t batch_offset; uint32_t offset_delta;
991
995
uint32_t IC;
992
996
uint32_t IW; uint32_t IH;
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
1000
1004
};
1001
1005
1002
1006
struct vk_op_im2col_3d_push_constants {
1007
+ uint64_t dst_addr;
1003
1008
uint32_t nb10;
1004
1009
uint32_t nb11;
1005
1010
uint32_t nb12;
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
2012
2017
return buf;
2013
2018
}
2014
2019
2020
+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2021
+ vk::MemoryAllocateFlags mem_flags {};
2022
+ if (device->buffer_device_address) {
2023
+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2024
+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2025
+ }
2026
+
2015
2027
vk::BufferCreateInfo buffer_create_info{
2016
2028
vk::BufferCreateFlags(),
2017
2029
size,
2018
- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst ,
2030
+ usage_flags ,
2019
2031
vk::SharingMode::eExclusive,
2020
2032
0,
2021
2033
nullptr,
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
2027
2039
2028
2040
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
2029
2041
2042
+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2043
+
2030
2044
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
2031
2045
const auto & req_flags = *it;
2032
2046
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
2038
2052
buf->memory_property_flags = req_flags;
2039
2053
2040
2054
try {
2041
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
2055
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
2042
2056
break;
2043
2057
} catch (const vk::SystemError& e) {
2044
2058
// loop and retry
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
2066
2080
buf->device = device;
2067
2081
buf->size = size;
2068
2082
2083
+ if (device->buffer_device_address) {
2084
+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2085
+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2086
+ }
2087
+
2069
2088
#ifdef GGML_VULKAN_MEMORY_DEBUG
2070
2089
device->memory_logger->log_allocation(buf, size);
2071
2090
#endif
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
3532
3551
3533
3552
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3534
3553
3535
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3536
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3537
- if (device->float_controls_rte_fp16) {
3538
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3539
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3554
+ #define IM2COL(bda) \
3555
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3556
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3557
+ if (device->float_controls_rte_fp16) { \
3558
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3559
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3560
+ } else { \
3561
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3562
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3563
+ }
3564
+ if (device->shader_int64 && device->buffer_device_address) {
3565
+ IM2COL(_bda)
3540
3566
} else {
3541
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3542
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3567
+ IM2COL()
3543
3568
}
3544
3569
3545
3570
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
4017
4042
device->vendor_id != VK_VENDOR_ID_INTEL &&
4018
4043
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
4019
4044
4045
+ device->shader_int64 = device_features2.features.shaderInt64;
4046
+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
4047
+
4020
4048
if (device->subgroup_size_control) {
4021
4049
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
4022
4050
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8635
8663
8636
8664
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8637
8665
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
8666
+ if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
8667
+ // buffer device address path doesn't use dst buffer
8668
+ d_sz = 1;
8669
+ }
8638
8670
// im2col uses only src1 and dst buffers
8639
8671
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8640
8672
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
9486
9518
9487
9519
const uint32_t pelements = OW * KW * KH;
9488
9520
9521
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9522
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9523
+
9524
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9525
+
9489
9526
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9527
+ dst_addr,
9490
9528
batch_offset, offset_delta,
9491
9529
IC, IW, IH, OW, OH, KW, KH,
9492
9530
pelements,
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
9522
9560
const int64_t OH = ne2;
9523
9561
const int64_t OW = ne1;
9524
9562
9563
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9564
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9565
+
9566
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9567
+
9525
9568
vk_op_im2col_3d_push_constants pc {};
9526
9569
9570
+ pc.dst_addr = dst_addr;
9527
9571
pc.nb10 = nb10 / ggml_type_size(src1->type);
9528
9572
pc.nb11 = nb11 / ggml_type_size(src1->type);
9529
9573
pc.nb12 = nb12 / ggml_type_size(src1->type);
0 commit comments