Skip to content

Commit

Permalink
Wider AVX2 loops and less usage. (#162)
Browse files Browse the repository at this point in the history
* Experiment with 64 bytes/loop AVX2

* Only reduce when doing 64.

* Use no more than 8 goroutines for avx2 codegen.

```
name                         old speed      new speed      delta
Encode10x2x10000-32          33.3GB/s ± 0%  37.5GB/s ± 1%  +12.49%   (p=0.000 n=9+10)
Encode100x20x10000-32        3.79GB/s ± 5%  3.77GB/s ± 5%     ~     (p=0.853 n=10+10)
Encode17x3x1M-32             78.2GB/s ± 1%  76.0GB/s ± 6%     ~     (p=0.123 n=10+10)
Encode10x4x16M-32            28.3GB/s ± 0%  27.7GB/s ± 2%   -2.32%   (p=0.000 n=8+10)
Encode5x2x1M-32               112GB/s ± 1%   113GB/s ± 1%     ~     (p=0.796 n=10+10)
Encode10x2x1M-32              149GB/s ± 1%   129GB/s ± 3%  -13.24%   (p=0.000 n=9+10)
Encode10x4x1M-32             99.1GB/s ± 1%  91.5GB/s ± 3%   -7.74%  (p=0.000 n=10+10)
Encode50x20x1M-32            19.7GB/s ± 1%  19.8GB/s ± 1%     ~      (p=0.447 n=9+10)
Encode17x3x16M-32            33.4GB/s ± 0%  33.3GB/s ± 1%   -0.46%   (p=0.043 n=10+9)
Encode_8x4x8M-32             30.1GB/s ± 1%  29.4GB/s ± 3%   -2.31%  (p=0.000 n=10+10)
Encode_12x4x12M-32           30.6GB/s ± 0%  30.5GB/s ± 0%     ~      (p=0.720 n=10+9)
Encode_16x4x16M-32           31.5GB/s ± 0%  31.5GB/s ± 0%     ~      (p=0.497 n=10+9)
Encode_16x4x32M-32           31.9GB/s ± 0%  31.5GB/s ± 4%     ~     (p=0.165 n=10+10)
Encode_16x4x64M-32           32.4GB/s ± 0%  32.3GB/s ± 0%     ~       (p=0.321 n=9+8)
Encode_8x5x8M-32             28.4GB/s ± 0%  28.4GB/s ± 1%     ~      (p=0.237 n=10+8)
Encode_8x6x8M-32             27.0GB/s ± 0%  27.2GB/s ± 2%     ~     (p=0.075 n=10+10)
Encode_8x7x8M-32             26.0GB/s ± 1%  25.8GB/s ± 1%   -0.53%   (p=0.003 n=9+10)
Encode_8x9x8M-32             24.6GB/s ± 1%  24.4GB/s ± 1%   -0.63%  (p=0.000 n=10+10)
Encode_8x10x8M-32            23.7GB/s ± 1%  23.7GB/s ± 0%   +0.32%   (p=0.035 n=10+9)
Encode_8x11x8M-32            23.0GB/s ± 1%  22.8GB/s ± 0%   -0.59%    (p=0.000 n=9+8)
Encode_8x8x05M-32            66.4GB/s ± 1%  64.2GB/s ± 1%   -3.32%  (p=0.000 n=10+10)
Encode_8x8x1M-32             56.7GB/s ± 0%  75.7GB/s ± 2%  +33.55%    (p=0.000 n=9+9)
Encode_8x8x8M-32             24.9GB/s ± 0%  24.9GB/s ± 1%     ~      (p=0.146 n=8+10)
Encode_8x8x32M-32            23.8GB/s ± 0%  23.4GB/s ± 0%   -1.42%   (p=0.000 n=9+10)
Encode_24x8x24M-32           29.9GB/s ± 0%  29.9GB/s ± 0%     ~      (p=0.278 n=10+9)
Encode_24x8x48M-32           30.7GB/s ± 1%  30.7GB/s ± 0%     ~       (p=0.351 n=9+7)
StreamEncode10x2x10000-32    15.5GB/s ± 1%  16.5GB/s ± 0%   +6.53%   (p=0.000 n=10+9)
StreamEncode100x20x10000-32  2.09GB/s ± 1%  2.06GB/s ± 2%   -1.78%  (p=0.000 n=10+10)
StreamEncode17x3x1M-32       12.2GB/s ± 2%  12.3GB/s ± 1%   +1.19%   (p=0.008 n=10+9)
StreamEncode10x4x16M-32      8.68GB/s ± 0%  9.47GB/s ± 1%   +9.05%   (p=0.000 n=8+10)
StreamEncode5x2x1M-32        12.3GB/s ± 1%  13.2GB/s ± 1%   +7.61%  (p=0.000 n=10+10)
StreamEncode10x2x1M-32       11.5GB/s ± 4%  13.3GB/s ± 2%  +15.15%   (p=0.000 n=10+7)
```
  • Loading branch information
klauspost authored Jun 21, 2021
1 parent 46e0559 commit 7bd2279
Show file tree
Hide file tree
Showing 6 changed files with 30,032 additions and 13,965 deletions.
265 changes: 260 additions & 5 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@ import (

// Technically we can do slightly bigger, but we stay reasonable.
const inputMax = 10
const outputMax = 8
const outputMax = 10

var switchDefs [inputMax][outputMax]string
var switchDefsX [inputMax][outputMax]string

const perLoopBits = 5
const perLoop = 1 << perLoopBits

// Prefetch offsets, set to 0 to disable.
// Disabled since they appear to be consistently slower.
const prefetchSrc = 0
Expand All @@ -38,10 +35,14 @@ func main() {
Constraint(buildtags.Not("nogen").ToConstraint())
Constraint(buildtags.Term("gc").ToConstraint())

const perLoopBits = 5
const perLoop = 1 << perLoopBits

for i := 1; i <= inputMax; i++ {
for j := 1; j <= outputMax; j++ {
//genMulAvx2(fmt.Sprintf("mulAvxTwoXor_%dx%d", i, j), i, j, true)
genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%d", i, j), i, j, false)
genMulAvx2Sixty64(fmt.Sprintf("mulAvxTwo_%dx%d_64", i, j), i, j, false)
}
}
f, err := os.Create("../galois_gen_switch_amd64.go")
Expand Down Expand Up @@ -91,8 +92,10 @@ func galMulSlicesAvx2(matrix []byte, in, out [][]byte, start, stop int) int {
}

func genMulAvx2(name string, inputs int, outputs int, xor bool) {
total := inputs * outputs
const perLoopBits = 5
const perLoop = 1 << perLoopBits

total := inputs * outputs
doc := []string{
fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs),
}
Expand Down Expand Up @@ -316,3 +319,255 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
Label(name + "_end")
RET()
}

func genMulAvx2Sixty64(name string, inputs int, outputs int, xor bool) {
if outputs >= 4 {
return
}
const perLoopBits = 6
const perLoop = 1 << perLoopBits

total := inputs * outputs

doc := []string{
fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs),
}
if !xor {
doc = append(doc, "The output is initialized to 0.")
}

// Load shuffle masks on every use.
var loadNone bool
// Use registers for destination registers.
var regDst = false
var reloadLength = false

// lo, hi, 1 in, 1 out, 2 tmp, 1 mask
est := total*2 + outputs + 5
if outputs == 1 {
// We don't need to keep a copy of the input if only 1 output.
est -= 2
}

if true || est > 16 {
loadNone = true
// We run out of GP registers first, now.
if inputs+outputs > 13 {
regDst = false
}
// Save one register by reloading length.
if true || inputs+outputs > 12 && regDst {
reloadLength = true
}
}

TEXT(name, 0, fmt.Sprintf("func(matrix []byte, in [][]byte, out [][]byte, start, n int)"))

// SWITCH DEFINITION:
s := fmt.Sprintf("n = (n>>%d)<<%d\n", perLoopBits, perLoopBits)
s += fmt.Sprintf(" mulAvxTwo_%dx%d_64(matrix, in, out, start, n)\n", inputs, outputs)
s += fmt.Sprintf("\t\t\t\treturn n\n")
switchDefs[inputs-1][outputs-1] = s

if loadNone {
Comment("Loading no tables to registers")
} else {
// loadNone == false
Comment("Loading all tables to registers")
}
if regDst {
Comment("Destination kept in GP registers")
} else {
Comment("Destination kept on stack")
}

Doc(doc...)
Pragma("noescape")
Commentf("Full registers estimated %d YMM used", est)

length := Load(Param("n"), GP64())
matrixBase := GP64()
addr, err := Param("matrix").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, matrixBase)
SHRQ(U8(perLoopBits), length)
TESTQ(length, length)
JZ(LabelRef(name + "_end"))

inLo := make([]reg.VecVirtual, total)
inHi := make([]reg.VecVirtual, total)

for i := range inLo {
if loadNone {
break
}
tableLo := YMM()
tableHi := YMM()
VMOVDQU(Mem{Base: matrixBase, Disp: i * 64}, tableLo)
VMOVDQU(Mem{Base: matrixBase, Disp: i*64 + 32}, tableHi)
inLo[i] = tableLo
inHi[i] = tableHi
}

inPtrs := make([]reg.GPVirtual, inputs)
inSlicePtr := GP64()
addr, err = Param("in").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, inSlicePtr)
for i := range inPtrs {
ptr := GP64()
MOVQ(Mem{Base: inSlicePtr, Disp: i * 24}, ptr)
inPtrs[i] = ptr
}
// Destination
dst := make([]reg.VecVirtual, outputs)
dst2 := make([]reg.VecVirtual, outputs)
dstPtr := make([]reg.GPVirtual, outputs)
addr, err = Param("out").Base().Resolve()
if err != nil {
panic(err)
}
outBase := addr.Addr
outSlicePtr := GP64()
MOVQ(addr.Addr, outSlicePtr)
MOVQ(outBase, outSlicePtr)
for i := range dst {
dst[i] = YMM()
dst2[i] = YMM()
if !regDst {
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
dstPtr[i] = ptr
}

offset := GP64()
addr, err = Param("start").Resolve()
if err != nil {
panic(err)
}

MOVQ(addr.Addr, offset)
if regDst {
Comment("Add start offset to output")
for _, ptr := range dstPtr {
ADDQ(offset, ptr)
}
}

Comment("Add start offset to input")
for _, ptr := range inPtrs {
ADDQ(offset, ptr)
}
// Offset no longer needed unless not regdst

tmpMask := GP64()
MOVQ(U32(15), tmpMask)
lowMask := YMM()
MOVQ(tmpMask, lowMask.AsX())
VPBROADCASTB(lowMask.AsX(), lowMask)

if reloadLength {
length = Load(Param("n"), GP64())
SHRQ(U8(perLoopBits), length)
}
Label(name + "_loop")
if xor {
Commentf("Load %d outputs", outputs)
} else {
Commentf("Clear %d outputs", outputs)
}
for i := range dst {
if xor {
if regDst {
VMOVDQU(Mem{Base: dstPtr[i]}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
continue
}
ptr := GP64()
MOVQ(outBase, ptr)
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
} else {
VPXOR(dst[i], dst[i], dst[i])
VPXOR(dst2[i], dst2[i], dst2[i])
}
}

lookLow, lookHigh := YMM(), YMM()
lookLow2, lookHigh2 := YMM(), YMM()
inLow, inHigh := YMM(), YMM()
in2Low, in2High := YMM(), YMM()
for i := range inPtrs {
Commentf("Load and process 64 bytes from input %d to %d outputs", i, outputs)
VMOVDQU(Mem{Base: inPtrs[i]}, inLow)
VMOVDQU(Mem{Base: inPtrs[i], Disp: 32}, in2Low)
if prefetchSrc > 0 {
PREFETCHT0(Mem{Base: inPtrs[i], Disp: prefetchSrc})
}
ADDQ(U8(perLoop), inPtrs[i])
VPSRLQ(U8(4), inLow, inHigh)
VPSRLQ(U8(4), in2Low, in2High)
VPAND(lowMask, inLow, inLow)
VPAND(lowMask, in2Low, in2Low)
VPAND(lowMask, inHigh, inHigh)
VPAND(lowMask, in2High, in2High)
for j := range dst {
if loadNone {
VMOVDQU(Mem{Base: matrixBase, Disp: 64 * (i*outputs + j)}, lookLow)
VMOVDQU(Mem{Base: matrixBase, Disp: 32 + 64*(i*outputs+j)}, lookHigh)
VPSHUFB(in2Low, lookLow, lookLow2)
VPSHUFB(inLow, lookLow, lookLow)
VPSHUFB(in2High, lookHigh, lookHigh2)
VPSHUFB(inHigh, lookHigh, lookHigh)
} else {
VPSHUFB(inLow, inLo[i*outputs+j], lookLow)
VPSHUFB(in2Low, inLo[i*outputs+j], lookLow2)
VPSHUFB(inHigh, inHi[i*outputs+j], lookHigh)
VPSHUFB(in2High, inHi[i*outputs+j], lookHigh2)
}
VPXOR(lookLow, lookHigh, lookLow)
VPXOR(lookLow2, lookHigh2, lookLow2)
VPXOR(lookLow, dst[j], dst[j])
VPXOR(lookLow2, dst2[j], dst2[j])
}
}
Commentf("Store %d outputs", outputs)
for i := range dst {
if regDst {
VMOVDQU(dst[i], Mem{Base: dstPtr[i]})
VMOVDQU(dst2[i], Mem{Base: dstPtr[i], Disp: 32})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
ADDQ(U8(perLoop), dstPtr[i])
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
VMOVDQU(dst2[i], Mem{Base: ptr, Index: offset, Scale: 1, Disp: 32})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
Comment("Prepare for next loop")
if !regDst {
ADDQ(U8(perLoop), offset)
}
DECQ(length)
JNZ(LabelRef(name + "_loop"))
VZEROUPPER()

Label(name + "_end")
RET()
}
Loading

0 comments on commit 7bd2279

Please sign in to comment.