Skip to content

Commit

Permalink
Remove axpy dense op (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
hodlen authored Apr 2, 2024
1 parent c8e15df commit b478398
Showing 1 changed file with 3 additions and 147 deletions.
150 changes: 3 additions & 147 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14310,146 +14310,6 @@ static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, cons
#endif
(void)vy;
}
atomic_flag g_axpy_dense_lock = ATOMIC_FLAG_INIT;
static void ggml_compute_forward_mul_mat_axpy_dense(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);

GGML_TENSOR_BINARY_OP_LOCALS;

// const int ith = params->ith;
const int nth = params->nth;

const enum ggml_type type = src0->type;

// const bool src1_cont = ggml_is_contiguous(src1);

// ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

// GGML_ASSERT(ne0 == ne01);
// GGML_ASSERT(ne1 == ne11);
// GGML_ASSERT(ne2 == ne12);
// GGML_ASSERT(ne3 == ne13);

// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == sizeof(float));

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);

// broadcast factors
// const int64_t r2 = ne12/ne02;
// const int64_t r3 = ne13/ne03;

// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows

if (params->type == GGML_TASK_INIT) {
ggml_set_zero(dst);
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);

for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
atomic_store(params->aic, 0);

return;
}

if (params->type == GGML_TASK_FINALIZE) {
return;
}

ggml_fp16_t* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
// const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);

struct ggml_tensor *src2 = dst->src[2];

// parallelize by src0 rows
const int64_t dr = (src2->ne[0] + 8*nth - 1)/(8*nth);
const int nr = ggml_nrows(src0);

// const int64_t ir10 = dr*ith;
// const int64_t ir11 = MIN(ir10 + dr, src2->ne[0]);

// src1 rows
// const int64_t nr1 = ne11*ne12*ne13;
// float *idx = src2->data;
// int *gid = (int *)(dst->src[3]->data);
// printf("down %d up %d ne00 %d\n", ir10, ir11, ne00);

#if defined(_MSC_VER)
float* vec = (float *)_malloca(ne00 * 4 * sizeof(float));
#else
float vec[ne00*4];
#endif
void *vy = vec;
memset(vy, 0, ne00*4);
char* src0_row = (char *) src0->data;
while(true) {
const int ir0 = atomic_fetch_add(params->aic, dr);
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
if (ir1 >= nr) break;
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
}
if (ir0 + dr >= nr)
break;
}

// 获取锁
while (atomic_flag_test_and_set(&g_axpy_dense_lock)) {
// 如果锁已经被占用,则等待
}

float *res = (float *)(dst->data);
float *tmp = (float *)vy;
int i;


// 计算剩余的元素个数
int remainder = ne00 % 8;

#if defined(__AVX2__)
// 使用AVX指令进行向量化计算
for (i = 0; i < ne00 - remainder; i += 8) {
__m256 res_vec = _mm256_loadu_ps(res + i); // 加载res中的8个浮点数
__m256 tmp_vec = _mm256_loadu_ps(tmp + i); // 加载tmp中的8个浮点数
__m256 result = _mm256_add_ps(res_vec, tmp_vec); // 执行加法运算
_mm256_storeu_ps(res + i, result); // 存储结果到res中
}

// 处理剩余的元素
for (i = ne00 - remainder; i < ne00; i++) {
res[i] += tmp[i];
}
#else
for (i = 0; i < dst->ne[0]; i++) {
res[i] += tmp[i];
}
#endif
atomic_flag_clear(&g_axpy_dense_lock);
#if defined(_MSC_VER)
_freea(vec);
#endif
}

static void ggml_compute_forward_mul_mat_axpy(
const struct ggml_compute_params * params,
Expand Down Expand Up @@ -15052,21 +14912,17 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_AXPY:
{
// printf("here? %d\n", tensor->src[0]->type);
GGML_ASSERT(tensor->src[2] != NULL && "sparse index is required for AXPY");
struct ggml_tensor *src3 = tensor->src[3];
if (tensor->src[2] == NULL) {
ggml_compute_forward_mul_mat_axpy_dense(params, tensor->src[0], tensor->src[1], tensor);
}
else if (src3 != NULL){
if (src3 != NULL){
if (tensor->src[0]->type != GGML_TYPE_Q4_0) {
ggml_compute_forward_mul_mat_axpy(params, tensor->src[0], tensor->src[1], tensor);
}
else {
ggml_compute_forward_mul_mat_axpy_q4_0(params, tensor->src[0], tensor->src[1], tensor);

}
}
else {
} else {
ggml_compute_forward_mul_mat_axpy_head(params, tensor->src[0], tensor->src[1], tensor);
}
// ggml_compute_forward_mul_mat_axpy(params, tensor->src[0], tensor->src[1], tensor);
Expand Down

0 comments on commit b478398

Please sign in to comment.