Skip to content

Commit c7dfe3d

Browse files
authored
ggml : ggml_compute_forward_concat() for arbitrary tensor type (#1118)
* ggml_compute_forward_concat() for arbitrary tensor type * Check that tensors' type match * ggml-cpu.c: check type of source tensors * ggml-cpu.c: move tensor type check to ggml_compute_forward_concat() * ggml.c: check concatenated tensor type * Remove tensor type check from ggml_compute_forward_concat() in ggml-cpu.c ..., as it was moved to ggml.c.
1 parent c6a6a64 commit c7dfe3d

File tree

2 files changed

+142
-2
lines changed

2 files changed

+142
-2
lines changed

src/ggml-cpu/ggml-cpu.c

+141-2
Original file line numberDiff line numberDiff line change
@@ -6648,14 +6648,143 @@ static void ggml_compute_forward_repeat_back(
66486648

66496649
// ggml_compute_forward_concat
66506650

6651+
static void ggml_compute_forward_concat_any(
6652+
const struct ggml_compute_params * params,
6653+
struct ggml_tensor * dst) {
6654+
6655+
const struct ggml_tensor * src0 = dst->src[0];
6656+
const struct ggml_tensor * src1 = dst->src[1];
6657+
6658+
const size_t len = ggml_type_size(src0->type);
6659+
6660+
const int ith = params->ith;
6661+
const int nth = params->nth;
6662+
6663+
GGML_TENSOR_BINARY_OP_LOCALS
6664+
6665+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
6666+
6667+
GGML_ASSERT(dim >= 0 && dim < 4);
6668+
6669+
int64_t o[4] = {0, 0, 0, 0};
6670+
o[dim] = src0->ne[dim];
6671+
6672+
const char * x;
6673+
6674+
// TODO: smarter multi-theading
6675+
for (int i3 = 0; i3 < ne3; i3++) {
6676+
for (int i2 = ith; i2 < ne2; i2 += nth) {
6677+
for (int i1 = 0; i1 < ne1; i1++) {
6678+
for (int i0 = 0; i0 < ne0; i0++) {
6679+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6680+
x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
6681+
} else {
6682+
x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
6683+
}
6684+
6685+
char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
6686+
6687+
memcpy(y, x, len);
6688+
}
6689+
}
6690+
}
6691+
}
6692+
}
6693+
6694+
static void ggml_compute_forward_concat_i8(
6695+
const struct ggml_compute_params * params,
6696+
struct ggml_tensor * dst) {
6697+
6698+
const struct ggml_tensor * src0 = dst->src[0];
6699+
const struct ggml_tensor * src1 = dst->src[1];
6700+
6701+
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
6702+
6703+
const int ith = params->ith;
6704+
const int nth = params->nth;
6705+
6706+
GGML_TENSOR_BINARY_OP_LOCALS
6707+
6708+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
6709+
6710+
GGML_ASSERT(dim >= 0 && dim < 4);
6711+
6712+
int64_t o[4] = {0, 0, 0, 0};
6713+
o[dim] = src0->ne[dim];
6714+
6715+
const int8_t * x;
6716+
6717+
// TODO: smarter multi-theading
6718+
for (int i3 = 0; i3 < ne3; i3++) {
6719+
for (int i2 = ith; i2 < ne2; i2 += nth) {
6720+
for (int i1 = 0; i1 < ne1; i1++) {
6721+
for (int i0 = 0; i0 < ne0; i0++) {
6722+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6723+
x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6724+
} else {
6725+
x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6726+
}
6727+
6728+
int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6729+
6730+
*y = *x;
6731+
}
6732+
}
6733+
}
6734+
}
6735+
}
6736+
6737+
static void ggml_compute_forward_concat_f16(
6738+
const struct ggml_compute_params * params,
6739+
struct ggml_tensor * dst) {
6740+
6741+
const struct ggml_tensor * src0 = dst->src[0];
6742+
const struct ggml_tensor * src1 = dst->src[1];
6743+
6744+
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
6745+
6746+
const int ith = params->ith;
6747+
const int nth = params->nth;
6748+
6749+
GGML_TENSOR_BINARY_OP_LOCALS
6750+
6751+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
6752+
6753+
GGML_ASSERT(dim >= 0 && dim < 4);
6754+
6755+
int64_t o[4] = {0, 0, 0, 0};
6756+
o[dim] = src0->ne[dim];
6757+
6758+
const ggml_fp16_t * x;
6759+
6760+
// TODO: smarter multi-theading
6761+
for (int i3 = 0; i3 < ne3; i3++) {
6762+
for (int i2 = ith; i2 < ne2; i2 += nth) {
6763+
for (int i1 = 0; i1 < ne1; i1++) {
6764+
for (int i0 = 0; i0 < ne0; i0++) {
6765+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6766+
x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6767+
} else {
6768+
x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6769+
}
6770+
6771+
ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6772+
6773+
*y = *x;
6774+
}
6775+
}
6776+
}
6777+
}
6778+
}
6779+
66516780
static void ggml_compute_forward_concat_f32(
66526781
const struct ggml_compute_params * params,
66536782
struct ggml_tensor * dst) {
66546783

66556784
const struct ggml_tensor * src0 = dst->src[0];
66566785
const struct ggml_tensor * src1 = dst->src[1];
66576786

6658-
GGML_ASSERT(src0->nb[0] == sizeof(float));
6787+
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
66596788

66606789
const int ith = params->ith;
66616790
const int nth = params->nth;
@@ -6698,14 +6827,24 @@ static void ggml_compute_forward_concat(
66986827
const struct ggml_tensor * src0 = dst->src[0];
66996828

67006829
switch (src0->type) {
6830+
case GGML_TYPE_F16:
6831+
case GGML_TYPE_BF16:
6832+
case GGML_TYPE_I16:
6833+
{
6834+
ggml_compute_forward_concat_f16(params, dst);
6835+
} break;
6836+
case GGML_TYPE_I8:
6837+
{
6838+
ggml_compute_forward_concat_i8(params, dst);
6839+
} break;
67016840
case GGML_TYPE_F32:
67026841
case GGML_TYPE_I32:
67036842
{
67046843
ggml_compute_forward_concat_f32(params, dst);
67056844
} break;
67066845
default:
67076846
{
6708-
GGML_ABORT("fatal error");
6847+
ggml_compute_forward_concat_any(params, dst);
67096848
}
67106849
}
67116850
}

src/ggml.c

+1
Original file line numberDiff line numberDiff line change
@@ -2332,6 +2332,7 @@ struct ggml_tensor * ggml_concat(
23322332
struct ggml_tensor * b,
23332333
int dim) {
23342334
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2335+
GGML_ASSERT(a->type == b->type);
23352336

23362337
int64_t ne[GGML_MAX_DIMS];
23372338
for (int d = 0; d < GGML_MAX_DIMS; ++d) {

0 commit comments

Comments
 (0)