From 597dcf66ba07428a580c481d7150252a40149094 Mon Sep 17 00:00:00 2001 From: Hwang In Tak Date: Mon, 4 Nov 2024 14:01:15 +0000 Subject: [PATCH] fix: Trim FFT --- math/poly/asm_fft.go | 34 ++++++++++++------------- math/poly/asm_fft_amd64.go | 38 ++++++++++++++-------------- math/poly/asm_fft_amd64.s | 9 ++++--- math/poly/poly_evaluator.go | 50 +++++++++++++++---------------------- 4 files changed, 62 insertions(+), 69 deletions(-) diff --git a/math/poly/asm_fft.go b/math/poly/asm_fft.go index 29cb887..12c82dc 100644 --- a/math/poly/asm_fft.go +++ b/math/poly/asm_fft.go @@ -362,7 +362,7 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { } // Last Loop - NInv := 2 / float64(N) + scale := float64(N / 2) Wr := real(twInv[w]) Wi := imag(twInv[w]) for j := 0; j < N/2; j += 8 { @@ -386,15 +386,15 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { Vi2 := coeffs[j+N/2+6] Vi3 := coeffs[j+N/2+7] - coeffs[j+0] = (Ur0 + Vr0) * NInv - coeffs[j+1] = (Ur1 + Vr1) * NInv - coeffs[j+2] = (Ur2 + Vr2) * NInv - coeffs[j+3] = (Ur3 + Vr3) * NInv + coeffs[j+0] = (Ur0 + Vr0) / scale + coeffs[j+1] = (Ur1 + Vr1) / scale + coeffs[j+2] = (Ur2 + Vr2) / scale + coeffs[j+3] = (Ur3 + Vr3) / scale - coeffs[j+4] = (Ui0 + Vi0) * NInv - coeffs[j+5] = (Ui1 + Vi1) * NInv - coeffs[j+6] = (Ui2 + Vi2) * NInv - coeffs[j+7] = (Ui3 + Vi3) * NInv + coeffs[j+4] = (Ui0 + Vi0) / scale + coeffs[j+5] = (Ui1 + Vi1) / scale + coeffs[j+6] = (Ui2 + Vi2) / scale + coeffs[j+7] = (Ui3 + Vi3) / scale UVr0 := Ur0 - Vr0 UVr1 := Ur1 - Vr1 @@ -416,14 +416,14 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { UVi2W := UVr2*Wi + UVi2*Wr UVi3W := UVr3*Wi + UVi3*Wr - coeffs[j+N/2+0] = UVr0W - coeffs[j+N/2+1] = UVr1W - coeffs[j+N/2+2] = UVr2W - coeffs[j+N/2+3] = UVr3W + coeffs[j+N/2+0] = UVr0W / scale + coeffs[j+N/2+1] = UVr1W / scale + coeffs[j+N/2+2] = UVr2W / scale + coeffs[j+N/2+3] = UVr3W / scale - coeffs[j+N/2+4] = UVi0W - coeffs[j+N/2+5] = UVi1W - coeffs[j+N/2+6] = UVi2W - coeffs[j+N/2+7] = UVi3W + coeffs[j+N/2+4] = UVi0W / scale + coeffs[j+N/2+5] = UVi1W / scale + coeffs[j+N/2+6] = UVi2W / scale + coeffs[j+N/2+7] = UVi3W / scale } } diff --git a/math/poly/asm_fft_amd64.go b/math/poly/asm_fft_amd64.go index 9fb1a10..c9cbe44 100644 --- a/math/poly/asm_fft_amd64.go +++ b/math/poly/asm_fft_amd64.go @@ -208,13 +208,13 @@ func fftInPlace(coeffs []float64, tw []complex128) { } } -func invFFTInPlaceAVX2(coeffs []float64, twInv []complex128, NInv float64) +func invFFTInPlaceAVX2(coeffs []float64, twInv []complex128, scale float64) // invfftInPlace is a top-level function for inverse FFT. // All internal inverse FFT implementations calls this function for performance. func invFFTInPlace(coeffs []float64, twInv []complex128) { if cpu.X86.HasAVX2 && cpu.X86.HasFMA { - invFFTInPlaceAVX2(coeffs, twInv, 2/float64(len(coeffs))) + invFFTInPlaceAVX2(coeffs, twInv, float64(len(coeffs)/2)) return } @@ -380,7 +380,7 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { } // Last Loop - NInv := 2 / float64(N) + scale := float64(N / 2) Wr := real(twInv[w]) Wi := imag(twInv[w]) for j := 0; j < N/2; j += 8 { @@ -404,15 +404,15 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { Vi2 := coeffs[j+N/2+6] Vi3 := coeffs[j+N/2+7] - coeffs[j+0] = (Ur0 + Vr0) * NInv - coeffs[j+1] = (Ur1 + Vr1) * NInv - coeffs[j+2] = (Ur2 + Vr2) * NInv - coeffs[j+3] = (Ur3 + Vr3) * NInv + coeffs[j+0] = (Ur0 + Vr0) / scale + coeffs[j+1] = (Ur1 + Vr1) / scale + coeffs[j+2] = (Ur2 + Vr2) / scale + coeffs[j+3] = (Ur3 + Vr3) / scale - coeffs[j+4] = (Ui0 + Vi0) * NInv - coeffs[j+5] = (Ui1 + Vi1) * NInv - coeffs[j+6] = (Ui2 + Vi2) * NInv - coeffs[j+7] = (Ui3 + Vi3) * NInv + coeffs[j+4] = (Ui0 + Vi0) / scale + coeffs[j+5] = (Ui1 + Vi1) / scale + coeffs[j+6] = (Ui2 + Vi2) / scale + coeffs[j+7] = (Ui3 + Vi3) / scale UVr0 := Ur0 - Vr0 UVr1 := Ur1 - Vr1 @@ -434,14 +434,14 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) { UVi2W := UVr2*Wi + UVi2*Wr UVi3W := UVr3*Wi + UVi3*Wr - coeffs[j+N/2+0] = UVr0W - coeffs[j+N/2+1] = UVr1W - coeffs[j+N/2+2] = UVr2W - coeffs[j+N/2+3] = UVr3W + coeffs[j+N/2+0] = UVr0W / scale + coeffs[j+N/2+1] = UVr1W / scale + coeffs[j+N/2+2] = UVr2W / scale + coeffs[j+N/2+3] = UVr3W / scale - coeffs[j+N/2+4] = UVi0W - coeffs[j+N/2+5] = UVi1W - coeffs[j+N/2+6] = UVi2W - coeffs[j+N/2+7] = UVi3W + coeffs[j+N/2+4] = UVi0W / scale + coeffs[j+N/2+5] = UVi1W / scale + coeffs[j+N/2+6] = UVi2W / scale + coeffs[j+N/2+7] = UVi3W / scale } } diff --git a/math/poly/asm_fft_amd64.s b/math/poly/asm_fft_amd64.s index c659754..f850fd0 100644 --- a/math/poly/asm_fft_amd64.s +++ b/math/poly/asm_fft_amd64.s @@ -373,7 +373,7 @@ m_loop_end: MOVQ CX, DX SHRQ $1, DX // N/2 - VBROADCASTSD NInv+48(FP), Y12 + VBROADCASTSD scale+48(FP), Y12 VBROADCASTSD (BX), Y13 VBROADCASTSD 8(BX), Y14 @@ -391,8 +391,8 @@ last_loop: VADDPD Y2, Y0, Y4 VADDPD Y3, Y1, Y5 - VMULPD Y4, Y12, Y4 - VMULPD Y5, Y12, Y5 + VDIVPD Y12, Y4, Y4 + VDIVPD Y12, Y5, Y5 VSUBPD Y2, Y0, Y6 // UVr VSUBPD Y3, Y1, Y7 // UVi @@ -404,6 +404,9 @@ last_loop: VMULPD Y6, Y14, Y9 VFMADD231PD Y7, Y13, Y9 // UVi * W + VDIVPD Y12, Y8, Y8 + VDIVPD Y12, Y9, Y9 + VMOVUPD Y4, (AX)(SI*8) VMOVUPD Y5, 32(AX)(SI*8) VMOVUPD Y8, (AX)(DI*8) diff --git a/math/poly/poly_evaluator.go b/math/poly/poly_evaluator.go index 9404bc8..c32d3ea 100644 --- a/math/poly/poly_evaluator.go +++ b/math/poly/poly_evaluator.go @@ -13,6 +13,10 @@ const ( // Currently, this is set to 16, because AVX2 implementation of FFT and inverse FFT // handles first/last two loops separately. MinDegree = 1 << 4 + + // splitBound is denotes the maximum bits of N*B1*B2, where B1, B2 is the splitting bound of polynomial multiplication. + // Currently, this is set to 50, which gives failure rate less than 2^-73. + splitBound = 50 ) // Evaluator computes polynomial operations over the N-th cyclotomic ring. @@ -116,60 +120,46 @@ func NewEvaluator[T num.Integer](N int) *Evaluator[T] { // genTwiddleFactors generates twiddle factors for FFT. func genTwiddleFactors(N int) (tw, twInv []complex128) { - wNj := make([]complex128, N/2) - wNjInv := make([]complex128, N/2) + twFFT := make([]complex128, N/2) + twInvFFT := make([]complex128, N/2) for j := 0; j < N/2; j++ { e := -2 * math.Pi * float64(j) / float64(N) - wNj[j] = cmplx.Exp(complex(0, e)) - wNjInv[j] = cmplx.Exp(-complex(0, e)) - } - vec.BitReverseInPlace(wNj) - vec.BitReverseInPlace(wNjInv) - - logN := num.Log2(N) - w4Nj := make([]complex128, logN) - w4NjInv := make([]complex128, logN) - for j := 0; j < logN; j++ { - e := 2 * math.Pi * math.Exp2(float64(j)) / float64(4*N) - w4Nj[j] = cmplx.Exp(complex(0, e)) - w4NjInv[j] = cmplx.Exp(-complex(0, e)) + twFFT[j] = cmplx.Exp(complex(0, e)) + twInvFFT[j] = cmplx.Exp(-complex(0, e)) } + vec.BitReverseInPlace(twFFT) + vec.BitReverseInPlace(twInvFFT) - tw = make([]complex128, N) - twInv = make([]complex128, N) + tw = make([]complex128, 0, N-1) + twInv = make([]complex128, 0, N-1) - w, t := 0, logN - for m := 1; m < N; m <<= 1 { - t-- + for m, t := 1, N/2; m < N; m, t = m<<1, t>>1 { + twFold := cmplx.Exp(complex(0, 2*math.Pi*float64(t)/float64(4*N))) for i := 0; i < m; i++ { - tw[w] = wNj[i] * w4Nj[t] - w++ + tw = append(tw, twFFT[i]*twFold) } } - w, t = 0, 0 - for m := N; m > 1; m >>= 1 { + for m, t := N, 1; m > 1; m, t = m>>1, t<<1 { + twInvFold := cmplx.Exp(complex(0, -2*math.Pi*float64(t)/float64(4*N))) for i := 0; i < m/2; i++ { - twInv[w] = wNjInv[i] * w4NjInv[t] - w++ + twInv = append(twInv, twInvFFT[i]*twInvFold) } - t++ } - twInv[w-1] /= complex(float64(N), 0) return tw, twInv } // splitParameters generates splitBits and splitCount for [*Evaluator.MulPoly]. func splitParameters[T num.Integer](N int) (splitBits T, splitCount int) { - splitBits = T(50-num.Log2(N)) / 2 + splitBits = T(splitBound-num.Log2(N)) / 2 splitCount = int(math.Ceil(float64(num.SizeT[T]()) / float64(splitBits))) return } // splitParametersBinary generates splitBits and splitCount for [*Evaluator.BinaryFourierPolyMulPoly]. func splitParametersBinary[T num.Integer](N int) (splitBits T, splitCount int) { - splitBits = T(50 - num.Log2(N)) + splitBits = T(splitBound - num.Log2(N)) splitCount = int(math.Ceil(float64(num.SizeT[T]()) / float64(splitBits))) return }