Skip to content

Commit

Permalink
Merge pull request #556 from openziti/fix.token.exchange
Browse files Browse the repository at this point in the history
fixes openziti/ziti#1960, fixes openziti/ziti#1964 token exchange fails
  • Loading branch information
andrewpmartinez authored Apr 23, 2024
2 parents 8e592e3 + 11ebb94 commit 57ca067
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 32 deletions.
65 changes: 49 additions & 16 deletions edge-apis/authwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-openapi/runtime"
"github.com/go-openapi/strfmt"
"github.com/go-resty/resty/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/openziti/edge-api/rest_client_api_client"
clientAuth "github.com/openziti/edge-api/rest_client_api_client/authentication"
clientControllers "github.com/openziti/edge-api/rest_client_api_client/controllers"
Expand Down Expand Up @@ -74,6 +75,9 @@ type ApiSession interface {

//GetId returns the id of the ApiSession
GetId() string

//RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections.
RequiresRouterTokenUpdate() bool
}

var _ ApiSession = (*ApiSessionLegacy)(nil)
Expand All @@ -85,6 +89,10 @@ type ApiSessionLegacy struct {
Detail *rest_model.CurrentAPISessionDetail
}

func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool {
return false
}

func (a *ApiSessionLegacy) GetId() string {
return stringz.OrEmpty(a.Detail.ID)
}
Expand Down Expand Up @@ -146,6 +154,10 @@ type ApiSessionOidc struct {
OidcTokens *oidc.Tokens[*oidc.IDTokenClaims]
}

func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool {
return true
}

func (a *ApiSessionOidc) GetAccessClaims() (*ApiAccessClaims, error) {
claims := &ApiAccessClaims{}

Expand Down Expand Up @@ -491,53 +503,74 @@ func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenC
}

func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
subjectToken := curTokens.RefreshToken
subjectTokenType := oidc.RefreshTokenType

// if subjectToken is "", then we don't have a refresh token, attempt to exchange a non-expired access token
if subjectToken == "" {
if curTokens.Expiry.Before(time.Now()) {
return nil, errors.New("cannot exchange token: refresh token not found, access token expired")
}

if curTokens.AccessToken == "" {
return nil, errors.New("cannot exchange token: refresh token not found, access token not found")
}
subjectToken = curTokens.AccessToken
subjectTokenType = oidc.AccessTokenType
}

var outTokens *oidc.Tokens[*oidc.IDTokenClaims]

_, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) {
apiHost := transport.ApiUrl.Host
te, err := tokenexchange.NewTokenExchanger(apiHost, tokenexchange.WithHTTPClient(client))
issuer := "https://" + apiHost + "/oidc"
tokenEndpoint := "https://" + apiHost + "/oidc/oauth/token"

te, err := tokenexchange.NewTokenExchangerClientCredentials(issuer, "native", "", tokenexchange.WithHTTPClient(client), tokenexchange.WithStaticTokenEndpoint(issuer, tokenEndpoint))

if err != nil {
return nil, err
}

accessResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.AccessTokenType)
var tokenResponse *oidc.TokenExchangeResponse

if err != nil {
return nil, err
}
now := time.Now()

//TODO: be smarter, only refresh refresh token if the new access token lives beyond refresh
refreshResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType)
switch subjectTokenType {
case oidc.RefreshTokenType:
tokenResponse, err = tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType)
case oidc.AccessTokenType:
tokenResponse, err = tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.AccessTokenType)
}

if err != nil {
return nil, err
}

idResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.IDTokenType)
idResp, err := tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.IDTokenType)

if err != nil {
return nil, err
}

idClaims := &oidc.IDTokenClaims{}
idClaims := &IdClaims{}

err = json.Unmarshal([]byte(idResp.AccessToken), idClaims)
//access token is used to hold id token per zitadel comments
_, _, err = jwt.NewParser().ParseUnverified(idResp.AccessToken, idClaims)

if err != nil {
return nil, err
}

outTokens = &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: accessResp.AccessToken,
TokenType: accessResp.TokenType,
RefreshToken: refreshResp.RefreshToken,
Expiry: time.Time{},
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
Expiry: now.Add(time.Duration(tokenResponse.ExpiresIn)),
},
IDTokenClaims: idClaims,
IDToken: idResp.AccessToken, //access token is used to hold id token per zitadel comments
IDTokenClaims: &idClaims.IDTokenClaims,
IDToken: idResp.AccessToken, //access token field is used to hold id token per zitadel comments
}

return outTokens, nil
Expand Down
32 changes: 32 additions & 0 deletions edge-apis/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,38 @@ type ApiAccessClaims struct {
Scopes []string `json:"scopes,omitempty"`
}

var _ jwt.Claims = (*IdClaims)(nil)

// IdClaims wraps oidc.IDToken claims to fulfill the jwt.Claims interface
type IdClaims struct {
oidc.IDTokenClaims
}

func (r *IdClaims) GetExpirationTime() (*jwt.NumericDate, error) {
return &jwt.NumericDate{Time: r.TokenClaims.GetExpiration()}, nil
}

func (r *IdClaims) GetNotBefore() (*jwt.NumericDate, error) {
notBefore := r.TokenClaims.NotBefore.AsTime()
return &jwt.NumericDate{Time: notBefore}, nil
}

func (r *IdClaims) GetIssuedAt() (*jwt.NumericDate, error) {
return &jwt.NumericDate{Time: r.TokenClaims.GetIssuedAt()}, nil
}

func (r *IdClaims) GetIssuer() (string, error) {
return r.TokenClaims.Issuer, nil
}

func (r *IdClaims) GetSubject() (string, error) {
return r.TokenClaims.Issuer, nil
}

func (r *IdClaims) GetAudience() (jwt.ClaimStrings, error) {
return jwt.ClaimStrings(r.TokenClaims.Audience), nil
}

type localRpServer struct {
Server *http.Server
Port string
Expand Down
1 change: 1 addition & 0 deletions example/chat/chat-server/chat-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,5 @@ func main() {
logger.Infof("new connection")
go server.handleChat(conn)
}

}
16 changes: 7 additions & 9 deletions ziti/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,8 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/openziti/foundation/v2/genext"
"github.com/openziti/transport/v2"
"github.com/pkg/errors"
"strings"
"time"

"github.com/go-openapi/strfmt"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/edge-api/rest_client_api_client/authentication"
Expand All @@ -43,10 +37,14 @@ import (
"github.com/openziti/edge-api/rest_client_api_client/session"
"github.com/openziti/edge-api/rest_model"
"github.com/openziti/edge-api/rest_util"
"github.com/openziti/foundation/v2/genext"
nfPem "github.com/openziti/foundation/v2/pem"
"github.com/openziti/identity"
apis "github.com/openziti/sdk-golang/edge-apis"
"github.com/openziti/sdk-golang/ziti/edge/posture"
"github.com/openziti/transport/v2"
"github.com/pkg/errors"
"strings"
)

// CtrlClient is a stateful version of ZitiEdgeClient that simplifies operations
Expand All @@ -72,7 +70,7 @@ func (self *CtrlClient) GetCurrentApiSession() apis.ApiSession {
}

// Refresh will contact the controller extending the current ApiSession for legacy API Sessions
func (self *CtrlClient) Refresh() (*time.Time, error) {
func (self *CtrlClient) Refresh() (apis.ApiSession, error) {
if apiSession := self.GetCurrentApiSession(); apiSession != nil {
newApiSession, err := self.API.RefreshApiSession(apiSession, self.HttpClient)

Expand All @@ -82,7 +80,7 @@ func (self *CtrlClient) Refresh() (*time.Time, error) {

self.ApiSession.Store(&newApiSession)

return newApiSession.GetExpiresAt(), nil
return newApiSession, nil
}

return nil, errors.New("no api session")
Expand Down
2 changes: 1 addition & 1 deletion ziti/edge/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type RouterClient interface {

//UpdateToken will attempt to send token updates to the connected router. A success/failure response is expected
//within the timeout period.
UpdateToken(token string, timeout time.Duration) error
UpdateToken(token []byte, timeout time.Duration) error
}

type RouterConn interface {
Expand Down
4 changes: 2 additions & 2 deletions ziti/edge/network/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func (conn *routerConn) NewDialConn(service *rest_model.ServiceDetail) *edgeConn
return edgeCh
}

func (conn *routerConn) UpdateToken(token string, timeout time.Duration) error {
msg := edge.NewUpdateTokenMsg([]byte(token))
func (conn *routerConn) UpdateToken(token []byte, timeout time.Duration) error {
msg := edge.NewUpdateTokenMsg(token)
resp, err := msg.WithTimeout(timeout).SendForReply(conn.ch)

if err != nil {
Expand Down
39 changes: 35 additions & 4 deletions ziti/ziti.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,20 @@ func (context *ContextImpl) RefreshService(serviceName string) (*rest_model.Serv
return serviceDetail, nil
}

func (context *ContextImpl) updateTokenOnAllErs(apiSession apis.ApiSession) {
if apiSession.RequiresRouterTokenUpdate() {
for tpl := range context.routerConnections.IterBuffered() {
erConn := tpl.Val
erKey := tpl.Key
go func() {
if err := erConn.UpdateToken(apiSession.GetToken(), 10*time.Second); err != nil {
pfxlog.Logger().WithError(err).WithField("er", erKey).Warn("error updating apiSession token to connected ER")
}
}()
}
}
}

func (context *ContextImpl) runRefreshes() {
log := pfxlog.Logger()
svcRefreshInterval := context.options.RefreshInterval
Expand All @@ -768,8 +782,9 @@ func (context *ContextImpl) runRefreshes() {
defer sessionRefreshTick.Stop()

refreshAt := time.Now().Add(30 * time.Second)

if currentApiSession := context.CtrlClt.GetCurrentApiSession(); currentApiSession != nil && currentApiSession.GetExpiresAt() != nil {
refreshAt = time.Time(*currentApiSession.GetExpiresAt()).Add(-10 * time.Second)
refreshAt = (*currentApiSession.GetExpiresAt()).Add(-10 * time.Second)
}

for {
Expand All @@ -778,14 +793,25 @@ func (context *ContextImpl) runRefreshes() {
return

case <-time.After(time.Until(refreshAt)):
exp, err := context.CtrlClt.Refresh()
apiSession := context.CtrlClt.GetCurrentApiSession()

if apiSession == nil {
pfxlog.Logger().Warn("could not refresh api session, current api session is nil")
continue
}

newApiSession, err := context.CtrlClt.Refresh()

if err != nil {
log.Errorf("could not refresh apiSession: %v", err)

refreshAt = time.Now().Add(5 * time.Second)
} else {
exp := newApiSession.GetExpiresAt()
refreshAt = exp.Add(-10 * time.Second)
log.Debugf("apiSession refreshed, new expiration[%s]", *exp)

context.updateTokenOnAllErs(newApiSession)
}

case <-svcRefreshTick.C:
Expand Down Expand Up @@ -926,8 +952,9 @@ func (context *ContextImpl) RefreshApiSessionWithBackoff() error {
expBackoff.MaxElapsedTime = 24 * time.Hour

operation := func() error {
_, err := context.CtrlClt.Refresh()
newApiSession, err := context.CtrlClt.Refresh()
if err == nil {
context.updateTokenOnAllErs(newApiSession)
return nil
}

Expand Down Expand Up @@ -990,9 +1017,13 @@ func (context *ContextImpl) authenticateMfa(code string) error {
return err
}

if _, err := context.CtrlClt.Refresh(); err != nil {
newApiSession, err := context.CtrlClt.Refresh()

if err != nil {
return err
}
context.updateTokenOnAllErs(newApiSession)

apiSession := context.CtrlClt.GetCurrentApiSession()

if apiSession != nil && len(apiSession.GetAuthQueries()) == 0 {
Expand Down

0 comments on commit 57ca067

Please sign in to comment.