diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 77aaeee6b9..05d349d970 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -851,9 +851,19 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IF /* Objective function come from sum of partitions in m and n. */ /* (n / nthreads_n) + (m / nthreads_m) */ /* = (n * nthreads_m + m * nthreads_n) / (nthreads_n * nthreads_m) */ - while (nthreads_m % 2 == 0 && n * nthreads_m + m * nthreads_n > n * (nthreads_m / 2) + m * (nthreads_n * 2)) { - nthreads_m /= 2; - nthreads_n *= 2; + BLASLONG cost = 0, div = 0; + for (BLASLONG i = 1; i <= sqrt(nthreads_m); i++) { + if (nthreads_m % i) continue; + BLASLONG j = nthreads_m / i; + BLASLONG cost_i = n * j + m * nthreads_n * i; + BLASLONG cost_j = n * i + m * nthreads_n * j; + if (cost == 0 || + cost_i < cost) {cost = cost_i; div = i;} + if (cost_j < cost) {cost = cost_j; div = j;} + } + if (div > 1) { + nthreads_m /= div; + nthreads_n *= div; } }