diff --git a/device/device.go b/device/device.go index d69776c..49adc94 100644 --- a/device/device.go +++ b/device/device.go @@ -213,14 +213,11 @@ func (d *Device) Connect(result chan<- error) { } // initialize the client - if err = d.initializeMQTTClient(); err != nil { - if result != nil { - result <- err - } - return - } + d.initializeMQTTClient() // Wait for the token - we're in a coroutine anyway + // If AutoReconnect is enabled, reconnections will be handled by + // paho's autoreconnect mechanism. policy.Reset() connectOperation := func() error { connectToken := d.m.Connect().(*mqtt.ConnectToken) diff --git a/device/protocol_mqtt_v1.go b/device/protocol_mqtt_v1.go index 8f5f127..1ed4e88 100644 --- a/device/protocol_mqtt_v1.go +++ b/device/protocol_mqtt_v1.go @@ -17,10 +17,14 @@ package device import ( "bytes" "compress/zlib" + "crypto/tls" "encoding/binary" "errors" "fmt" "io" + "net" + "net/url" + "os" "path/filepath" "reflect" "sort" @@ -31,13 +35,14 @@ import ( mqtt "github.com/ispirata/paho.mqtt.golang" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "golang.org/x/net/proxy" ) func (d *Device) getBaseTopic() string { return fmt.Sprintf("%s/%s", d.realm, d.deviceID) } -func (d *Device) initializeMQTTClient() error { +func (d *Device) initializeMQTTClient() { opts := mqtt.NewClientOptions() opts.AddBroker(d.brokerURL) opts.SetAutoReconnect(d.opts.AutoReconnect) @@ -50,11 +55,8 @@ func (d *Device) initializeMQTTClient() error { opts.SetStore(s) } - tlsConfig, err := d.getTLSConfig() - if err != nil { - return err - } - opts.SetTLSConfig(tlsConfig) + // TLS will be handled in the custom connection function + opts.SetCustomOpenConnectionFn(d.astarteCustomOpenConnectionFn) opts.SetOnConnectHandler(func(client mqtt.Client, sessionPresent bool) { astarteOnConnectHandler(d, sessionPresent) @@ -76,8 +78,6 @@ func (d *Device) initializeMQTTClient() error { opts.SetDefaultPublishHandler(d.astarteGoSDKDefaultPublishHandler) d.m = mqtt.NewClient(opts) - - return nil } func (d *Device) astarteGoSDKDefaultPublishHandler(client mqtt.Client, msg mqtt.Message) { @@ -203,6 +203,67 @@ func (d *Device) handleControlMessages(message string, payload []byte) error { return nil } +// TLS errors on reconnection are not exposed, so we have to make do. +// This code is mostly taken from Paho's openConnection function. It does not carry out any MQTT specific handshakes. +func (d *Device) astarteCustomOpenConnectionFn(uri *url.URL, options mqtt.ClientOptions) (net.Conn, error) { + switch uri.Scheme { + case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps": + // Use our own TLS config as the one in options might be outdated. + tlsc, tlscErr := d.getTLSConfig() + if tlscErr != nil { + return nil, tlscErr + } + allProxy := os.Getenv("all_proxy") + if len(allProxy) == 0 { + // Do not use a backoff, as this code might be called many consecutive times until a valid device certificate is emitted. + dialer := &net.Dialer{Timeout: 10 * time.Second} + // The resulting connection will be up-to-date with our TLS config + conn, err := tls.DialWithDialer(dialer, "tcp", uri.Host, tlsc) + if err != nil { + _ = d.obtainNewCertificate() + // Fail anyway, so the next reconnection will use the new certificate. + return nil, err + } + return conn, nil + } + proxyDialer := proxy.FromEnvironment() + conn, err := proxyDialer.Dial("tcp", uri.Host) + if err != nil { + return nil, err + } + // The resulting connection will be up-to-date with our TLS config + tlsConn := tls.Client(conn, tlsc) + err = tlsConn.Handshake() + if err != nil { + _ = conn.Close() + _ = d.obtainNewCertificate() + //fail anyway, so the next reconnection will use the new certificate. + return nil, err + } + return tlsConn, nil + + //no need for certificates + case "mqtt", "tcp": + allProxy := os.Getenv("all_proxy") + if len(allProxy) == 0 { + dialer := &net.Dialer{Timeout: 10 * time.Second} + conn, err := dialer.Dial("tcp", uri.Host) + if err != nil { + return nil, err + } + return conn, nil + } + proxyDialer := proxy.FromEnvironment() + + conn, err := proxyDialer.Dial("tcp", uri.Host) + if err != nil { + return nil, err + } + return conn, nil + } + return nil, errors.New("unknown protocol") +} + func astarteOnConnectHandler(d *Device, sessionPresent bool) { // Generate Introspection first introspection := d.generateDeviceIntrospection()