Skip to content

Commit

Permalink
optimize fanin and fanout nodes to use less reflection (#1250)
Browse files Browse the repository at this point in the history
  • Loading branch information
lovromazgon authored Oct 24, 2023
1 parent bbdb074 commit 3111797
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 24 deletions.
56 changes: 34 additions & 22 deletions pkg/pipeline/stream/fanin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package stream

import (
"context"
"reflect"
)

type FaninNode struct {
Expand Down Expand Up @@ -49,31 +48,14 @@ func (n *FaninNode) Run(ctx context.Context) error {
n.running = false
}()

cases := make([]reflect.SelectCase, len(n.in)+1)
cases[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())}
for i, ch := range n.in {
cases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
}
trigger := n.trigger(ctx)

for {
chosen, value, ok := reflect.Select(cases)
// ok will be true if the channel has not been closed.
if !ok {
if chosen == 0 {
// context is done
return ctx.Err()
}
// one of the in channels is closed, remove it from select case
cases = append(cases[:chosen], cases[chosen+1:]...)
if len(cases) == 1 {
// only context is left, we're done
return nil
}
continue
msg, err := trigger()
if err != nil || msg == nil {
return err
}

msg := value.Interface().(*Message)

select {
case <-ctx.Done():
return msg.Nack(ctx.Err(), n.ID())
Expand All @@ -82,6 +64,36 @@ func (n *FaninNode) Run(ctx context.Context) error {
}
}

func (n *FaninNode) trigger(ctx context.Context) func() (*Message, error) {
in := make([]<-chan *Message, len(n.in))
copy(in, n.in)

f := n.chooseSelectFunc(ctx, in)

return func() (*Message, error) {
for {
chosen, msg, ok := f()
// ok will be true if the channel has not been closed.
if !ok {
if chosen == 0 {
// context is done
return nil, ctx.Err()
}
// one of the in channels is closed, remove it from select case
in = append(in[:chosen-1], in[chosen:]...)
if len(in) == 0 {
// only context is left, we're done
return nil, nil
}

f = n.chooseSelectFunc(ctx, in)
continue // keep selecting with new select func
}
return msg, nil
}
}
}

func (n *FaninNode) Sub(in <-chan *Message) {
n.in = append(n.in, in)
}
Expand Down
168 changes: 168 additions & 0 deletions pkg/pipeline/stream/fanin_select.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright © 2023 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package stream

import (
"context"
"reflect"
)

func (n *FaninNode) chooseSelectFunc(ctx context.Context, in []<-chan *Message) func() (int, *Message, bool) {
switch len(in) {
case 1:
return func() (int, *Message, bool) { return n.select1(ctx, in[0]) }
case 2:
return func() (int, *Message, bool) { return n.select2(ctx, in[0], in[1]) }
case 3:
return func() (int, *Message, bool) { return n.select3(ctx, in[0], in[1], in[2]) }
case 4:
return func() (int, *Message, bool) { return n.select4(ctx, in[0], in[1], in[2], in[3]) }
case 5:
return func() (int, *Message, bool) { return n.select5(ctx, in[0], in[1], in[2], in[3], in[4]) }
case 6:
return func() (int, *Message, bool) { return n.select6(ctx, in[0], in[1], in[2], in[3], in[4], in[5]) }
default:
// use reflection for more channels
cases := make([]reflect.SelectCase, len(in)+1)
cases[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())}
for i, ch := range in {
cases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
}
return func() (int, *Message, bool) {
chosen, value, ok := reflect.Select(cases)
if !ok { // a channel was closed
return chosen, nil, ok
}
return chosen, value.Interface().(*Message), ok
}
}
}

func (*FaninNode) select1(
ctx context.Context,
c1 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
}
}

func (*FaninNode) select2(
ctx context.Context,
c1 <-chan *Message,
c2 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
case val, ok := <-c2:
return 2, val, ok
}
}

func (*FaninNode) select3(
ctx context.Context,
c1 <-chan *Message,
c2 <-chan *Message,
c3 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
case val, ok := <-c2:
return 2, val, ok
case val, ok := <-c3:
return 3, val, ok
}
}

func (*FaninNode) select4(
ctx context.Context,
c1 <-chan *Message,
c2 <-chan *Message,
c3 <-chan *Message,
c4 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
case val, ok := <-c2:
return 2, val, ok
case val, ok := <-c3:
return 3, val, ok
case val, ok := <-c4:
return 4, val, ok
}
}

func (*FaninNode) select5(
ctx context.Context,
c1 <-chan *Message,
c2 <-chan *Message,
c3 <-chan *Message,
c4 <-chan *Message,
c5 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
case val, ok := <-c2:
return 2, val, ok
case val, ok := <-c3:
return 3, val, ok
case val, ok := <-c4:
return 4, val, ok
case val, ok := <-c5:
return 5, val, ok
}
}

func (*FaninNode) select6(
ctx context.Context,
c1 <-chan *Message,
c2 <-chan *Message,
c3 <-chan *Message,
c4 <-chan *Message,
c5 <-chan *Message,
c6 <-chan *Message,
) (int, *Message, bool) {
select {
case <-ctx.Done():
return 0, nil, false
case val, ok := <-c1:
return 1, val, ok
case val, ok := <-c2:
return 2, val, ok
case val, ok := <-c3:
return 3, val, ok
case val, ok := <-c4:
return 4, val, ok
case val, ok := <-c5:
return 5, val, ok
case val, ok := <-c6:
return 6, val, ok
}
}
25 changes: 25 additions & 0 deletions pkg/pipeline/stream/fanout.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func (n *FanoutNode) Run(ctx context.Context) error {
n.running = false
}()

if len(n.out) == 1 {
// shortcut if there's only 1 destination
return n.select1(ctx)
}

var wg sync.WaitGroup
for {
select {
Expand Down Expand Up @@ -141,6 +146,26 @@ func (n *FanoutNode) Run(ctx context.Context) error {
}
}

func (n *FanoutNode) select1(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg, ok := <-n.in:
if !ok {
// pipeline closed
return nil
}
select {
case <-ctx.Done():
return msg.Nack(ctx.Err(), n.ID())
case n.out[0] <- msg:
// all good
}
}
}
}

// wrapAckHandler modifies the ack handler, so it's called with the original
// message received by FanoutNode instead of the new message created by
// FanoutNode.
Expand Down
3 changes: 1 addition & 2 deletions pkg/pipeline/stream/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package stream

import (
"context"
"fmt"
"sync"

"github.com/conduitio/conduit/pkg/foundation/cerrors"
Expand Down Expand Up @@ -131,7 +130,7 @@ func (m *Message) init() {
// ID returns a string representing a unique ID of this message. This is meant
// only for logging purposes.
func (m *Message) ID() string {
return fmt.Sprintf("%s/%s", m.SourceID, m.Record.Position)
return m.SourceID + "/" + string(m.Record.Position)
}

func (m *Message) ControlMessageType() ControlMessageType {
Expand Down

0 comments on commit 3111797

Please sign in to comment.