Skip to content

Commit

Permalink
Add code to generate the SVE and NEON routines for ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
fwessels committed Jun 11, 2024
1 parent c0f2a86 commit d22fb73
Show file tree
Hide file tree
Showing 8 changed files with 1,641 additions and 789 deletions.
415 changes: 415 additions & 0 deletions _gen/gen-arm-neon.go

Large diffs are not rendered by default.

293 changes: 293 additions & 0 deletions _gen/gen-arm-sve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
// Copyright 2024, Klaus Post/Minio Inc. See LICENSE for details.

package main

import (
"bufio"
"bytes"
"fmt"
"log"
"os"
"regexp"
"strconv"
"strings"

avxtwo2sve "github.com/fwessels/avxTwo2sve"
sve_as "github.com/fwessels/sve-as"
)

func patchLabel(line string) string {
return strings.ReplaceAll(line, "AvxTwo", "Sve")
}

func extractRoutine(filename, routine string) (lines []string, err error) {
file, err := os.Open(filename)
if err != nil {
return
}
defer file.Close()

// Create a scanner to read the file line by line
scanner := bufio.NewScanner(file)

// Iterate over each line
collect := false
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, routine) {
collect = true
}
if collect {
lines = append(lines, line)
}
if collect && strings.HasSuffix(line, "RET") {
collect = false
}
}

// Check for any errors that occurred during scanning
err = scanner.Err()
return
}

func addArmInitializations(instructions []string) (processed []string) {
for _, instr := range instructions {
processed = append(processed, instr)
if strings.HasPrefix(instr, "TEXT ·") {
sve := "ptrue p0.d"
opcode, err := sve_as.Assemble(sve)
if err != nil {
processed = append(processed, fmt.Sprintf(" WORD $0x00000000 // %-44s\n", sve))
} else {
processed = append(processed, fmt.Sprintf(" WORD $0x%08x // %-44s\n", opcode, sve))
}
}
}
return
}

// Expand #defines
func expandHashDefines(instructions []string) (processed []string) {
for _, instr := range instructions {
if strings.Contains(instr, "XOR3WAY") {
f := strings.Fields(instr)
if len(f) >= 3 {
dst := strings.ReplaceAll(f[len(f)-1], ")", "")
b := strings.ReplaceAll(f[len(f)-2], ",", "")
a := strings.ReplaceAll(f[len(f)-3], ",", "")

processed = append(processed, fmt.Sprintf("VPXOR %s, %s, %s", a, dst, dst))
processed = append(processed, fmt.Sprintf("VPXOR %s, %s, %s", b, dst, dst))
} else {
log.Fatalf("Not enough arguments for 'XOR3WAY' macro: %d", len(f))
}
} else if !strings.Contains(instr, "VZEROUPPER") {
processed = append(processed, instr)
}
}
return
}

func convertRoutine(asmBuf *bytes.Buffer, instructions []string) {

asmF := func(format string, args ...interface{}) {
(*asmBuf).WriteString(fmt.Sprintf(format, args...))
}

wordOpcode := regexp.MustCompile(`WORD \$0x[0-9a-f]{8}`)

for _, instr := range instructions {
instr = strings.TrimSpace(instr)
if instr == "" {
asmF("\n")
} else if strings.HasPrefix(instr, "TEXT ") { // function header
asmF("%s\n", patchLabel(instr))
} else if wordOpcode.MatchString(instr) { // arm code
asmF(" %s\n", instr)
} else if strings.HasPrefix(instr, "//") { // comment
asmF(" %s\n", instr)
} else if strings.HasSuffix(instr, ":") { // label
asmF("%s\n", patchLabel(instr))
} else {
sve, plan9, err := avxtwo2sve.AvxTwo2Sve(instr, patchLabel)
if err != nil {
panic(err)
} else if !plan9 {
opcode, err := sve_as.Assemble(sve)
if err != nil {
asmF(" WORD $0x00000000 // %-44s\n", sve)
} else {
asmF(" WORD $0x%08x // %-44s\n", opcode, sve)
}
} else {
asmF(" %s\n", sve)
}
}
}
}

func fromAvx2ToSve() {
asmOut, goOut := &bytes.Buffer{}, &bytes.Buffer{}

goOut.WriteString(`// Code generated by command: go generate ` + os.Getenv("GOFILE") + `. DO NOT EDIT.` + "\n\n")
goOut.WriteString("//go:build !noasm && !appengine && !gccgo && !nopshufb\n\n")
goOut.WriteString("package reedsolomon\n\n")

const input = 10
const AVX2_CODE = "../galois_gen_amd64.s"

// Processing 64 bytes variants
for output := 1; output <= 3; output++ {
for op := ""; len(op) <= 3; op += "Xor" {
templName := fmt.Sprintf("mulAvxTwo_%dx%d_64%s", input, output, op)
funcDef := fmt.Sprintf("func %s(matrix []byte, in [][]byte, out [][]byte, start int, n int)", strings.ReplaceAll(templName, "AvxTwo", "Sve"))

// asm first
lines, err := extractRoutine(AVX2_CODE, fmt.Sprintf("TEXT ·%s(SB)", templName))
if err != nil {
log.Fatal(err)
}
lines = expandHashDefines(lines)

convertRoutine(asmOut, lines)

// add newline after RET
asmOut.WriteString("\n")

// golang declaration
goOut.WriteString(fmt.Sprintf("//go:noescape\n%s\n\n", funcDef))
}
}

// Processing 32 bytes variants
for output := 4; output <= 10; output++ {
for op := ""; len(op) <= 3; op += "Xor" {
templName := fmt.Sprintf("mulAvxTwo_%dx%d%s", input, output, op)
funcDef := fmt.Sprintf("func %s(matrix []byte, in [][]byte, out [][]byte, start int, n int)", strings.ReplaceAll(templName, "AvxTwo", "Sve"))

// asm first
lines, err := extractRoutine(AVX2_CODE, fmt.Sprintf("TEXT ·%s(SB)", templName))
if err != nil {
log.Fatal(err)
}
lines = expandHashDefines(lines)

// add additional initialization for SVE
// (for predicated loads and stores in
// case of register shortage)
lines = addArmInitializations(lines)

convertRoutine(asmOut, lines)

// add newline after RET
asmOut.WriteString("\n")

// golang declaration
goOut.WriteString(fmt.Sprintf("//go:noescape\n%s\n\n", funcDef))
}
}

if err := os.WriteFile("../galois_gen_arm64.s", asmOut.Bytes(), 0644); err != nil {
log.Fatal(err)
}
if err := os.WriteFile("../galois_gen_arm64.go", goOut.Bytes(), 0644); err != nil {
log.Fatal(err)
}
}

func insertEarlyExit(lines []string, funcName string, outputs int) (processed []string) {

const reg = "R16"
label := funcName + "_store"

reComment := regexp.MustCompile(fmt.Sprintf(`// Load and process \d* bytes from input (\d*) to %d outputs`, outputs))
reLoop := regexp.MustCompile(`^` + strings.ReplaceAll(label, "store", "loop") + `:`)
reStore := regexp.MustCompile(fmt.Sprintf(`// Store %d outputs`, outputs))

for _, line := range lines {
if matches := reLoop.FindAllStringSubmatch(line, -1); len(matches) == 1 {
lastline := processed[len(processed)-1]
processed = processed[:len(processed)-1]
processed = append(processed, "")
processed = append(processed, fmt.Sprintf(" // Load number of input shards"))
processed = append(processed, fmt.Sprintf(" MOVD in_len+32(FP), %s", reg))
processed = append(processed, lastline)
}

if matches := reComment.FindAllStringSubmatch(line, -1); len(matches) == 1 {
if inputs, err := strconv.Atoi(matches[0][1]); err != nil {
panic(err)
} else {
if inputs > 0 && inputs < 10 {
lastline := processed[len(processed)-1]
processed = processed[:len(processed)-1]
processed = append(processed, fmt.Sprintf(" // Check for early termination"))
processed = append(processed, fmt.Sprintf(" CMP $%d, %s", inputs, reg))
processed = append(processed, fmt.Sprintf(" BEQ %s", label))
processed = append(processed, lastline)
}
}
}

if matches := reStore.FindAllStringSubmatch(line, -1); len(matches) == 1 {
processed = append(processed, fmt.Sprintf("%s:", label))
}

processed = append(processed, line)
}
return
}

func addEarlyExit(arch string) {
const filename = "../galois_gen_arm64.s"
asmOut := &bytes.Buffer{}

asmOut.WriteString(`// Code generated by command: go generate ` + os.Getenv("GOFILE") + `. DO NOT EDIT.` + "\n\n")
asmOut.WriteString("//go:build !appengine && !noasm && !nogen && !nopshufb && gc\n\n")
asmOut.WriteString(`#include "textflag.h"` + "\n\n")

input := 10
for outputs := 1; outputs <= 3; outputs++ {
for op := ""; len(op) <= 3; op += "Xor" {
funcName := fmt.Sprintf("mul%s_%dx%d_64%s", arch, input, outputs, op)
funcDef := fmt.Sprintf("func %s(matrix []byte, in [][]byte, out [][]byte, start int, n int)", funcName)

lines, _ := extractRoutine(filename, fmt.Sprintf("TEXT ·%s(SB)", funcName))

// prepend output with commented out function definition and comment
asmOut.WriteString(fmt.Sprintf("// %s\n", funcDef))
asmOut.WriteString("// Requires: SVE\n")

lines = insertEarlyExit(lines, funcName, outputs)

asmOut.WriteString(strings.Join(lines, "\n"))
asmOut.WriteString("\n\n")
}
}

for outputs := 4; outputs <= 10; outputs++ {
for op := ""; len(op) <= 3; op += "Xor" {
funcName := fmt.Sprintf("mul%s_%dx%d%s", arch, input, outputs, op)
funcDef := fmt.Sprintf("func %s(matrix []byte, in [][]byte, out [][]byte, start int, n int)", funcName)

lines, _ := extractRoutine(filename, fmt.Sprintf("TEXT ·%s(SB)", funcName))

// prepend output with commented out function definition and comment
asmOut.WriteString(fmt.Sprintf("// %s\n", funcDef))
asmOut.WriteString("// Requires: SVE\n")

lines = insertEarlyExit(lines, funcName, outputs)
asmOut.WriteString(strings.Join(lines, "\n"))
asmOut.WriteString("\n\n")
}
}

if err := os.WriteFile("../galois_gen_arm64.s", asmOut.Bytes(), 0644); err != nil {
log.Fatal(err)
}
}

func genArmSve() {
fromAvx2ToSve()
addEarlyExit("Sve")
}
24 changes: 16 additions & 8 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ func main() {
genSwitch()
genGF16()
genGF8()

if pshufb {
genArmSve()
genArmNeon()
}
Generate()
}

Expand Down Expand Up @@ -449,7 +454,10 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
for _, ptr := range inPtrs {
ADDQ(offset, ptr)
}
// Offset no longer needed unless not regdst
// Offset no longer needed unless not regDst
if !regDst {
SHRQ(U8(3), offset) // divide by 8 since we'll be scaling it up when loading or storing
}

tmpMask := GP64()
MOVQ(U32(15), tmpMask)
Expand Down Expand Up @@ -478,9 +486,9 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 8}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
}
}
}
Expand Down Expand Up @@ -508,9 +516,9 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
} else {
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: j * 24}, ptr)
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[j])
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 8}, dst[j])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
}
}
}
Expand Down Expand Up @@ -543,14 +551,14 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 8})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
}
}
Comment("Prepare for next loop")
if !regDst {
ADDQ(U8(perLoop), offset)
ADDQ(U8(perLoop>>3), offset)
}
DECQ(length)
JNZ(LabelRef(name + "_loop"))
Expand Down
4 changes: 3 additions & 1 deletion _gen/go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
module github.com/klauspost/reedsolomon/_gen

go 1.19
go 1.21.5

require (
github.com/klauspost/asmfmt v1.3.1
github.com/mmcloughlin/avo v0.5.1-0.20221128045730-bf1d05562091
)

require (
github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471 // indirect
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions _gen/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471 h1:omdgAKxePZxbMC7HZPw99QMPeH7fKh3t2QRSZ0YFA/0=
github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471/go.mod h1:9+ibRsEIs0vLXkalKCGEbZfVS4fafeIvMvM9GvIsdeQ=
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f h1:HQud3yIU82LdkQzHEYiSJs73wCHjprIqeZE9JvSjKbQ=
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f/go.mod h1:j3s7EY79XxNMyjx/54Vo6asZafWU4yijB+KIfj4hrh8=
github.com/klauspost/asmfmt v1.3.1 h1:7xZi1N7s9gTLbqiM8KUv8TLyysavbTRGBT5/ly0bRtw=
github.com/klauspost/asmfmt v1.3.1/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
github.com/mmcloughlin/avo v0.5.1-0.20221128045730-bf1d05562091 h1:C2c8ttOBeyhs1SvyCXVPCFd0EqtPiTKGnMWQ+JkM0Lc=
Expand Down
Loading

0 comments on commit d22fb73

Please sign in to comment.