diff --git a/software/runtime/kernel/mempool_radix4_cfft_q16p.h b/software/runtime/kernel/mempool_radix4_cfft_q16p.h index 93f8bc2b3..aadf6c3b7 100644 --- a/software/runtime/kernel/mempool_radix4_cfft_q16p.h +++ b/software/runtime/kernel/mempool_radix4_cfft_q16p.h @@ -579,8 +579,8 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16, : [CoSi1] "r"(CoSi1), [CoSi2] "r"(CoSi2), [CoSi3] "r"(CoSi3) :); for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) { - int16_t *pIn = pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * N_BANKS; - int16_t *pOut = pDst16 + idx_row * (N_BANKS * 8) + 2 * col_id * N_BANKS; + int16_t *pIn = pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * fftLen; + int16_t *pOut = pDst16 + idx_row * (N_BANKS * 8) + 2 * col_id * (fftLen / 4); radix4_butterfly_first(pIn, pOut, i0, n2, CoSi1, CoSi2, CoSi3, C1, C2, C3); } @@ -598,18 +598,15 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16, n1 = n2; n2 >>= 2U; n2_store = n2 >> 2U; - for (j = core_id * 4; j < core_id * 4 + 4; j++) { CoSi1 = *(v2s *)&pCoef_src[2U * (j)]; CoSi2 = *(v2s *)&pCoef_src[2U * (j + 1 * N_BANKS)]; CoSi3 = *(v2s *)&pCoef_src[2U * (j + 2 * N_BANKS)]; if (j % 4 == 0) { - wing_idx = j % n2; offset = (j / n2); ic = wing_idx >> 2U; ic += offset * n2; - *((v2s *)&pCoef_dst[2U * (ic)]) = CoSi1; *((v2s *)&pCoef_dst[2U * (n2_store * 1 + ic)]) = CoSi1; *((v2s *)&pCoef_dst[2U * (n2_store * 2 + ic)]) = CoSi1; @@ -674,32 +671,82 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16, pSrc16 = pDst16; pDst16 = pTmp; mempool_log_partial_barrier(2, absolute_core_id, N_FFTs_COL * nPE); - mempool_stop_benchmark(); mempool_start_benchmark(); - /* BITREVERSAL */ // Bitreversal stage stores in the sequential addresses if (bitReverseFlag) { #ifdef BITREVERSETABLE - uint16_t *ptr1 = (uint16_t *)(pSrc16); - uint16_t *ptr2 = (uint16_t *)(pDst16 + 2 * col_id * fftLen); - for (j = 2 * core_id; j < bitReverseLen; j += 2 * nPE) { - v2s addr, tmpa, tmpb; - addr = __SRA2(*(v2s *)&pBitRevTable[j], ((v2s){2, 2})); - int32_t a0 = addr[0]; - int32_t a1 = addr[1]; - int32_t b0 = (a0 % 4) * 2 * N_BANKS + 2 * (a0 / 4); - int32_t b1 = (a1 % 4) * 2 * N_BANKS + 2 * (a1 / 4); - for (int32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) { - tmpa = *(v2s *)&ptr1[b0 + idx_row * (N_BANKS * 8)]; - tmpb = *(v2s *)&ptr1[b1 + idx_row * (N_BANKS * 8)]; - *((v2s *)&ptr2[a0 + idx_row * (N_BANKS * 8)]) = tmpb; - *((v2s *)&ptr2[a1 + idx_row * (N_BANKS * 8)]) = tmpa; + pSrc16 = pSrc16 + 2 * col_id * (fftLen / 4); + pDst16 = pDst16 + 2 * col_id * fftLen; + for (j = 8 * core_id; j < bitReverseLen; j += 8 * nPE) { + uint32_t addr1, addr2, addr3, addr4; + uint32_t tmpa1, tmpa2, tmpa3, tmpa4; + uint32_t tmpb1, tmpb2, tmpb3, tmpb4; + uint32_t a1, a2, a3, a4; + uint32_t b1, b2, b3, b4; + uint32_t a1_load, a2_load, a3_load, a4_load; + uint32_t b1_load, b2_load, b3_load, b4_load; + uint32_t s2 = 0x00020002; + addr1 = *(uint32_t *)&pBitRevTable[j]; + addr2 = *(uint32_t *)&pBitRevTable[j + 2]; + addr3 = *(uint32_t *)&pBitRevTable[j + 4]; + addr4 = *(uint32_t *)&pBitRevTable[j + 6]; + asm volatile( + "pv.sra.h %[addr1],%[addr1],%[s2];" + "pv.sra.h %[addr2],%[addr2],%[s2];" + "pv.sra.h %[addr3],%[addr3],%[s2];" + "pv.sra.h %[addr4],%[addr4],%[s2];" + "pv.extract.h %[a1],%[addr1],0;" + "pv.extract.h %[a2],%[addr2],0;" + "pv.extract.h %[a3],%[addr3],0;" + "pv.extract.h %[a4],%[addr4],0;" + "pv.extract.h %[b1],%[addr1],1;" + "pv.extract.h %[b2],%[addr2],1;" + "pv.extract.h %[b3],%[addr3],1;" + "pv.extract.h %[b4],%[addr4],1;" + : [a1] "=r"(a1), [a2] "=r"(a2), [a3] "=r"(a3), [a4] "=r"(a4), + [b1] "=r"(b1), [b2] "=r"(b2), [b3] "=r"(b3), [b4] "=r"(b4), + [addr1] "+&r"(addr1), [addr2] "+&r"(addr2), + [addr3] "+&r"(addr3), [addr4] "+&r"(addr4) + : [s2] "r"(s2) + :); + // Compute the local addresses from the natural order ones + a1_load = (a1 % 4) * 2 * N_BANKS + 2 * (a1 / 4); + a2_load = (a2 % 4) * 2 * N_BANKS + 2 * (a2 / 4); + a3_load = (a3 % 4) * 2 * N_BANKS + 2 * (a3 / 4); + a4_load = (a4 % 4) * 2 * N_BANKS + 2 * (a4 / 4); + b1_load = (b1 % 4) * 2 * N_BANKS + 2 * (b1 / 4); + b2_load = (b2 % 4) * 2 * N_BANKS + 2 * (b2 / 4); + b3_load = (b3 % 4) * 2 * N_BANKS + 2 * (b3 / 4); + b4_load = (b4 % 4) * 2 * N_BANKS + 2 * (b4 / 4); + for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) { + uint16_t *ptr1 = (uint16_t *)(pSrc16 + idx_row * (N_BANKS * 8)); + uint16_t *ptr2 = (uint16_t *)(pDst16 + idx_row * (N_BANKS * 8)); + // Load at address a + tmpa1 = *(uint32_t *)&ptr1[a1_load]; + tmpa2 = *(uint32_t *)&ptr1[a2_load]; + tmpa3 = *(uint32_t *)&ptr1[a3_load]; + tmpa4 = *(uint32_t *)&ptr1[a4_load]; + // Load at address b + tmpb1 = *(uint32_t *)&ptr1[b1_load]; + tmpb2 = *(uint32_t *)&ptr1[b2_load]; + tmpb3 = *(uint32_t *)&ptr1[b3_load]; + tmpb4 = *(uint32_t *)&ptr1[b4_load]; + // Swap a with b + *((uint32_t *)&ptr2[b1]) = tmpa1; + *((uint32_t *)&ptr2[b2]) = tmpa2; + *((uint32_t *)&ptr2[b3]) = tmpa3; + *((uint32_t *)&ptr2[b4]) = tmpa4; + // Swap b with a + *((uint32_t *)&ptr2[a1]) = tmpb1; + *((uint32_t *)&ptr2[a2]) = tmpb2; + *((uint32_t *)&ptr2[a3]) = tmpb3; + *((uint32_t *)&ptr2[a4]) = tmpb4; } } #else - uint16_t *ptr1 = (uint16_t *)(pSrc16); + uint16_t *ptr1 = (uint16_t *)(pSrc16 + 2 * col_id * (fftLen / 4)); uint16_t *ptr2 = (uint16_t *)(pDst16 + 2 * col_id * fftLen); for (j = core_id * 16; j < MIN(core_id * 16 + 16, fftLen >> 2U); j += 4) { uint32_t idx0 = j;