Skip to content

Commit

Permalink
Merge branch 'v3' into remove-automatic-migration
Browse files Browse the repository at this point in the history
  • Loading branch information
kian99 authored Jun 10, 2024
2 parents 6daf9fb + c9c56cb commit 40ca4e7
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 21 deletions.
2 changes: 2 additions & 0 deletions cmd/jimmsrv/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ func start(ctx context.Context, s *service.Service) error {
s.OnShutdown(func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

zapctx.Warn(ctx, "server shutdown triggered")
httpsrv.Shutdown(ctx)
jimmsvc.Cleanup()
})
s.Go(httpsrv.ListenAndServe)
zapctx.Info(ctx, "Successfully started JIMM server")
Expand Down
5 changes: 5 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
package jimm

var NewOpenFGAClient = newOpenFGAClient

// GetCleanups export `Service.cleanups` field for testing purposes.
func (s *Service) GetCleanups() []func() error {
return s.cleanups
}
4 changes: 2 additions & 2 deletions internal/dbmodel/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package dbmodel

import (
"database/sql"
"fmt"
"net"
"strconv"

jujuparams "github.com/juju/juju/rpc/params"
"github.com/juju/names/v5"
Expand Down Expand Up @@ -112,7 +112,7 @@ func (c Controller) ToAPIControllerInfo() apiparams.ControllerInfo {
ci.PublicAddress = c.PublicAddress
for _, hps := range c.Addresses {
for _, hp := range hps {
ci.APIAddresses = append(ci.APIAddresses, fmt.Sprintf("%s:%d", hp.Value, hp.Port))
ci.APIAddresses = append(ci.APIAddresses, net.JoinHostPort(hp.Value, strconv.Itoa(hp.Port)))
}
}
ci.CACertificate = c.CACertificate
Expand Down
63 changes: 46 additions & 17 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package jimm
import (
"context"
"crypto/rand"
"database/sql"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -180,7 +179,8 @@ type Params struct {
type Service struct {
jimm jimm.JIMM

mux *chi.Mux
mux *chi.Mux
cleanups []func() error
}

func (s *Service) JIMM() *jimm.JIMM {
Expand Down Expand Up @@ -225,6 +225,22 @@ func (s *Service) StartJWKSRotator(ctx context.Context, checkRotateRequired <-ch
return s.jimm.JWKService.StartJWKSRotator(ctx, checkRotateRequired, initialRotateRequiredTime)
}

// Cleanup cleans up resources that need to be released on shutdown.
func (s *Service) Cleanup() {
// Iterating over clean up function in reverse-order to avoid early clean ups.
for i := len(s.cleanups) - 1; i >= 0; i-- {
f := s.cleanups[i]
if err := f(); err != nil {
zapctx.Error(context.Background(), "cleanup failed", zap.Error(err))
}
}
}

// AddCleanup adds a clean up function to be run at service shutdown.
func (s *Service) AddCleanup(f func() error) {
s.cleanups = append(s.cleanups, f)
}

// NewService creates a new Service using the given params.
func NewService(ctx context.Context, p Params) (*Service, error) {
const op = errors.Op("NewService")
Expand Down Expand Up @@ -256,10 +272,6 @@ func NewService(ctx context.Context, p Params) (*Service, error) {
if err := s.jimm.Database.Migrate(ctx, false); err != nil {
return nil, errors.E(op, err)
}
sqlDb, err := s.jimm.Database.DB.DB()
if err != nil {
return nil, errors.E(op, err)
}

if p.AuditLogRetentionPeriodInDays != "" {
period, err := strconv.Atoi(p.AuditLogRetentionPeriodInDays)
Expand Down Expand Up @@ -287,15 +299,11 @@ func NewService(ctx context.Context, p Params) (*Service, error) {
return nil, errors.E(op, err)
}

// Setup browser session store
sessionStore, err := setupSessionStore(ctx, sqlDb, s.jimm.CredentialStore)
sessionStore, err := s.setupSessionStore(ctx)
if err != nil {
return nil, errors.E(op, err)
}

// Cleanup expired session every 30 minutes
sessionStore.Cleanup(time.Minute * 30)

redirectUrl := p.PublicDNSName + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
if !strings.HasPrefix(redirectUrl, "https://") || !strings.HasPrefix(redirectUrl, "http://") {
redirectUrl = "https://" + redirectUrl
Expand Down Expand Up @@ -418,21 +426,42 @@ func (s *Service) setupDischarger(p Params) (*discharger.MacaroonDischarger, err
return MacaroonDischarger, nil
}

func setupSessionStore(ctx context.Context, db *sql.DB, credStore jimmcreds.CredentialStore) (*pgstore.PGStore, error) {
oauthSessionStoreSecret, err := credStore.GetOAuthSessionStoreSecret(ctx)
func (s *Service) setupSessionStore(ctx context.Context) (*pgstore.PGStore, error) {
const op = errors.Op("setupSessionStore")

if s.jimm.CredentialStore == nil {
return nil, errors.E(op, "credential store is not configured")
}

sqlDb, err := s.jimm.Database.DB.DB()
if err != nil {
return nil, errors.E(op, err)
}

oauthSessionStoreSecret, err := s.jimm.CredentialStore.GetOAuthSessionStoreSecret(ctx)
if err == nil {
zapctx.Info(ctx, "detected existing OAuth session store secret")
} else if errors.ErrorCode(err) == errors.CodeNotFound {
oauthSessionStoreSecret, err = generateOAuthSessionStoreSecret(ctx, credStore)
oauthSessionStoreSecret, err = generateOAuthSessionStoreSecret(ctx, s.jimm.CredentialStore)
if err != nil {
return nil, err
}
} else {
return nil, errors.E(err, "failed to read session store secret")
return nil, errors.E(op, err, "failed to read session store secret")
}

store, err := pgstore.NewPGStoreFromPool(db, oauthSessionStoreSecret)
return store, err
store, err := pgstore.NewPGStoreFromPool(sqlDb, oauthSessionStoreSecret)
if err != nil {
return nil, errors.E(op, err, "failed to create session store")
}

// Cleanup expired session every 30 minutes
cleanupQuit, cleanupDone := store.Cleanup(time.Minute * 30)
s.AddCleanup(func() error {
store.StopCleanup(cleanupQuit, cleanupDone)
return nil
})
return store, nil
}

func generateOAuthSessionStoreSecret(ctx context.Context, store jimmcreds.CredentialStore) ([]byte, error) {
Expand Down
41 changes: 39 additions & 2 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func TestDefaultService(t *testing.T) {
p.InsecureSecretStorage = true
svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()
rr := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/debug/info", nil)
c.Assert(err, qt.IsNil)
Expand Down Expand Up @@ -76,6 +77,7 @@ func TestAuthenticator(t *testing.T) {
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

err = svc.JIMM().GetCredentialStore().PutOAuthSecret(ctx, []byte(jimmtest.JWTTestSecret))
c.Assert(err, qt.IsNil)
Expand Down Expand Up @@ -140,6 +142,7 @@ func TestVault(t *testing.T) {
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
svc, err := jimm.NewService(ctx, p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

err = svc.JIMM().GetCredentialStore().PutOAuthSecret(ctx, []byte(jimmtest.JWTTestSecret))
c.Assert(err, qt.IsNil)
Expand Down Expand Up @@ -198,8 +201,9 @@ func TestPostgresSecretStore(t *testing.T) {
p := jimmtest.NewTestJimmParams(c)
p.InsecureSecretStorage = true
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
_, err = jimm.NewService(context.Background(), p)
svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()
}

func TestOpenFGA(t *testing.T) {
Expand All @@ -215,6 +219,7 @@ func TestOpenFGA(t *testing.T) {
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
svc, err := jimm.NewService(ctx, p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

err = svc.JIMM().GetCredentialStore().PutOAuthSecret(ctx, []byte(jimmtest.JWTTestSecret))
c.Assert(err, qt.IsNil)
Expand Down Expand Up @@ -264,6 +269,7 @@ func TestPublicKey(t *testing.T) {
p.InsecureSecretStorage = true
svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

srv := httptest.NewTLSServer(svc)
c.Cleanup(srv.Close)
Expand Down Expand Up @@ -339,6 +345,7 @@ func TestThirdPartyCaveatDischarge(t *testing.T) {
p.InsecureSecretStorage = true
svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

srv := httptest.NewTLSServer(svc)
c.Cleanup(srv.Close)
Expand Down Expand Up @@ -406,8 +413,8 @@ func TestDisableOAuthEndpointsWhenDashboardRedirectURLNotSet(t *testing.T) {
p.InsecureSecretStorage = true
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
svc, err := jimm.NewService(context.Background(), p)

c.Assert(err, qt.IsNil)
defer svc.Cleanup()

srv := httptest.NewTLSServer(svc)
c.Cleanup(srv.Close)
Expand All @@ -430,6 +437,7 @@ func TestEnableOAuthEndpointsWhenDashboardRedirectURLSet(t *testing.T) {

svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)
defer svc.Cleanup()

srv := httptest.NewTLSServer(svc)
c.Cleanup(srv.Close)
Expand All @@ -452,3 +460,32 @@ func cofgaParamsToJIMMOpenFGAParams(cofgaParams cofga.OpenFGAParams) jimm.OpenFG
AuthModel: cofgaParams.AuthModelID,
}
}

func TestCleanup(t *testing.T) {
c := qt.New(t)

outputs := make(chan string, 2)
service := jimm.Service{}
service.AddCleanup(func() error { outputs <- "first"; return nil })
service.AddCleanup(func() error { outputs <- "second"; return nil })
service.Cleanup()
c.Assert([]string{<-outputs, <-outputs}, qt.DeepEquals, []string{"second", "first"})
}

func TestCleanupDoesNotPanic_SessionStoreRelatedCleanups(t *testing.T) {
c := qt.New(t)

_, _, cofgaParams, err := jimmtest.SetupTestOFGAClient(c.Name())
c.Assert(err, qt.IsNil)
p := jimmtest.NewTestJimmParams(c)
p.OpenFGAParams = cofgaParamsToJIMMOpenFGAParams(*cofgaParams)
p.InsecureSecretStorage = true

svc, err := jimm.NewService(context.Background(), p)
c.Assert(err, qt.IsNil)

// Make sure `cleanups` is not empty.
c.Assert(len(svc.GetCleanups()) > 0, qt.IsTrue)

svc.Cleanup()
}

0 comments on commit 40ca4e7

Please sign in to comment.