Skip to content

Commit

Permalink
[software] Unroll bitreversal in parallel scheduled fft
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Aug 4, 2023
1 parent 63b1874 commit 8dce826
Showing 1 changed file with 69 additions and 22 deletions.
91 changes: 69 additions & 22 deletions software/runtime/kernel/mempool_radix4_cfft_q16p.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,8 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16,
: [CoSi1] "r"(CoSi1), [CoSi2] "r"(CoSi2), [CoSi3] "r"(CoSi3)
:);
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
int16_t *pIn = pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * N_BANKS;
int16_t *pOut = pDst16 + idx_row * (N_BANKS * 8) + 2 * col_id * N_BANKS;
int16_t *pIn = pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * fftLen;
int16_t *pOut = pDst16 + idx_row * (N_BANKS * 8) + 2 * col_id * (fftLen / 4);
radix4_butterfly_first(pIn, pOut, i0, n2, CoSi1, CoSi2, CoSi3, C1, C2,
C3);
}
Expand All @@ -598,18 +598,15 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16,
n1 = n2;
n2 >>= 2U;
n2_store = n2 >> 2U;

for (j = core_id * 4; j < core_id * 4 + 4; j++) {
CoSi1 = *(v2s *)&pCoef_src[2U * (j)];
CoSi2 = *(v2s *)&pCoef_src[2U * (j + 1 * N_BANKS)];
CoSi3 = *(v2s *)&pCoef_src[2U * (j + 2 * N_BANKS)];
if (j % 4 == 0) {

wing_idx = j % n2;
offset = (j / n2);
ic = wing_idx >> 2U;
ic += offset * n2;

*((v2s *)&pCoef_dst[2U * (ic)]) = CoSi1;
*((v2s *)&pCoef_dst[2U * (n2_store * 1 + ic)]) = CoSi1;
*((v2s *)&pCoef_dst[2U * (n2_store * 2 + ic)]) = CoSi1;
Expand Down Expand Up @@ -674,32 +671,82 @@ void mempool_radix4_cfft_q16p_scheduler(uint32_t col_id, int16_t *pSrc16,
pSrc16 = pDst16;
pDst16 = pTmp;
mempool_log_partial_barrier(2, absolute_core_id, N_FFTs_COL * nPE);

mempool_stop_benchmark();
mempool_start_benchmark();

/* BITREVERSAL */
// Bitreversal stage stores in the sequential addresses
if (bitReverseFlag) {
#ifdef BITREVERSETABLE
uint16_t *ptr1 = (uint16_t *)(pSrc16);
uint16_t *ptr2 = (uint16_t *)(pDst16 + 2 * col_id * fftLen);
for (j = 2 * core_id; j < bitReverseLen; j += 2 * nPE) {
v2s addr, tmpa, tmpb;
addr = __SRA2(*(v2s *)&pBitRevTable[j], ((v2s){2, 2}));
int32_t a0 = addr[0];
int32_t a1 = addr[1];
int32_t b0 = (a0 % 4) * 2 * N_BANKS + 2 * (a0 / 4);
int32_t b1 = (a1 % 4) * 2 * N_BANKS + 2 * (a1 / 4);
for (int32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
tmpa = *(v2s *)&ptr1[b0 + idx_row * (N_BANKS * 8)];
tmpb = *(v2s *)&ptr1[b1 + idx_row * (N_BANKS * 8)];
*((v2s *)&ptr2[a0 + idx_row * (N_BANKS * 8)]) = tmpb;
*((v2s *)&ptr2[a1 + idx_row * (N_BANKS * 8)]) = tmpa;
pSrc16 = pSrc16 + 2 * col_id * (fftLen / 4);
pDst16 = pDst16 + 2 * col_id * fftLen;
for (j = 8 * core_id; j < bitReverseLen; j += 8 * nPE) {
uint32_t addr1, addr2, addr3, addr4;
uint32_t tmpa1, tmpa2, tmpa3, tmpa4;
uint32_t tmpb1, tmpb2, tmpb3, tmpb4;
uint32_t a1, a2, a3, a4;
uint32_t b1, b2, b3, b4;
uint32_t a1_load, a2_load, a3_load, a4_load;
uint32_t b1_load, b2_load, b3_load, b4_load;
uint32_t s2 = 0x00020002;
addr1 = *(uint32_t *)&pBitRevTable[j];
addr2 = *(uint32_t *)&pBitRevTable[j + 2];
addr3 = *(uint32_t *)&pBitRevTable[j + 4];
addr4 = *(uint32_t *)&pBitRevTable[j + 6];
asm volatile(
"pv.sra.h %[addr1],%[addr1],%[s2];"
"pv.sra.h %[addr2],%[addr2],%[s2];"
"pv.sra.h %[addr3],%[addr3],%[s2];"
"pv.sra.h %[addr4],%[addr4],%[s2];"
"pv.extract.h %[a1],%[addr1],0;"
"pv.extract.h %[a2],%[addr2],0;"
"pv.extract.h %[a3],%[addr3],0;"
"pv.extract.h %[a4],%[addr4],0;"
"pv.extract.h %[b1],%[addr1],1;"
"pv.extract.h %[b2],%[addr2],1;"
"pv.extract.h %[b3],%[addr3],1;"
"pv.extract.h %[b4],%[addr4],1;"
: [a1] "=r"(a1), [a2] "=r"(a2), [a3] "=r"(a3), [a4] "=r"(a4),
[b1] "=r"(b1), [b2] "=r"(b2), [b3] "=r"(b3), [b4] "=r"(b4),
[addr1] "+&r"(addr1), [addr2] "+&r"(addr2),
[addr3] "+&r"(addr3), [addr4] "+&r"(addr4)
: [s2] "r"(s2)
:);
// Compute the local addresses from the natural order ones
a1_load = (a1 % 4) * 2 * N_BANKS + 2 * (a1 / 4);
a2_load = (a2 % 4) * 2 * N_BANKS + 2 * (a2 / 4);
a3_load = (a3 % 4) * 2 * N_BANKS + 2 * (a3 / 4);
a4_load = (a4 % 4) * 2 * N_BANKS + 2 * (a4 / 4);
b1_load = (b1 % 4) * 2 * N_BANKS + 2 * (b1 / 4);
b2_load = (b2 % 4) * 2 * N_BANKS + 2 * (b2 / 4);
b3_load = (b3 % 4) * 2 * N_BANKS + 2 * (b3 / 4);
b4_load = (b4 % 4) * 2 * N_BANKS + 2 * (b4 / 4);
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
uint16_t *ptr1 = (uint16_t *)(pSrc16 + idx_row * (N_BANKS * 8));
uint16_t *ptr2 = (uint16_t *)(pDst16 + idx_row * (N_BANKS * 8));
// Load at address a
tmpa1 = *(uint32_t *)&ptr1[a1_load];
tmpa2 = *(uint32_t *)&ptr1[a2_load];
tmpa3 = *(uint32_t *)&ptr1[a3_load];
tmpa4 = *(uint32_t *)&ptr1[a4_load];
// Load at address b
tmpb1 = *(uint32_t *)&ptr1[b1_load];
tmpb2 = *(uint32_t *)&ptr1[b2_load];
tmpb3 = *(uint32_t *)&ptr1[b3_load];
tmpb4 = *(uint32_t *)&ptr1[b4_load];
// Swap a with b
*((uint32_t *)&ptr2[b1]) = tmpa1;
*((uint32_t *)&ptr2[b2]) = tmpa2;
*((uint32_t *)&ptr2[b3]) = tmpa3;
*((uint32_t *)&ptr2[b4]) = tmpa4;
// Swap b with a
*((uint32_t *)&ptr2[a1]) = tmpb1;
*((uint32_t *)&ptr2[a2]) = tmpb2;
*((uint32_t *)&ptr2[a3]) = tmpb3;
*((uint32_t *)&ptr2[a4]) = tmpb4;
}
}
#else
uint16_t *ptr1 = (uint16_t *)(pSrc16);
uint16_t *ptr1 = (uint16_t *)(pSrc16 + 2 * col_id * (fftLen / 4));
uint16_t *ptr2 = (uint16_t *)(pDst16 + 2 * col_id * fftLen);
for (j = core_id * 16; j < MIN(core_id * 16 + 16, fftLen >> 2U); j += 4) {
uint32_t idx0 = j;
Expand Down

0 comments on commit 8dce826

Please sign in to comment.