diff --git a/iotdevice/transport/mqtt/mqtt.go b/iotdevice/transport/mqtt/mqtt.go index 73c3b66..3bf168f 100644 --- a/iotdevice/transport/mqtt/mqtt.go +++ b/iotdevice/transport/mqtt/mqtt.go @@ -66,6 +66,18 @@ func WithModelID(modelID string) TransportOption { } } +func WithConnectCallback(callback func(*Transport)) TransportOption { + return func(tr *Transport) { + tr.connectCallback = callback + } +} + +func WithDisconnectCallback(callback func(*Transport)) TransportOption { + return func(tr *Transport) { + tr.disconnectCallback = callback + } +} + // New returns new Transport transport. // See more: https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-mqtt-support func New(opts ...TransportOption) *Transport { @@ -96,6 +108,9 @@ type Transport struct { cocfg func(opts *mqtt.ClientOptions) webSocket bool + + connectCallback func(tr *Transport) + disconnectCallback func(tr *Transport) } type resp struct { @@ -162,9 +177,15 @@ func (tr *Transport) Connect(ctx context.Context, creds transport.Credentials) e } } tr.subm.RUnlock() + if tr.connectCallback != nil { + tr.connectCallback(tr) + } }) o.SetConnectionLostHandler(func(_ mqtt.Client, err error) { tr.logger.Debugf("connection lost: %v", err) + if tr.disconnectCallback != nil { + tr.disconnectCallback(tr) + } }) if tr.cocfg != nil {