Skip to content

Commit

Permalink
[#66501] linux-client: goroutine-safe getDeviceToken method
Browse files Browse the repository at this point in the history
This commit adds a mutexed method for fetching a valid device token that
can be used instead of directly accessing the `deviceToken` field of the
`Device` structure. This allows for multiple independent goroutines to
get a valid device token safely.
  • Loading branch information
Kacper Zienkiewicz committed Oct 3, 2024
1 parent 2f065ec commit 7de58e3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
41 changes: 31 additions & 10 deletions devices/linux-client/daemon/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ const MSG_RECV_TIMEOUT_INTERVALS = 10
const MSG_RECV_INTERVAL_S = 1
const RSA_DEVICE_KEY_SIZE = 4096
const MGMT_LOOP_RECOVERY_INTERVAL = 30

var tokenMutex sync.Mutex
const TOKEN_EXPIRY_MIN_ALLOWED = 5

type Device struct {
name string
Expand All @@ -48,6 +47,7 @@ type Device struct {
macAddr string
rdfmCtx *app.RDFM
deviceToken string
tokenMutex sync.Mutex
httpTransport *http.Transport
}

Expand Down Expand Up @@ -192,12 +192,12 @@ func (d *Device) connect() error {
}

// Get device token
err = d.authenticateDeviceWithServer()
deviceToken, err := d.getDeviceToken()
if err != nil {
return err
}
authHeader := http.Header{
"Authorization": []string{"Bearer token=" + d.deviceToken},
"Authorization": []string{"Bearer token=" + deviceToken},
}

// Open a WebSocket connection
Expand Down Expand Up @@ -269,10 +269,33 @@ func (d *Device) managementWsLoop(done chan bool) {
panic(err)
}

func (d Device) getDeviceToken() string {
tokenMutex.Lock()
defer tokenMutex.Unlock()
return d.deviceToken
func (d *Device) getDeviceToken() (string, error) {
d.tokenMutex.Lock()
defer d.tokenMutex.Unlock()

payload, err := netUtils.ExtractJwtPayload(d.deviceToken)
if err == nil {
// check if reauth needed in the next TOKEN_EXPIRY_MIN_ALLOWED or less seconds
if payload.CreatedAt+payload.Expires-TOKEN_EXPIRY_MIN_ALLOWED <= time.Now().Unix() {
log.Println("Device token expired, reauthenticating...")
err = d.authenticateDeviceWithServer()
if err != nil {
return "", err
}
} else {
log.Println("Device token up to date")
}
} else {
// extraction failed
log.Println("Device token missing or malformed, authenticating...")
err := d.authenticateDeviceWithServer()
if err != nil {
return "", err
}
}

// d.authenticateDeviceWithServer should have taken care of fetching the token if necessary
return d.deviceToken, nil
}

func getPublicKey(privateKey *rsa.PrivateKey) []byte {
Expand Down Expand Up @@ -406,8 +429,6 @@ func (d *Device) authenticateDeviceWithServer() error {
log.Println("Failed to deserialize package metadata", err)
return err
}
tokenMutex.Lock()
defer tokenMutex.Unlock()
d.deviceToken = response["token"].(string)
log.Println("Authorization token expires in", response["expires"], "seconds")
if len(d.deviceToken) == 0 {
Expand Down
26 changes: 6 additions & 20 deletions devices/linux-client/daemon/updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import (
const MIN_RETRY_INTERVAL = 1
const MAX_RETRY_INTERVAL = 60

var NotAuthorizedError = errors.New("Device did not provide authorization data, or the authorization has expired")

func (d *Device) checkUpdate() error {
devType, err := d.rdfmCtx.GetCurrentDeviceType()
if err != nil {
Expand All @@ -45,7 +43,11 @@ func (d *Device) checkUpdate() error {
bytes.NewBuffer(serializedMetadata),
)
req.Header.Set("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer token="+d.deviceToken)
deviceToken, err := d.getDeviceToken()
if err != nil {
return errors.New("Failed to fetch the device token: " + err.Error())
}
req.Header.Add("Authorization", "Bearer token="+deviceToken)

log.Println("Checking updates...")

Expand Down Expand Up @@ -82,7 +84,7 @@ func (d *Device) checkUpdate() error {
case 400:
return errors.New("Device metadata is missing device type and/or software version")
case 401:
return NotAuthorizedError
return errors.New("Device did not provide authorization data, or the authorization has expired")
default:
return errors.New("Unexpected status code from the server: " + res.Status)
}
Expand All @@ -92,7 +94,6 @@ func (d *Device) checkUpdate() error {
func (d *Device) updateCheckerLoop(done chan bool) {
var err error
var info string
var count int

// Recover the goroutine if it panics
defer func() {
Expand All @@ -112,27 +113,12 @@ func (d *Device) updateCheckerLoop(done chan bool) {

for {
err = d.checkUpdate()
if err == NotAuthorizedError {
log.Println(err)
err = d.authenticateDeviceWithServer()
if err == nil {
retryInterval := count * MIN_RETRY_INTERVAL
if retryInterval > MAX_RETRY_INTERVAL {
retryInterval = MAX_RETRY_INTERVAL
}
time.Sleep(time.Duration(retryInterval) * time.Second)
count = (count + 1) * 2
continue
}
err = errors.New("Failed to autheniticate with the server: " + err.Error())
}
if err != nil {
log.Println("Update check failed:", err)
}
updateDuration := time.Duration(d.rdfmCtx.RdfmConfig.UpdatePollIntervalSeconds) * time.Second
log.Printf("Next update check in %s\n", updateDuration)
time.Sleep(time.Duration(updateDuration))
count = 0
}
panic(err)
}

0 comments on commit 7de58e3

Please sign in to comment.