Skip to content

Commit d8359f5

Browse files
authored
vulkan: 64-bit im2col (#16135)
* vulkan: 64-bit im2col Add variants of the im2col shaders that use buffer_device_address/buffer_reference, and use 64-bit address calculations. This is needed for large convolutions used in stable-diffusion.cpp. * fix validation error for large im2col
1 parent 6a2c614 commit d8359f5

File tree

6 files changed

+117
-26
lines changed

6 files changed

+117
-26
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ struct vk_device_struct {
408408
bool subgroup_ballot;
409409
bool subgroup_clustered;
410410
bool multi_add;
411+
bool shader_int64;
412+
bool buffer_device_address;
411413

412414
bool add_rms_fusion;
413415
uint32_t partials_binding_alignment;
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
655657
vk::MemoryPropertyFlags memory_property_flags;
656658
void * ptr;
657659
size_t size = 0;
660+
vk::DeviceAddress bda_addr {};
658661

659662
vk_device device;
660663

@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
987990
};
988991

989992
struct vk_op_im2col_push_constants {
993+
uint64_t dst_addr;
990994
uint32_t batch_offset; uint32_t offset_delta;
991995
uint32_t IC;
992996
uint32_t IW; uint32_t IH;
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
10001004
};
10011005

10021006
struct vk_op_im2col_3d_push_constants {
1007+
uint64_t dst_addr;
10031008
uint32_t nb10;
10041009
uint32_t nb11;
10051010
uint32_t nb12;
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20122017
return buf;
20132018
}
20142019

2020+
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2021+
vk::MemoryAllocateFlags mem_flags {};
2022+
if (device->buffer_device_address) {
2023+
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2024+
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2025+
}
2026+
20152027
vk::BufferCreateInfo buffer_create_info{
20162028
vk::BufferCreateFlags(),
20172029
size,
2018-
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
2030+
usage_flags,
20192031
vk::SharingMode::eExclusive,
20202032
0,
20212033
nullptr,
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20272039

20282040
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
20292041

2042+
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2043+
20302044
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
20312045
const auto & req_flags = *it;
20322046

@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20382052
buf->memory_property_flags = req_flags;
20392053

20402054
try {
2041-
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
2055+
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
20422056
break;
20432057
} catch (const vk::SystemError& e) {
20442058
// loop and retry
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20662080
buf->device = device;
20672081
buf->size = size;
20682082

2083+
if (device->buffer_device_address) {
2084+
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2085+
buf->bda_addr = device->device.getBufferAddress(addressInfo);
2086+
}
2087+
20692088
#ifdef GGML_VULKAN_MEMORY_DEBUG
20702089
device->memory_logger->log_allocation(buf, size);
20712090
#endif
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
35323551

35333552
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
35343553

3535-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3536-
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3537-
if (device->float_controls_rte_fp16) {
3538-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3539-
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3554+
#define IM2COL(bda) \
3555+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3556+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3557+
if (device->float_controls_rte_fp16) { \
3558+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3559+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3560+
} else { \
3561+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3562+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3563+
}
3564+
if (device->shader_int64 && device->buffer_device_address) {
3565+
IM2COL(_bda)
35403566
} else {
3541-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3542-
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3567+
IM2COL()
35433568
}
35443569

35453570
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
40174042
device->vendor_id != VK_VENDOR_ID_INTEL &&
40184043
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
40194044

4045+
device->shader_int64 = device_features2.features.shaderInt64;
4046+
device->buffer_device_address = vk12_features.bufferDeviceAddress;
4047+
40204048
if (device->subgroup_size_control) {
40214049
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
40224050
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86358663

86368664
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
86378665
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
8666+
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
8667+
// buffer device address path doesn't use dst buffer
8668+
d_sz = 1;
8669+
}
86388670
// im2col uses only src1 and dst buffers
86398671
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
86408672
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
94869518

94879519
const uint32_t pelements = OW * KW * KH;
94889520

9521+
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9522+
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9523+
9524+
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9525+
94899526
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9527+
dst_addr,
94909528
batch_offset, offset_delta,
94919529
IC, IW, IH, OW, OH, KW, KH,
94929530
pelements,
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
95229560
const int64_t OH = ne2;
95239561
const int64_t OW = ne1;
95249562

9563+
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9564+
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9565+
9566+
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9567+
95259568
vk_op_im2col_3d_push_constants pc {};
95269569

9570+
pc.dst_addr = dst_addr;
95279571
pc.nb10 = nb10 / ggml_type_size(src1->type);
95289572
pc.nb11 = nb11 / ggml_type_size(src1->type);
95299573
pc.nb12 = nb12 / ggml_type_size(src1->type);

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
#include "rte.comp"
77

8+
#include "types.comp"
9+
810
layout (push_constant) uniform parameter
911
{
12+
BDA_STORAGE_T dst_addr;
1013
uint batch_offset; uint offset_delta;
1114
uint IC;
1215
uint IW; uint IH;
@@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
1922
int d0; int d1;
2023
} p;
2124

22-
#include "types.comp"
23-
2425
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
2526

2627
const uint NUM_ITER = 512 / BLOCK_SIZE;
@@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3031
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
3132
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
3233

34+
#if BDA
35+
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
36+
#endif
37+
3338
void main() {
3439
const uint gidx = gl_GlobalInvocationID.x;
3540

@@ -38,7 +43,7 @@ void main() {
3843
const uint ic = gl_GlobalInvocationID.z % p.IC;
3944

4045
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
41-
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
46+
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
4247
const int oh_s1 = int(oh) * p.s1;
4348
const uint ksize = p.OW * p.KH;
4449

@@ -50,7 +55,7 @@ void main() {
5055
uint current_ix = rem % p.OW;
5156

5257
A_TYPE values[NUM_ITER];
53-
uint offset_dst[NUM_ITER];
58+
BDA_OFFSET_T offset_dst[NUM_ITER];
5459
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5560
values[idx] = A_TYPE(0);
5661
}
@@ -66,7 +71,7 @@ void main() {
6671
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
6772
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
6873

69-
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
74+
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
7075

7176
if ((iih < p.IH) && (iiw < p.IW)) {
7277
values[idx] = data_a[src_base + iih * p.IW + iiw];
@@ -89,7 +94,11 @@ void main() {
8994
continue;
9095
}
9196

97+
#if BDA
98+
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
99+
dst_addr.d = D_TYPE(values[idx]);
100+
#else
92101
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
102+
#endif
93103
}
94-
95104
}

ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
#include "rte.comp"
88

9+
#include "types.comp"
10+
911
layout (push_constant) uniform parameter
1012
{
13+
BDA_STORAGE_T dst_addr;
1114
uint32_t nb10;
1215
uint32_t nb11;
1316
uint32_t nb12;
@@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
3841
uint32_t misalign_offsets;
3942
} p;
4043

41-
#include "types.comp"
42-
4344
uint get_aoffset() { return p.misalign_offsets >> 16; }
4445
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
4546

@@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
5051
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
5152
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
5253

54+
#if BDA
55+
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
56+
#endif
57+
5358
void main() {
5459
const uint32_t i = gl_GlobalInvocationID.x;
5560

@@ -100,13 +105,22 @@ void main() {
100105
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
101106
const uint32_t iid = iod * s2 + ikd * d2 - p2;
102107

103-
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
108+
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
104109

110+
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
111+
#if BDA
112+
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
113+
if (iih >= IH || iiw >= IW || iid >= ID) {
114+
dst_addr.d = D_TYPE(0.0f);
115+
} else {
116+
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
117+
}
118+
#else
105119
if (iih >= IH || iiw >= IW || iid >= ID) {
106120
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
107121
} else {
108-
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
109122
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
110123
}
124+
#endif
111125
}
112126
}

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
14471447
return uintBitsToFloat(bits);
14481448
}
14491449

1450+
#if BDA
1451+
1452+
#extension GL_EXT_buffer_reference : enable
1453+
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
1454+
1455+
#define BDA_STORAGE_T uint64_t
1456+
#define BDA_OFFSET_T uint64_t
1457+
1458+
#else
1459+
1460+
#define BDA_STORAGE_T uvec2
1461+
#define BDA_OFFSET_T uint
1462+
1463+
#endif
1464+
14501465
#endif // !defined(GGML_TYPES_COMP)

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -775,13 +775,15 @@ void process_shaders() {
775775
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
776776
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
777777

778-
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
779-
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
780-
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
781-
782-
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
783-
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
784-
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
778+
for (std::string dim_str : {"", "_3d"}) {
779+
for (bool bda : {false, true}) {
780+
std::string bda_str = bda ? "_bda" : "";
781+
std::string bda_def = bda ? "1" : "0";
782+
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
783+
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
784+
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
785+
}
786+
}
785787

786788
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
787789

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5753,6 +5753,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
57535753
}
57545754
}
57555755

5756+
#if 0
5757+
// >4GB im2col destination. Too slow to run by default.
5758+
// Test cases taken from Wan2.1 T2V 1.3B.
5759+
test_cases.emplace_back(new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true));
5760+
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96}, {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false));
5761+
#endif
5762+
57565763
// im2col 1D
57575764
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
57585765
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));

0 commit comments

Comments
 (0)