Skip to content

Commit

Permalink
NOISSUE - Add support for interceptor (#51)
Browse files Browse the repository at this point in the history
Signed-off-by: Dusan Borovcanin <[email protected]>
  • Loading branch information
dborovcanin committed Jan 16, 2024
1 parent 0ffbc4f commit 0b102d0
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 54 deletions.
8 changes: 4 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,31 +267,31 @@ func loadConfig() config {

func proxyMQTTWS(cfg WSMQTTConfig, logger mglog.Logger, handler session.Handler, errs chan error) {
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
wp := websocket.New(target, cfg.targetPath, cfg.targetScheme, handler, logger)
wp := websocket.New(target, cfg.targetPath, cfg.targetScheme, handler, nil, logger)
http.Handle(cfg.path, wp.Handler())

errs <- wp.Listen(cfg.port)
}

func proxyMQTTWSS(cfg config, tlsCfg *tls.Config, logger mglog.Logger, handler session.Handler, errs chan error) {
target := fmt.Sprintf("%s:%s", cfg.wsMQTTConfig.targetHost, cfg.wsMQTTConfig.targetPort)
wp := websocket.New(target, cfg.wsMQTTConfig.targetPath, cfg.wsMQTTConfig.targetScheme, handler, logger)
wp := websocket.New(target, cfg.wsMQTTConfig.targetPath, cfg.wsMQTTConfig.targetScheme, handler, nil, logger)
http.Handle(cfg.wsMQTTConfig.wssPath, wp.Handler())
errs <- wp.ListenTLS(tlsCfg, cfg.serverCert, cfg.serverKey, cfg.wsMQTTConfig.wssPort)
}

func proxyMQTT(ctx context.Context, cfg MQTTConfig, logger mglog.Logger, handler session.Handler, errs chan error) {
address := fmt.Sprintf("%s:%s", cfg.host, cfg.port)
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
mp := mqtt.New(address, target, handler, logger)
mp := mqtt.New(address, target, handler, nil, logger)

errs <- mp.Listen(ctx)
}

func proxyMQTTS(ctx context.Context, cfg MQTTConfig, tlsCfg *tls.Config, logger mglog.Logger, handler session.Handler, errs chan error) {
address := fmt.Sprintf("%s:%s", cfg.host, cfg.mqttsPort)
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
mp := mqtt.New(address, target, handler, logger)
mp := mqtt.New(address, target, handler, nil, logger)

errs <- mp.ListenTLS(ctx, tlsCfg)
}
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ toolchain go1.21.4
require (
github.com/absmach/magistrala v0.11.1-0.20231220185538-1fe2e74a741f
github.com/eclipse/paho.mqtt.golang v1.4.3
github.com/google/uuid v1.4.0
github.com/gorilla/websocket v1.5.0
golang.org/x/sync v0.4.0
github.com/google/uuid v1.5.0
github.com/gorilla/websocket v1.5.1
golang.org/x/sync v0.6.0
)

require (
github.com/go-kit/log v0.2.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/net v0.20.0 // indirect
)
16 changes: 8 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU=
github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
30 changes: 16 additions & 14 deletions pkg/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@ import (
mptls "github.com/absmach/mproxy/pkg/tls"
)

// Proxy is main MQTT proxy struct
// Proxy is main MQTT proxy struct.
type Proxy struct {
address string
target string
handler session.Handler
logger logger.Logger
dialer net.Dialer
address string
target string
handler session.Handler
interceptor session.Interceptor
logger logger.Logger
dialer net.Dialer
}

// New returns a new mqtt Proxy instance.
func New(address, target string, handler session.Handler, logger logger.Logger) *Proxy {
// New returns a new MQTT Proxy instance.
func New(address, target string, handler session.Handler, interceptor session.Interceptor, logger logger.Logger) *Proxy {
return &Proxy{
address: address,
target: target,
handler: handler,
logger: logger,
address: address,
target: target,
handler: handler,
logger: logger,
interceptor: interceptor,
}
}

Expand Down Expand Up @@ -59,7 +61,7 @@ func (p Proxy) handle(ctx context.Context, inbound net.Conn) {
return
}

if err = session.Stream(ctx, inbound, outbound, p.handler, clientCert); err != io.EOF {
if err = session.Stream(ctx, inbound, outbound, p.handler, p.interceptor, clientCert); err != io.EOF {
p.logger.Warn(err.Error())
}
}
Expand All @@ -79,7 +81,7 @@ func (p Proxy) Listen(ctx context.Context) error {
return nil
}

// ListenTLS - version of Listen with TLS encryption
// ListenTLS - version of Listen with TLS encryption.
func (p Proxy) ListenTLS(ctx context.Context, tlsCfg *tls.Config) error {
l, err := tls.Listen("tcp", p.address, tlsCfg)
if err != nil {
Expand Down
25 changes: 14 additions & 11 deletions pkg/mqtt/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@ import (

// Proxy represents WS Proxy.
type Proxy struct {
target string
path string
scheme string
event session.Handler
target string
path string
scheme string
handler session.Handler
interceptor session.Interceptor

logger logger.Logger
}

// New - creates new WS proxy
func New(target, path, scheme string, event session.Handler, logger logger.Logger) *Proxy {
func New(target, path, scheme string, handler session.Handler, interceptor session.Interceptor, logger logger.Logger) *Proxy {
return &Proxy{
target: target,
path: path,
scheme: scheme,
event: event,
logger: logger,
target: target,
path: path,
scheme: scheme,
handler: handler,
interceptor: interceptor,
logger: logger,
}
}

Expand Down Expand Up @@ -94,7 +97,7 @@ func (p Proxy) pass(ctx context.Context, in *websocket.Conn) {
return
}

err = session.Stream(ctx, inboundConn, outboundConn, p.event, clientCert)
err = session.Stream(ctx, inboundConn, outboundConn, p.handler, p.interceptor, clientCert)
errc <- err
p.logger.Warn("Broken connection for client with error: " + err.Error())
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/session/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package session

import (
"context"

"github.com/eclipse/paho.mqtt.golang/packets"
)

// Interceptor is an interface for mProxy intercept hook.
type Interceptor interface {
// Intercept is called on every packet flowing through the Proxy.
// Packets can be modified before being sent to the broker or the client.
// If the interceptor returns a non-nil packet, the modified packet is sent.
// The error indicates unsuccessful interception and mProxy is cancelling the packet.
Intercept(ctx context.Context, pkt packets.ControlPacket, dir Direction) (packets.ControlPacket, error)
}
34 changes: 21 additions & 13 deletions pkg/session/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"github.com/eclipse/paho.mqtt.golang/packets"
)

type direction int
type Direction int

const (
up direction = iota
down
Up Direction = iota
Down
)

const unknownID = "unknown"
Expand All @@ -25,26 +25,26 @@ var (
)

// Stream starts proxy between client and broker.
func Stream(ctx context.Context, inbound, outbound net.Conn, handler Handler, cert x509.Certificate) error {
func Stream(ctx context.Context, in, out net.Conn, h Handler, ic Interceptor, cert x509.Certificate) error {
s := Session{
Cert: cert,
}
ctx = NewContext(ctx, &s)
errs := make(chan error, 2)

go stream(ctx, up, inbound, outbound, handler, errs)
go stream(ctx, down, outbound, inbound, handler, errs)
go stream(ctx, Up, in, out, h, ic, errs)
go stream(ctx, Down, out, in, h, ic, errs)

// Handle whichever error happens first.
// The other routine won't be blocked when writing
// to the errors channel because it is buffered.
err := <-errs

handler.Disconnect(ctx)
h.Disconnect(ctx)
return err
}

func stream(ctx context.Context, dir direction, r, w net.Conn, h Handler, errs chan error) {
func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, ic Interceptor, errs chan error) {
for {
// Read from one connection.
pkt, err := packets.ReadPacket(r)
Expand All @@ -53,20 +53,28 @@ func stream(ctx context.Context, dir direction, r, w net.Conn, h Handler, errs c
return
}

if dir == up {
if dir == Up {
if err = authorize(ctx, pkt, h); err != nil {
errs <- wrap(ctx, err, dir)
return
}
}
if ic != nil {
pkt, err = ic.Intercept(ctx, pkt, dir)
if err != nil {
errs <- wrap(ctx, err, dir)
return
}
}

// Send to another.
if err := pkt.Write(w); err != nil {
errs <- wrap(ctx, err, dir)
return
}

if dir == up {
// Notify only for packets sent from client to broker (incoming packets).
if dir == Up {
if err := notify(ctx, pkt, h); err != nil {
errs <- wrap(ctx, err, dir)
}
Expand Down Expand Up @@ -118,7 +126,7 @@ func notify(ctx context.Context, pkt packets.ControlPacket, h Handler) error {
}
}

func wrap(ctx context.Context, err error, dir direction) error {
func wrap(ctx context.Context, err error, dir Direction) error {
if err == io.EOF {
return err
}
Expand All @@ -127,9 +135,9 @@ func wrap(ctx context.Context, err error, dir direction) error {
cid = s.ID
}
switch dir {
case up:
case Up:
return fmt.Errorf(errClient, cid, err.Error())
case down:
case Down:
return fmt.Errorf(errBroker, cid, err.Error())
default:
return err
Expand Down

0 comments on commit 0b102d0

Please sign in to comment.