@@ -6648,14 +6648,143 @@ static void ggml_compute_forward_repeat_back(
6648
6648
6649
6649
// ggml_compute_forward_concat
6650
6650
6651
+ static void ggml_compute_forward_concat_any(
6652
+ const struct ggml_compute_params * params,
6653
+ struct ggml_tensor * dst) {
6654
+
6655
+ const struct ggml_tensor * src0 = dst->src[0];
6656
+ const struct ggml_tensor * src1 = dst->src[1];
6657
+
6658
+ const size_t len = ggml_type_size(src0->type);
6659
+
6660
+ const int ith = params->ith;
6661
+ const int nth = params->nth;
6662
+
6663
+ GGML_TENSOR_BINARY_OP_LOCALS
6664
+
6665
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6666
+
6667
+ GGML_ASSERT(dim >= 0 && dim < 4);
6668
+
6669
+ int64_t o[4] = {0, 0, 0, 0};
6670
+ o[dim] = src0->ne[dim];
6671
+
6672
+ const char * x;
6673
+
6674
+ // TODO: smarter multi-theading
6675
+ for (int i3 = 0; i3 < ne3; i3++) {
6676
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6677
+ for (int i1 = 0; i1 < ne1; i1++) {
6678
+ for (int i0 = 0; i0 < ne0; i0++) {
6679
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6680
+ x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
6681
+ } else {
6682
+ x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
6683
+ }
6684
+
6685
+ char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
6686
+
6687
+ memcpy(y, x, len);
6688
+ }
6689
+ }
6690
+ }
6691
+ }
6692
+ }
6693
+
6694
+ static void ggml_compute_forward_concat_i8(
6695
+ const struct ggml_compute_params * params,
6696
+ struct ggml_tensor * dst) {
6697
+
6698
+ const struct ggml_tensor * src0 = dst->src[0];
6699
+ const struct ggml_tensor * src1 = dst->src[1];
6700
+
6701
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
6702
+
6703
+ const int ith = params->ith;
6704
+ const int nth = params->nth;
6705
+
6706
+ GGML_TENSOR_BINARY_OP_LOCALS
6707
+
6708
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6709
+
6710
+ GGML_ASSERT(dim >= 0 && dim < 4);
6711
+
6712
+ int64_t o[4] = {0, 0, 0, 0};
6713
+ o[dim] = src0->ne[dim];
6714
+
6715
+ const int8_t * x;
6716
+
6717
+ // TODO: smarter multi-theading
6718
+ for (int i3 = 0; i3 < ne3; i3++) {
6719
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6720
+ for (int i1 = 0; i1 < ne1; i1++) {
6721
+ for (int i0 = 0; i0 < ne0; i0++) {
6722
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6723
+ x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6724
+ } else {
6725
+ x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6726
+ }
6727
+
6728
+ int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6729
+
6730
+ *y = *x;
6731
+ }
6732
+ }
6733
+ }
6734
+ }
6735
+ }
6736
+
6737
+ static void ggml_compute_forward_concat_f16(
6738
+ const struct ggml_compute_params * params,
6739
+ struct ggml_tensor * dst) {
6740
+
6741
+ const struct ggml_tensor * src0 = dst->src[0];
6742
+ const struct ggml_tensor * src1 = dst->src[1];
6743
+
6744
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
6745
+
6746
+ const int ith = params->ith;
6747
+ const int nth = params->nth;
6748
+
6749
+ GGML_TENSOR_BINARY_OP_LOCALS
6750
+
6751
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6752
+
6753
+ GGML_ASSERT(dim >= 0 && dim < 4);
6754
+
6755
+ int64_t o[4] = {0, 0, 0, 0};
6756
+ o[dim] = src0->ne[dim];
6757
+
6758
+ const ggml_fp16_t * x;
6759
+
6760
+ // TODO: smarter multi-theading
6761
+ for (int i3 = 0; i3 < ne3; i3++) {
6762
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6763
+ for (int i1 = 0; i1 < ne1; i1++) {
6764
+ for (int i0 = 0; i0 < ne0; i0++) {
6765
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6766
+ x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6767
+ } else {
6768
+ x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6769
+ }
6770
+
6771
+ ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6772
+
6773
+ *y = *x;
6774
+ }
6775
+ }
6776
+ }
6777
+ }
6778
+ }
6779
+
6651
6780
static void ggml_compute_forward_concat_f32(
6652
6781
const struct ggml_compute_params * params,
6653
6782
struct ggml_tensor * dst) {
6654
6783
6655
6784
const struct ggml_tensor * src0 = dst->src[0];
6656
6785
const struct ggml_tensor * src1 = dst->src[1];
6657
6786
6658
- GGML_ASSERT(src0->nb[0] == sizeof(float));
6787
+ GGML_ASSERT(ggml_type_size( src0->type) == sizeof(float));
6659
6788
6660
6789
const int ith = params->ith;
6661
6790
const int nth = params->nth;
@@ -6698,14 +6827,24 @@ static void ggml_compute_forward_concat(
6698
6827
const struct ggml_tensor * src0 = dst->src[0];
6699
6828
6700
6829
switch (src0->type) {
6830
+ case GGML_TYPE_F16:
6831
+ case GGML_TYPE_BF16:
6832
+ case GGML_TYPE_I16:
6833
+ {
6834
+ ggml_compute_forward_concat_f16(params, dst);
6835
+ } break;
6836
+ case GGML_TYPE_I8:
6837
+ {
6838
+ ggml_compute_forward_concat_i8(params, dst);
6839
+ } break;
6701
6840
case GGML_TYPE_F32:
6702
6841
case GGML_TYPE_I32:
6703
6842
{
6704
6843
ggml_compute_forward_concat_f32(params, dst);
6705
6844
} break;
6706
6845
default:
6707
6846
{
6708
- GGML_ABORT("fatal error" );
6847
+ ggml_compute_forward_concat_any(params, dst );
6709
6848
}
6710
6849
}
6711
6850
}
0 commit comments