diff --git a/cmd/oidc-token-verifier/main.go b/cmd/oidc-token-verifier/main.go index 79bb6606842d..993ae9bc82bb 100644 --- a/cmd/oidc-token-verifier/main.go +++ b/cmd/oidc-token-verifier/main.go @@ -1,9 +1,11 @@ package main import ( + "errors" "fmt" "os" + "github.com/coreos/go-oidc/v3/oidc" "github.com/kyma-project/test-infra/pkg/logging" tioidc "github.com/kyma-project/test-infra/pkg/oidc" "github.com/spf13/cobra" @@ -107,8 +109,8 @@ func (opts *options) extractClaims() error { var ( zapLogger *zap.Logger err error - tokenExpiredError *tioidc.TokenExpiredError - token Token + tokenExpiredError *oidc.TokenExpiredError + token *tioidc.Token ) if opts.debug { zapLogger, err = zap.NewDevelopment() @@ -165,28 +167,29 @@ func (opts *options) extractClaims() error { verifier := provider.NewVerifier(logger, verifyConfig) logger.Infow("New verifier created") - token, err = verifier.VerifyToken(ctx, opts.token) + // Verify the token + token, err = verifier.Verify(ctx, opts.token) if errors.As(err, &tokenExpiredError) { - // Verify the token expiration time using the extended expiration time. - err = verifier.VerifyExtendedExpiration(err.(tioidc.TokenExpiredError).Expiry, opts.oidcTokenExpirationTime) + err = verifier.VerifyExtendedExpiration(err.(*oidc.TokenExpiredError).Expiry, opts.oidcTokenExpirationTime) if err != nil { return err } verifyConfig.SkipExpiryCheck = false verifierWithoutExpiration := provider.NewVerifier(logger, verifyConfig) - token, err = verifierWithoutExpiration.VerifyToken(ctx, opts.token) + token, err = verifierWithoutExpiration.Verify(ctx, opts.token) } if err != nil { return err } logger.Infow("Token verified successfully") - // claims will store the extracted claim values from the token. + // Create claims claims := tioidc.NewClaims(logger) logger.Infow("Verifying token claims") - // Verifies if custom claims has expected values. - // Extract the claim values from the token into the claims struct. - err = tokenProcessor.ValidateClaims(ctx, &claims) + + // Pass the token to ValidateClaims + err = tokenProcessor.ValidateClaims(&claims, token) + if err != nil { return err } diff --git a/pkg/oidc/oidc.go b/pkg/oidc/oidc.go index 3bbbf3eade93..598c584cb326 100644 --- a/pkg/oidc/oidc.go +++ b/pkg/oidc/oidc.go @@ -205,20 +205,20 @@ func NewVerifierConfig(logger LoggerInterface, clientID string, options ...Verif // Verify verifies the raw OIDC token. // It returns a Token struct which contains the verified token if successful. -func (tokenVerifier *TokenVerifier) Verify(ctx context.Context, rawToken string) (Token, error) { +func (tokenVerifier *TokenVerifier) Verify(ctx context.Context, rawToken string) (*Token, error) { logger := tokenVerifier.Logger logger.Debugw("Verifying token") logger.Debugw("Got raw token value", "rawToken", maskToken(rawToken)) idToken, err := tokenVerifier.Verifier.Verify(ctx, rawToken) if err != nil { token := Token{} - return token, fmt.Errorf("failed to verify token: %w", err) + return &token, fmt.Errorf("failed to verify token: %w", err) } logger.Debugw("Token verified successfully") token := Token{ Token: idToken, } - return token, nil + return &token, nil } // VerifyExtendedExpiration checks the OIDC token expiration timestamp against the provided expiration time. @@ -229,7 +229,7 @@ func (tokenVerifier *TokenVerifier) VerifyExtendedExpiration(expirationTimestamp logger.Debugw("Verifying token expiration time", "expirationTimestamp", expirationTimestamp, "gracePeriodMinutes", gracePeriodMinutes) now := time.Now() elapsed := now.Sub(expirationTimestamp) - gracePeriod := *time.Minute + gracePeriod := time.Minute if elapsed <= gracePeriod { return nil } @@ -402,15 +402,21 @@ func (tokenProcessor *TokenProcessor) Issuer() string { // It uses the provided verifier to verify the token signature and expiration time. // It verifies if the token claims have expected values. // It unmarshal the claims into the provided claims struct. -func (tokenProcessor *TokenProcessor) ValidateClaims(claims ClaimsInterface) error { +func (tokenProcessor *TokenProcessor) ValidateClaims(claims ClaimsInterface, token *Token) error { logger := tokenProcessor.logger + // Ensure that the token is initialized + if token.Token == nil { + return fmt.Errorf("failed to verify token: token validation failed") + } + logger.Debugw("Getting claims from token") - err = token.Claims(claims) + err := token.Claims(claims) if err != nil { return fmt.Errorf("failed to get claims from token: %w", err) } logger.Debugw("Got claims from token", "claims", fmt.Sprintf("%+v", claims)) + err = claims.validateExpectations(tokenProcessor.issuer) if err != nil { return fmt.Errorf("failed to validate claims: %w", err) diff --git a/pkg/oidc/oidc_test.go b/pkg/oidc/oidc_test.go index 643e4589fae8..2e1a2f87fba7 100644 --- a/pkg/oidc/oidc_test.go +++ b/pkg/oidc/oidc_test.go @@ -3,12 +3,11 @@ package oidc_test import ( "errors" "fmt" - - // "time" - // "fmt" "os" + // "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/go-jose/go-jose/v4/jwt" tioidc "github.com/kyma-project/test-infra/pkg/oidc" @@ -22,8 +21,8 @@ import ( var _ = Describe("OIDC", func() { var ( - err error - ctx context.Context + err error + logger *zap.SugaredLogger trustedIssuers map[string]tioidc.Issuer rawToken []byte @@ -90,7 +89,6 @@ var _ = Describe("OIDC", func() { JWKSURL: "https://fakedings.dev-gcp.nais.io/fake/jwks", }, } - ctx = context.Background() }) When("issuer is trusted", func() { It("should return a new TokenProcessor", func() { @@ -176,8 +174,6 @@ var _ = Describe("OIDC", func() { Expect(err).NotTo(HaveOccurred()) Expect(tokenProcessor).NotTo(BeNil()) - ctx = context.Background() - trustedIssuers = map[string]tioidc.Issuer{ "https://fakedings.dev-gcp.nais.io/fake": { Name: "github", @@ -215,7 +211,7 @@ var _ = Describe("OIDC", func() { verifier.On("Verify", mock.AnythingOfType("backgroundCtx"), string(rawToken)).Return(token, nil) // Run - err = tokenProcessor.VerifyAndExtractClaims(ctx, verifier, &claims) + err = tokenProcessor.ValidateClaims(&claims, &token) // Verify Expect(err).NotTo(HaveOccurred()) @@ -239,17 +235,17 @@ var _ = Describe("OIDC", func() { verifier.On("Verify", mock.AnythingOfType("backgroundCtx"), string(rawToken)).Return(token, nil) // Run - err = tokenProcessor.VerifyAndExtractClaims(ctx, verifier, &claims) + err = tokenProcessor.ValidateClaims(&claims, &token) // Verify Expect(err).To(HaveOccurred()) Expect(err).To(MatchError("failed to validate claims: job_workflow_ref claim expected value validation failed, expected: kyma-project/test-infra/.github/workflows/unexpected.yml@refs/heads/main, provided: kyma-project/test-infra/.github/workflows/verify-oidc-token.yml@refs/heads/main")) }) It("should return an error when token was not verified", func() { - verifier.On("Verify", mock.AnythingOfType("backgroundCtx"), string(rawToken)).Return(token, fmt.Errorf("token validation failed")) + verifier.On("Verify", mock.AnythingOfType("backgroundCtx"), string(rawToken)).Return(tioidc.Token{}, fmt.Errorf("token validation failed")) // Run - err = tokenProcessor.VerifyAndExtractClaims(ctx, verifier, &claims) + err = tokenProcessor.ValidateClaims(&claims, &token) // Verify Expect(err).To(HaveOccurred()) @@ -263,7 +259,7 @@ var _ = Describe("OIDC", func() { Token.On("Claims", &claims).Return(fmt.Errorf("claims are not set")) // Run - err = tokenProcessor.VerifyAndExtractClaims(ctx, verifier, &claims) + err = tokenProcessor.ValidateClaims(&claims, &token) // Verify Expect(err).To(HaveOccurred()) @@ -278,7 +274,7 @@ var _ = Describe("OIDC", func() { tokenVerifier tioidc.TokenVerifier verifier *oidcmocks.MockVerifier ctx context.Context - token tioidc.Token + token *tioidc.Token ) BeforeEach(func() { verifier = &oidcmocks.MockVerifier{} @@ -295,7 +291,7 @@ var _ = Describe("OIDC", func() { verifier.On("Verify", mock.AnythingOfType("backgroundCtx"), string(rawToken)).Return(&oidc.IDToken{}, nil) token, err = tokenVerifier.Verify(ctx, string(rawToken)) Expect(err).NotTo(HaveOccurred()) - Expect(token).To(BeAssignableToTypeOf(tioidc.Token{})) + Expect(token).To(BeAssignableToTypeOf(&tioidc.Token{})) }) }) }) @@ -311,7 +307,6 @@ var _ = Describe("OIDC", func() { provider = tioidc.Provider{ VerifierProvider: oidcProvider, } - ctx = context.Background() verifierConfig, err = tioidc.NewVerifierConfig(logger, clientID) Expect(err).NotTo(HaveOccurred()) })