From 7de58e3e838d2bb951bf3196883e08a779a65356 Mon Sep 17 00:00:00 2001 From: Kacper Zienkiewicz Date: Thu, 3 Oct 2024 12:13:15 +0200 Subject: [PATCH] [#66501] linux-client: goroutine-safe `getDeviceToken` method This commit adds a mutexed method for fetching a valid device token that can be used instead of directly accessing the `deviceToken` field of the `Device` structure. This allows for multiple independent goroutines to get a valid device token safely. --- devices/linux-client/daemon/device.go | 41 +++++++++++++++++++------- devices/linux-client/daemon/updates.go | 26 ++++------------ 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/devices/linux-client/daemon/device.go b/devices/linux-client/daemon/device.go index fee5202..b6d6916 100644 --- a/devices/linux-client/daemon/device.go +++ b/devices/linux-client/daemon/device.go @@ -36,8 +36,7 @@ const MSG_RECV_TIMEOUT_INTERVALS = 10 const MSG_RECV_INTERVAL_S = 1 const RSA_DEVICE_KEY_SIZE = 4096 const MGMT_LOOP_RECOVERY_INTERVAL = 30 - -var tokenMutex sync.Mutex +const TOKEN_EXPIRY_MIN_ALLOWED = 5 type Device struct { name string @@ -48,6 +47,7 @@ type Device struct { macAddr string rdfmCtx *app.RDFM deviceToken string + tokenMutex sync.Mutex httpTransport *http.Transport } @@ -192,12 +192,12 @@ func (d *Device) connect() error { } // Get device token - err = d.authenticateDeviceWithServer() + deviceToken, err := d.getDeviceToken() if err != nil { return err } authHeader := http.Header{ - "Authorization": []string{"Bearer token=" + d.deviceToken}, + "Authorization": []string{"Bearer token=" + deviceToken}, } // Open a WebSocket connection @@ -269,10 +269,33 @@ func (d *Device) managementWsLoop(done chan bool) { panic(err) } -func (d Device) getDeviceToken() string { - tokenMutex.Lock() - defer tokenMutex.Unlock() - return d.deviceToken +func (d *Device) getDeviceToken() (string, error) { + d.tokenMutex.Lock() + defer d.tokenMutex.Unlock() + + payload, err := netUtils.ExtractJwtPayload(d.deviceToken) + if err == nil { + // check if reauth needed in the next TOKEN_EXPIRY_MIN_ALLOWED or less seconds + if payload.CreatedAt+payload.Expires-TOKEN_EXPIRY_MIN_ALLOWED <= time.Now().Unix() { + log.Println("Device token expired, reauthenticating...") + err = d.authenticateDeviceWithServer() + if err != nil { + return "", err + } + } else { + log.Println("Device token up to date") + } + } else { + // extraction failed + log.Println("Device token missing or malformed, authenticating...") + err := d.authenticateDeviceWithServer() + if err != nil { + return "", err + } + } + + // d.authenticateDeviceWithServer should have taken care of fetching the token if necessary + return d.deviceToken, nil } func getPublicKey(privateKey *rsa.PrivateKey) []byte { @@ -406,8 +429,6 @@ func (d *Device) authenticateDeviceWithServer() error { log.Println("Failed to deserialize package metadata", err) return err } - tokenMutex.Lock() - defer tokenMutex.Unlock() d.deviceToken = response["token"].(string) log.Println("Authorization token expires in", response["expires"], "seconds") if len(d.deviceToken) == 0 { diff --git a/devices/linux-client/daemon/updates.go b/devices/linux-client/daemon/updates.go index f7bc300..9fc6daa 100644 --- a/devices/linux-client/daemon/updates.go +++ b/devices/linux-client/daemon/updates.go @@ -17,8 +17,6 @@ import ( const MIN_RETRY_INTERVAL = 1 const MAX_RETRY_INTERVAL = 60 -var NotAuthorizedError = errors.New("Device did not provide authorization data, or the authorization has expired") - func (d *Device) checkUpdate() error { devType, err := d.rdfmCtx.GetCurrentDeviceType() if err != nil { @@ -45,7 +43,11 @@ func (d *Device) checkUpdate() error { bytes.NewBuffer(serializedMetadata), ) req.Header.Set("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer token="+d.deviceToken) + deviceToken, err := d.getDeviceToken() + if err != nil { + return errors.New("Failed to fetch the device token: " + err.Error()) + } + req.Header.Add("Authorization", "Bearer token="+deviceToken) log.Println("Checking updates...") @@ -82,7 +84,7 @@ func (d *Device) checkUpdate() error { case 400: return errors.New("Device metadata is missing device type and/or software version") case 401: - return NotAuthorizedError + return errors.New("Device did not provide authorization data, or the authorization has expired") default: return errors.New("Unexpected status code from the server: " + res.Status) } @@ -92,7 +94,6 @@ func (d *Device) checkUpdate() error { func (d *Device) updateCheckerLoop(done chan bool) { var err error var info string - var count int // Recover the goroutine if it panics defer func() { @@ -112,27 +113,12 @@ func (d *Device) updateCheckerLoop(done chan bool) { for { err = d.checkUpdate() - if err == NotAuthorizedError { - log.Println(err) - err = d.authenticateDeviceWithServer() - if err == nil { - retryInterval := count * MIN_RETRY_INTERVAL - if retryInterval > MAX_RETRY_INTERVAL { - retryInterval = MAX_RETRY_INTERVAL - } - time.Sleep(time.Duration(retryInterval) * time.Second) - count = (count + 1) * 2 - continue - } - err = errors.New("Failed to autheniticate with the server: " + err.Error()) - } if err != nil { log.Println("Update check failed:", err) } updateDuration := time.Duration(d.rdfmCtx.RdfmConfig.UpdatePollIntervalSeconds) * time.Second log.Printf("Next update check in %s\n", updateDuration) time.Sleep(time.Duration(updateDuration)) - count = 0 } panic(err) }