diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go index 879ed591f85..1eea423999e 100644 --- a/go/callgraph/vta/graph.go +++ b/go/callgraph/vta/graph.go @@ -172,6 +172,26 @@ func (f function) String() string { return fmt.Sprintf("Function(%s)", f.f.Name()) } +// resultVar represents the result +// variable of a function, whether +// named or not. +type resultVar struct { + f *ssa.Function + index int // valid index into result var tuple +} + +func (o resultVar) Type() types.Type { + return o.f.Signature.Results().At(o.index).Type() +} + +func (o resultVar) String() string { + v := o.f.Signature.Results().At(o.index) + if n := v.Name(); n != "" { + return fmt.Sprintf("Return(%s[%s])", o.f.Name(), n) + } + return fmt.Sprintf("Return(%s[%d])", o.f.Name(), o.index) +} + // nestedPtrInterface node represents all references and dereferences // of locals and globals that have a nested pointer to interface type. // We merge such constructs into a single node for simplicity and without @@ -580,6 +600,24 @@ func (b *builder) call(c ssa.CallInstruction) { siteCallees(c, b.callGraph)(func(f *ssa.Function) bool { addArgumentFlows(b, c, f) + + site, ok := c.(ssa.Value) + if !ok { + return true // go or defer + } + + results := f.Signature.Results() + if results.Len() == 1 { + // When there is only one return value, the destination register does not + // have a tuple type. + b.addInFlowEdge(resultVar{f: f, index: 0}, b.nodeFromVal(site)) + } else { + tup := site.Type().(*types.Tuple) + for i := 0; i < results.Len(); i++ { + local := indexedLocal{val: site, typ: tup.At(i).Type(), index: i} + b.addInFlowEdge(resultVar{f: f, index: i}, local) + } + } return true }) } @@ -624,37 +662,11 @@ func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) { } } -// rtrn produces flows between values of r and c where -// c is a call instruction that resolves to the enclosing -// function of r based on b.callGraph. +// rtrn creates flow edges from the operands of the return +// statement to the result variables of the enclosing function. func (b *builder) rtrn(r *ssa.Return) { - n := b.callGraph.Nodes[r.Parent()] - // n != nil when b.callgraph is sound, but the client can - // pass any callgraph, including an underapproximate one. - if n == nil { - return - } - - for _, e := range n.In { - if cv, ok := e.Site.(ssa.Value); ok { - addReturnFlows(b, r, cv) - } - } -} - -func addReturnFlows(b *builder, r *ssa.Return, site ssa.Value) { - results := r.Results - if len(results) == 1 { - // When there is only one return value, the destination register does not - // have a tuple type. - b.addInFlowEdge(b.nodeFromVal(results[0]), b.nodeFromVal(site)) - return - } - - tup := site.Type().(*types.Tuple) - for i, r := range results { - local := indexedLocal{val: site, typ: tup.At(i).Type(), index: i} - b.addInFlowEdge(b.nodeFromVal(r), local) + for i, rs := range r.Results { + b.addInFlowEdge(b.nodeFromVal(rs), resultVar{f: r.Parent(), index: i}) } } @@ -795,7 +807,7 @@ func (b *builder) representative(n node) node { return field{StructType: canonicalize(i.StructType, &b.canon), index: i.index} case indexedLocal: return indexedLocal{typ: t, val: i.val, index: i.index} - case local, global, panicArg, recoverReturn, function: + case local, global, panicArg, recoverReturn, function, resultVar: return n default: panic(fmt.Errorf("canonicalizing unrecognized node %v", n)) diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go index d26416ca3ec..8ce4079c693 100644 --- a/go/callgraph/vta/graph_test.go +++ b/go/callgraph/vta/graph_test.go @@ -24,6 +24,7 @@ func TestNodeInterface(t *testing.T) { // - basic type int // - struct X with two int fields a and b // - global variable "gl" + // - "foo" function // - "main" function and its // - first register instruction t0 := *gl prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) @@ -33,6 +34,7 @@ func TestNodeInterface(t *testing.T) { pkg := prog.AllPackages()[0] main := pkg.Func("main") + foo := pkg.Func("foo") reg := firstRegInstr(main) // t0 := *gl X := pkg.Type("X").Type() gl := pkg.Var("gl") @@ -64,6 +66,7 @@ func TestNodeInterface(t *testing.T) { {local{val: reg}, "Local(t0)", bint}, {indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X}, {function{f: main}, "Function(main)", voidFunc}, + {resultVar{f: foo, index: 0}, "Return(foo[r])", bint}, {nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i}, {nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc}, {panicArg{}, "Panic", nil}, diff --git a/go/callgraph/vta/testdata/src/dynamic_calls.go b/go/callgraph/vta/testdata/src/dynamic_calls.go index f8f88983dce..da37a0d55d3 100644 --- a/go/callgraph/vta/testdata/src/dynamic_calls.go +++ b/go/callgraph/vta/testdata/src/dynamic_calls.go @@ -43,6 +43,8 @@ var g *B = &B{} // ensure *B.foo is created. // type flow that gets merged together during stringification. // WANT: +// Return(doWork[0]) -> Local(t2) +// Return(close[0]) -> Local(t2) // Local(t0) -> Local(ai), Local(ai), Local(bi), Local(bi) -// Constant(testdata.I) -> Local(t2) +// Constant(testdata.I) -> Return(close[0]), Return(doWork[0]) // Local(x) -> Local(t0) diff --git a/go/callgraph/vta/testdata/src/maps.go b/go/callgraph/vta/testdata/src/maps.go index f5f51a3d687..69709b56e36 100644 --- a/go/callgraph/vta/testdata/src/maps.go +++ b/go/callgraph/vta/testdata/src/maps.go @@ -41,5 +41,5 @@ func Baz(m map[I]I, b1, b2 B, n map[string]*J) *J { // Local(b2) -> Local(t1) // Local(t1) -> MapValue(testdata.I) // Local(t0) -> MapKey(testdata.I) -// Local(t3) -> MapValue(*testdata.J) +// Local(t3) -> MapValue(*testdata.J), Return(Baz[0]) // MapValue(*testdata.J) -> Local(t3) diff --git a/go/callgraph/vta/testdata/src/returns.go b/go/callgraph/vta/testdata/src/returns.go index b11b4321ba7..27bc418851e 100644 --- a/go/callgraph/vta/testdata/src/returns.go +++ b/go/callgraph/vta/testdata/src/returns.go @@ -51,7 +51,9 @@ func Baz(i I) *I { // WANT: // Local(i) -> Local(ii), Local(j) // Local(ii) -> Local(iii) -// Local(iii) -> Local(t0[0]), Local(t0[1]) -// Local(t1) -> Local(t0[0]) -// Local(t2) -> Local(t0[1]) -// Local(t0) -> Local(t1) +// Local(iii) -> Return(Foo[0]), Return(Foo[1]) +// Local(t1) -> Return(Baz[0]) +// Local(t1) -> Return(Bar[0]) +// Local(t2) -> Return(Bar[1]) +// Local(t0) -> Return(Do[0]) +// Return(Do[0]) -> Local(t1) diff --git a/go/callgraph/vta/testdata/src/simple.go b/go/callgraph/vta/testdata/src/simple.go index d3bfbe79284..71ddbe37163 100644 --- a/go/callgraph/vta/testdata/src/simple.go +++ b/go/callgraph/vta/testdata/src/simple.go @@ -16,3 +16,5 @@ type X struct { func main() { print(gl) } + +func foo() (r int) { return gl } diff --git a/go/callgraph/vta/testdata/src/static_calls.go b/go/callgraph/vta/testdata/src/static_calls.go index 74a27c166ad..e44ab68979d 100644 --- a/go/callgraph/vta/testdata/src/static_calls.go +++ b/go/callgraph/vta/testdata/src/static_calls.go @@ -38,4 +38,6 @@ func Baz(inp I) { // Local(inp) -> Local(i) // Local(t1) -> Local(iii) // Local(t2) -> Local(ii) -// Local(i) -> Local(t0[0]), Local(t0[1]) +// Local(i) -> Return(foo[0]), Return(foo[1]) +// Return(foo[0]) -> Local(t0[0]) +// Return(foo[1]) -> Local(t0[1])