Skip to content

Commit

Permalink
fix: Trim FFT
Browse files Browse the repository at this point in the history
  • Loading branch information
sp301415 committed Nov 4, 2024
1 parent 338b85f commit 597dcf6
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 69 deletions.
34 changes: 17 additions & 17 deletions math/poly/asm_fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
}

// Last Loop
NInv := 2 / float64(N)
scale := float64(N / 2)
Wr := real(twInv[w])
Wi := imag(twInv[w])
for j := 0; j < N/2; j += 8 {
Expand All @@ -386,15 +386,15 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
Vi2 := coeffs[j+N/2+6]
Vi3 := coeffs[j+N/2+7]

coeffs[j+0] = (Ur0 + Vr0) * NInv
coeffs[j+1] = (Ur1 + Vr1) * NInv
coeffs[j+2] = (Ur2 + Vr2) * NInv
coeffs[j+3] = (Ur3 + Vr3) * NInv
coeffs[j+0] = (Ur0 + Vr0) / scale
coeffs[j+1] = (Ur1 + Vr1) / scale
coeffs[j+2] = (Ur2 + Vr2) / scale
coeffs[j+3] = (Ur3 + Vr3) / scale

coeffs[j+4] = (Ui0 + Vi0) * NInv
coeffs[j+5] = (Ui1 + Vi1) * NInv
coeffs[j+6] = (Ui2 + Vi2) * NInv
coeffs[j+7] = (Ui3 + Vi3) * NInv
coeffs[j+4] = (Ui0 + Vi0) / scale
coeffs[j+5] = (Ui1 + Vi1) / scale
coeffs[j+6] = (Ui2 + Vi2) / scale
coeffs[j+7] = (Ui3 + Vi3) / scale

UVr0 := Ur0 - Vr0
UVr1 := Ur1 - Vr1
Expand All @@ -416,14 +416,14 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
UVi2W := UVr2*Wi + UVi2*Wr
UVi3W := UVr3*Wi + UVi3*Wr

coeffs[j+N/2+0] = UVr0W
coeffs[j+N/2+1] = UVr1W
coeffs[j+N/2+2] = UVr2W
coeffs[j+N/2+3] = UVr3W
coeffs[j+N/2+0] = UVr0W / scale
coeffs[j+N/2+1] = UVr1W / scale
coeffs[j+N/2+2] = UVr2W / scale
coeffs[j+N/2+3] = UVr3W / scale

coeffs[j+N/2+4] = UVi0W
coeffs[j+N/2+5] = UVi1W
coeffs[j+N/2+6] = UVi2W
coeffs[j+N/2+7] = UVi3W
coeffs[j+N/2+4] = UVi0W / scale
coeffs[j+N/2+5] = UVi1W / scale
coeffs[j+N/2+6] = UVi2W / scale
coeffs[j+N/2+7] = UVi3W / scale
}
}
38 changes: 19 additions & 19 deletions math/poly/asm_fft_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ func fftInPlace(coeffs []float64, tw []complex128) {
}
}

func invFFTInPlaceAVX2(coeffs []float64, twInv []complex128, NInv float64)
func invFFTInPlaceAVX2(coeffs []float64, twInv []complex128, scale float64)

// invfftInPlace is a top-level function for inverse FFT.
// All internal inverse FFT implementations calls this function for performance.
func invFFTInPlace(coeffs []float64, twInv []complex128) {
if cpu.X86.HasAVX2 && cpu.X86.HasFMA {
invFFTInPlaceAVX2(coeffs, twInv, 2/float64(len(coeffs)))
invFFTInPlaceAVX2(coeffs, twInv, float64(len(coeffs)/2))
return
}

Expand Down Expand Up @@ -380,7 +380,7 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
}

// Last Loop
NInv := 2 / float64(N)
scale := float64(N / 2)
Wr := real(twInv[w])
Wi := imag(twInv[w])
for j := 0; j < N/2; j += 8 {
Expand All @@ -404,15 +404,15 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
Vi2 := coeffs[j+N/2+6]
Vi3 := coeffs[j+N/2+7]

coeffs[j+0] = (Ur0 + Vr0) * NInv
coeffs[j+1] = (Ur1 + Vr1) * NInv
coeffs[j+2] = (Ur2 + Vr2) * NInv
coeffs[j+3] = (Ur3 + Vr3) * NInv
coeffs[j+0] = (Ur0 + Vr0) / scale
coeffs[j+1] = (Ur1 + Vr1) / scale
coeffs[j+2] = (Ur2 + Vr2) / scale
coeffs[j+3] = (Ur3 + Vr3) / scale

coeffs[j+4] = (Ui0 + Vi0) * NInv
coeffs[j+5] = (Ui1 + Vi1) * NInv
coeffs[j+6] = (Ui2 + Vi2) * NInv
coeffs[j+7] = (Ui3 + Vi3) * NInv
coeffs[j+4] = (Ui0 + Vi0) / scale
coeffs[j+5] = (Ui1 + Vi1) / scale
coeffs[j+6] = (Ui2 + Vi2) / scale
coeffs[j+7] = (Ui3 + Vi3) / scale

UVr0 := Ur0 - Vr0
UVr1 := Ur1 - Vr1
Expand All @@ -434,14 +434,14 @@ func invFFTInPlace(coeffs []float64, twInv []complex128) {
UVi2W := UVr2*Wi + UVi2*Wr
UVi3W := UVr3*Wi + UVi3*Wr

coeffs[j+N/2+0] = UVr0W
coeffs[j+N/2+1] = UVr1W
coeffs[j+N/2+2] = UVr2W
coeffs[j+N/2+3] = UVr3W
coeffs[j+N/2+0] = UVr0W / scale
coeffs[j+N/2+1] = UVr1W / scale
coeffs[j+N/2+2] = UVr2W / scale
coeffs[j+N/2+3] = UVr3W / scale

coeffs[j+N/2+4] = UVi0W
coeffs[j+N/2+5] = UVi1W
coeffs[j+N/2+6] = UVi2W
coeffs[j+N/2+7] = UVi3W
coeffs[j+N/2+4] = UVi0W / scale
coeffs[j+N/2+5] = UVi1W / scale
coeffs[j+N/2+6] = UVi2W / scale
coeffs[j+N/2+7] = UVi3W / scale
}
}
9 changes: 6 additions & 3 deletions math/poly/asm_fft_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ m_loop_end:
MOVQ CX, DX
SHRQ $1, DX // N/2

VBROADCASTSD NInv+48(FP), Y12
VBROADCASTSD scale+48(FP), Y12
VBROADCASTSD (BX), Y13
VBROADCASTSD 8(BX), Y14

Expand All @@ -391,8 +391,8 @@ last_loop:
VADDPD Y2, Y0, Y4
VADDPD Y3, Y1, Y5

VMULPD Y4, Y12, Y4
VMULPD Y5, Y12, Y5
VDIVPD Y12, Y4, Y4
VDIVPD Y12, Y5, Y5

VSUBPD Y2, Y0, Y6 // UVr
VSUBPD Y3, Y1, Y7 // UVi
Expand All @@ -404,6 +404,9 @@ last_loop:
VMULPD Y6, Y14, Y9
VFMADD231PD Y7, Y13, Y9 // UVi * W

VDIVPD Y12, Y8, Y8
VDIVPD Y12, Y9, Y9

VMOVUPD Y4, (AX)(SI*8)
VMOVUPD Y5, 32(AX)(SI*8)
VMOVUPD Y8, (AX)(DI*8)
Expand Down
50 changes: 20 additions & 30 deletions math/poly/poly_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ const (
// Currently, this is set to 16, because AVX2 implementation of FFT and inverse FFT
// handles first/last two loops separately.
MinDegree = 1 << 4

// splitBound is denotes the maximum bits of N*B1*B2, where B1, B2 is the splitting bound of polynomial multiplication.
// Currently, this is set to 50, which gives failure rate less than 2^-73.
splitBound = 50
)

// Evaluator computes polynomial operations over the N-th cyclotomic ring.
Expand Down Expand Up @@ -116,60 +120,46 @@ func NewEvaluator[T num.Integer](N int) *Evaluator[T] {

// genTwiddleFactors generates twiddle factors for FFT.
func genTwiddleFactors(N int) (tw, twInv []complex128) {
wNj := make([]complex128, N/2)
wNjInv := make([]complex128, N/2)
twFFT := make([]complex128, N/2)
twInvFFT := make([]complex128, N/2)
for j := 0; j < N/2; j++ {
e := -2 * math.Pi * float64(j) / float64(N)
wNj[j] = cmplx.Exp(complex(0, e))
wNjInv[j] = cmplx.Exp(-complex(0, e))
}
vec.BitReverseInPlace(wNj)
vec.BitReverseInPlace(wNjInv)

logN := num.Log2(N)
w4Nj := make([]complex128, logN)
w4NjInv := make([]complex128, logN)
for j := 0; j < logN; j++ {
e := 2 * math.Pi * math.Exp2(float64(j)) / float64(4*N)
w4Nj[j] = cmplx.Exp(complex(0, e))
w4NjInv[j] = cmplx.Exp(-complex(0, e))
twFFT[j] = cmplx.Exp(complex(0, e))
twInvFFT[j] = cmplx.Exp(-complex(0, e))
}
vec.BitReverseInPlace(twFFT)
vec.BitReverseInPlace(twInvFFT)

tw = make([]complex128, N)
twInv = make([]complex128, N)
tw = make([]complex128, 0, N-1)
twInv = make([]complex128, 0, N-1)

w, t := 0, logN
for m := 1; m < N; m <<= 1 {
t--
for m, t := 1, N/2; m < N; m, t = m<<1, t>>1 {
twFold := cmplx.Exp(complex(0, 2*math.Pi*float64(t)/float64(4*N)))
for i := 0; i < m; i++ {
tw[w] = wNj[i] * w4Nj[t]
w++
tw = append(tw, twFFT[i]*twFold)
}
}

w, t = 0, 0
for m := N; m > 1; m >>= 1 {
for m, t := N, 1; m > 1; m, t = m>>1, t<<1 {
twInvFold := cmplx.Exp(complex(0, -2*math.Pi*float64(t)/float64(4*N)))
for i := 0; i < m/2; i++ {
twInv[w] = wNjInv[i] * w4NjInv[t]
w++
twInv = append(twInv, twInvFFT[i]*twInvFold)
}
t++
}
twInv[w-1] /= complex(float64(N), 0)

return tw, twInv
}

// splitParameters generates splitBits and splitCount for [*Evaluator.MulPoly].
func splitParameters[T num.Integer](N int) (splitBits T, splitCount int) {
splitBits = T(50-num.Log2(N)) / 2
splitBits = T(splitBound-num.Log2(N)) / 2
splitCount = int(math.Ceil(float64(num.SizeT[T]()) / float64(splitBits)))
return
}

// splitParametersBinary generates splitBits and splitCount for [*Evaluator.BinaryFourierPolyMulPoly].
func splitParametersBinary[T num.Integer](N int) (splitBits T, splitCount int) {
splitBits = T(50 - num.Log2(N))
splitBits = T(splitBound - num.Log2(N))
splitCount = int(math.Ceil(float64(num.SizeT[T]()) / float64(splitBits)))
return
}
Expand Down

0 comments on commit 597dcf6

Please sign in to comment.