Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: allow to use multiple returns in inlined functions #2594

Merged
merged 3 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 52 additions & 47 deletions pkg/compiler/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ type codegen struct {
labels map[labelWithType]uint16
// A list of nested label names together with evaluation stack depth.
labelList []labelWithStackSize
// inlineLabelOffsets contains size of labelList at the start of inline call processing.
// For such calls, we need to drop only the newly created part of stack.
inlineLabelOffsets []int
// inlineContext contains info about inlined function calls.
inlineContext []inlineContextSingle
// globalInlineCount contains the amount of auxiliary variables introduced by
// function inlining during global variables initialization.
globalInlineCount int
Expand Down Expand Up @@ -146,6 +145,14 @@ type nameWithLocals struct {
count int
}

type inlineContextSingle struct {
// labelOffset contains size of labelList at the start of inline call processing.
// For such calls, we need to drop only the newly created part of stack.
labelOffset int
// returnLabel contains label ID pointing to the first instruction right after the call.
returnLabel uint16
}

type varType int

const (
Expand Down Expand Up @@ -680,8 +687,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {

cnt := 0
start := 0
if len(c.inlineLabelOffsets) > 0 {
start = c.inlineLabelOffsets[len(c.inlineLabelOffsets)-1]
if len(c.inlineContext) > 0 {
start = c.inlineContext[len(c.inlineContext)-1].labelOffset
}
for i := start; i < len(c.labelList); i++ {
cnt += c.labelList[i].sz
Expand Down Expand Up @@ -711,6 +718,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.saveSequencePoint(n)
if len(c.pkgInfoInline) == 0 {
emit.Opcodes(c.prog.BinWriter, opcode.RET)
} else {
emit.Jmp(c.prog.BinWriter, opcode.JMPL, c.inlineContext[len(c.inlineContext)-1].returnLabel)
}
return nil

Expand Down Expand Up @@ -2211,7 +2220,7 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {

func (c *codegen) writeJumps(b []byte) ([]byte, error) {
ctx := vm.NewContext(b)
var offsets []int
var nopOffsets []int
for op, param, err := ctx.Next(); err == nil && ctx.IP() < len(b); op, param, err = ctx.Next() {
switch op {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
Expand All @@ -2235,13 +2244,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
return nil, err
}
if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 {
offsets = append(offsets, ctx.IP())
if op == opcode.JMPL && offset == 5 {
copy(b[ctx.IP():], []byte{byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)})
nopOffsets = append(nopOffsets, ctx.IP(), ctx.IP()+1, ctx.IP()+2, ctx.IP()+3, ctx.IP()+4)
} else {
copy(b[ctx.IP():], []byte{byte(toShortForm(op)), byte(offset), byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)})
nopOffsets = append(nopOffsets, ctx.IP()+2, ctx.IP()+3, ctx.IP()+4)
}
}
case opcode.INITSLOT:
nextIP := ctx.NextIP()
info := c.reverseOffsetMap[ctx.IP()]
if argCount := b[nextIP-1]; info.count == 0 && argCount == 0 {
offsets = append(offsets, ctx.IP())
copy(b[ctx.IP():], []byte{byte(opcode.NOP), byte(opcode.NOP), byte(opcode.NOP)})
nopOffsets = append(nopOffsets, ctx.IP(), ctx.IP()+1, ctx.IP()+2)
continue
}

Expand All @@ -2253,20 +2269,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
}

if c.deployEndOffset >= 0 {
_, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), offsets)
_, end := correctRange(uint16(c.initEndOffset+1), uint16(c.deployEndOffset), nopOffsets)
c.deployEndOffset = int(end)
}
if c.initEndOffset > 0 {
_, end := correctRange(0, uint16(c.initEndOffset), offsets)
_, end := correctRange(0, uint16(c.initEndOffset), nopOffsets)
c.initEndOffset = int(end)
}

// Correct function ip range.
// Note: indices are sorted in increasing order.
for _, f := range c.funcs {
f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, offsets)
f.rng.Start, f.rng.End = correctRange(f.rng.Start, f.rng.End, nopOffsets)
}
return shortenJumps(b, offsets), nil
return removeNOPs(b, nopOffsets), nil
}

func correctRange(start, end uint16, offsets []int) (uint16, uint16) {
Expand All @@ -2277,10 +2293,10 @@ loop:
case ind > int(end):
break loop
case ind < int(start):
newStart -= longToShortRemoveCount
newEnd -= longToShortRemoveCount
newStart--
newEnd--
case ind >= int(start):
newEnd -= longToShortRemoveCount
newEnd--
}
}
return newStart, newEnd
Expand All @@ -2303,21 +2319,22 @@ func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) {
return offset, nil
}

// longToShortRemoveCount is a difference between short and long instruction sizes in bytes.
// By pure coincidence, this is also the size of `INITSLOT` instruction.
const longToShortRemoveCount = 3

// shortenJumps converts b to a program where all long JMP*/CALL* specified by absolute offsets
// removeNOPs converts b to a program where all long JMP*/CALL* specified by absolute offsets
// are replaced with their corresponding short counterparts. It panics if either b or offsets are invalid.
// This is done in 2 passes:
// 1. Alter jump offsets taking into account parts to be removed.
// 2. Perform actual removal of jump targets.
// Note: after jump offsets altering, there can appear new candidates for conversion.
// These are ignored for now.
func shortenJumps(b []byte, offsets []int) []byte {
if len(offsets) == 0 {
func removeNOPs(b []byte, nopOffsets []int) []byte {
if len(nopOffsets) == 0 {
return b
}
for i := range nopOffsets {
if b[nopOffsets[i]] != byte(opcode.NOP) {
panic("NOP offset is invalid")
}
}

// 1. Alter existing jump offsets.
ctx := vm.NewContext(b)
Expand All @@ -2330,57 +2347,46 @@ func shortenJumps(b []byte, offsets []int) []byte {
opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY:
offset := int(int8(b[nextIP-1]))
offset += calcOffsetCorrection(ip, ip+offset, offsets)
offset += calcOffsetCorrection(ip, ip+offset, nopOffsets)
b[nextIP-1] = byte(offset)
case opcode.TRY:
catchOffset := int(int8(b[nextIP-2]))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets)
b[nextIP-1] = byte(catchOffset)
finallyOffset := int(int8(b[nextIP-1]))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets)
b[nextIP-1] = byte(finallyOffset)
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL:
arg := b[nextIP-4:]
offset := int(int32(binary.LittleEndian.Uint32(arg)))
offset += calcOffsetCorrection(ip, ip+offset, offsets)
offset += calcOffsetCorrection(ip, ip+offset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(offset))
case opcode.TRYL:
arg := b[nextIP-8:]
catchOffset := int(int32(binary.LittleEndian.Uint32(arg)))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(catchOffset))
arg = b[nextIP-4:]
finallyOffset := int(int32(binary.LittleEndian.Uint32(arg)))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, nopOffsets)
binary.LittleEndian.PutUint32(arg, uint32(finallyOffset))
}
}

// 2. Convert instructions.
copyOffset := 0
l := len(offsets)
if op := opcode.Opcode(b[offsets[0]]); op != opcode.INITSLOT {
b[offsets[0]] = byte(toShortForm(op))
}
l := len(nopOffsets)
for i := 0; i < l; i++ {
start := offsets[i] + 2
if b[offsets[i]] == byte(opcode.INITSLOT) {
start = offsets[i]
}

start := nopOffsets[i]
end := len(b)
if i != l-1 {
end = offsets[i+1]
if op := opcode.Opcode(b[offsets[i+1]]); op != opcode.INITSLOT {
end += 2
b[offsets[i+1]] = byte(toShortForm(op))
}
end = nopOffsets[i+1]
}
copy(b[start-copyOffset:], b[start+3:end])
copyOffset += longToShortRemoveCount
copy(b[start-copyOffset:], b[start+1:end])
copyOffset++
}
return b[:len(b)-copyOffset]
}
Expand All @@ -2392,9 +2398,8 @@ func calcOffsetCorrection(ip, target int, offsets []int) int {
})
for i := start; i < len(offsets) && (offsets[i] < target || offsets[i] <= ip); i++ {
ind := offsets[i]
if ip <= ind && ind < target ||
ind != ip && target <= ind && ind <= ip {
cnt += longToShortRemoveCount
if ip <= ind && ind < target || target <= ind && ind < ip {
cnt++
}
}
if ip < target {
Expand Down
14 changes: 9 additions & 5 deletions pkg/compiler/inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import (
// <inline body of f directly>
// }
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
labelSz := len(c.labelList)
offSz := len(c.inlineLabelOffsets)
c.inlineLabelOffsets = append(c.inlineLabelOffsets, labelSz)
offSz := len(c.inlineContext)
c.inlineContext = append(c.inlineContext, inlineContextSingle{
labelOffset: len(c.labelList),
returnLabel: c.newLabel(),
})

defer func() {
c.inlineLabelOffsets = c.inlineLabelOffsets[:offSz]
c.labelList = c.labelList[:labelSz]
c.labelList = c.labelList[:c.inlineContext[offSz].labelOffset]
c.inlineContext = c.inlineContext[:offSz]
}()

pkg := c.packageCache[f.pkg.Path()]
Expand Down Expand Up @@ -113,6 +116,7 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
c.fillImportMap(f.file, pkg)
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
ast.Walk(c, f.decl.Body)
c.setLabel(c.inlineContext[offSz].returnLabel)
if c.scope.voidCalls[n] {
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
Expand Down
43 changes: 43 additions & 0 deletions pkg/compiler/inline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,46 @@ func TestInlinedMethodWithPointer(t *testing.T) {
}`
eval(t, src, big.NewInt(100542))
}

func TestInlineConditionalReturn(t *testing.T) {
srcTmpl := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/c"
func Main() int {
x := %d
if c.Is42(x) {
return 100
}
return 10
}`
t.Run("true", func(t *testing.T) {
eval(t, fmt.Sprintf(srcTmpl, 123), big.NewInt(10))
})
t.Run("false", func(t *testing.T) {
eval(t, fmt.Sprintf(srcTmpl, 42), big.NewInt(100))
})
}

func TestInlineDoubleConditionalReturn(t *testing.T) {
srcTmpl := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/c"
func Main() int {
return c.Transform(%d, %d)
}`

testCase := []struct {
name string
a, b, result int
}{
{"true, true, small", 42, 3, 6},
{"true, true, big", 42, 15, 15},
{"true, false", 42, 42, 42},
{"false, true", 3, 11, 6},
{"false, false", 3, 42, 6},
}

for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
eval(t, fmt.Sprintf(srcTmpl, tc.a, tc.b), big.NewInt(int64(tc.result)))
})
}
}
30 changes: 15 additions & 15 deletions pkg/compiler/jumps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func testShortenJumps(t *testing.T, before, after []opcode.Opcode, indices []int
for i := range before {
prog[i] = byte(before[i])
}
raw := shortenJumps(prog, indices)
raw := removeNOPs(prog, indices)
actual := make([]opcode.Opcode, len(raw))
for i := range raw {
actual[i] = opcode.Opcode(raw[i])
Expand All @@ -36,53 +36,53 @@ func TestShortenJumps(t *testing.T) {
for op, sop := range testCases {
t.Run(op.String(), func(t *testing.T) {
before := []opcode.Opcode{
op, 6, 0, 0, 0, opcode.PUSH1, opcode.NOP, // <- first jump to here
sop, 6, opcode.NOP, opcode.NOP, opcode.NOP, opcode.PUSH1, opcode.NOP, // <- first jump to here
op, 9, 12, 0, 0, opcode.PUSH1, opcode.NOP, // <- last jump to here
op, 255, 0, 0, 0, op, 0xFF - 5, 0xFF, 0xFF, 0xFF,
sop, 249, opcode.NOP, opcode.NOP, opcode.NOP, sop, 0xFF - 5, opcode.NOP, opcode.NOP, opcode.NOP,
}
after := []opcode.Opcode{
sop, 3, opcode.PUSH1, opcode.NOP,
op, 3, 12, 0, 0, opcode.PUSH1, opcode.NOP,
sop, 249, sop, 0xFF - 2,
}
testShortenJumps(t, before, after, []int{0, 14, 19})
testShortenJumps(t, before, after, []int{2, 3, 4, 16, 17, 18, 21, 22, 23})
})
}
t.Run("NoReplace", func(t *testing.T) {
b := []byte{0, 1, 2, 3, 4, 5}
expected := []byte{0, 1, 2, 3, 4, 5}
require.Equal(t, expected, shortenJumps(b, nil))
require.Equal(t, expected, removeNOPs(b, nil))
})
t.Run("InvalidIndex", func(t *testing.T) {
before := []byte{byte(opcode.PUSH1), 0, 0, 0, 0}
require.Panics(t, func() {
shortenJumps(before, []int{0})
removeNOPs(before, []int{0})
})
})
t.Run("SideConditions", func(t *testing.T) {
t.Run("Forward", func(t *testing.T) {
before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMPL, 5, 0, 0, 0,
opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
}
after := []opcode.Opcode{
opcode.JMP, 2,
opcode.JMP, 2,
}
testShortenJumps(t, before, after, []int{0, 5})
testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9})
})
t.Run("Backwards", func(t *testing.T) {
before := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
opcode.JMPL, 0xFF - 4, 0xFF, 0xFF, 0xFF,
opcode.JMP, 5, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP,
opcode.JMP, 0xFF - 4, opcode.NOP, opcode.NOP, opcode.NOP,
}
after := []opcode.Opcode{
opcode.JMPL, 5, 0, 0, 0,
opcode.JMP, 0xFF - 4,
opcode.JMP, 2,
opcode.JMP, 0xFF - 1,
opcode.JMP, 0xFF - 1,
}
testShortenJumps(t, before, after, []int{5, 10})
testShortenJumps(t, before, after, []int{2, 3, 4, 7, 8, 9, 12, 13, 14})
})
})
}
Expand Down
Loading