Skip to content

Commit

Permalink
Fix test for encrypted gorilla session key
Browse files Browse the repository at this point in the history
We need to be able to choose if the initial testdata added for tests
by prepareServer() or prod data added by Init() in a production setting uses
an encrypted session key or not.
  • Loading branch information
eest committed Dec 30, 2024
1 parent 3f9e62b commit 11d453f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
17 changes: 13 additions & 4 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ type InitUser struct {
Password string
}

func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {
func Init(logger zerolog.Logger, pgConfig *pgxpool.Config, encryptedSessionKey bool) (InitUser, error) {
err := migrations.Up(logger, pgConfig)
if err != nil {
return InitUser{}, fmt.Errorf("unable to run initial migration: %w", err)
Expand Down Expand Up @@ -1716,9 +1716,18 @@ func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {
Password: password,
}

userSessionKey, err := generateRandomKey(32)
var gorillaSessionEncKey []byte

gorillaSessionAuthKey, err := generateRandomKey(32)
if err != nil {
return InitUser{}, fmt.Errorf("unable to create random user session key: %w", err)
return InitUser{}, fmt.Errorf("unable to create random gorilla session auth key: %w", err)
}

if encryptedSessionKey {
gorillaSessionEncKey, err = generateRandomKey(32)
if err != nil {
return InitUser{}, fmt.Errorf("unable to create random gorilla session encryption key: %w", err)
}
}

err = pgx.BeginFunc(context.Background(), dbPool, func(tx pgx.Tx) error {
Expand Down Expand Up @@ -1749,7 +1758,7 @@ func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {

u.ID = userID

_, err = insertGorillaSessionKey(tx, userSessionKey, nil)
_, err = insertGorillaSessionKey(tx, gorillaSessionAuthKey, gorillaSessionEncKey)
if err != nil {
return fmt.Errorf("unable to INSERT initial user session key: %w", err)
}
Expand Down
49 changes: 29 additions & 20 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestMain(m *testing.M) {
m.Run()
}

func populateTestData(dbPool *pgxpool.Pool) error {
func populateTestData(dbPool *pgxpool.Pool, encryptedSessionKey bool) error {
// use static UUIDs to get known contents for testing
testData := []string{
// Organizations
Expand Down Expand Up @@ -231,10 +231,19 @@ func populateTestData(dbPool *pgxpool.Pool) error {

gorillaAuthKey, err := generateRandomKey(32)
if err != nil {
return fmt.Errorf("unable to create random user session key: %w", err)
return fmt.Errorf("unable to create random gorilla session auth key: %w", err)
}

_, err = insertGorillaSessionKey(tx, gorillaAuthKey, nil)
var gorillaEncKey []byte

if encryptedSessionKey {
gorillaEncKey, err = generateRandomKey(32)
if err != nil {
return fmt.Errorf("unable to create random gorilla session encryption key: %w", err)
}
}

_, err = insertGorillaSessionKey(tx, gorillaAuthKey, gorillaEncKey)
if err != nil {
return fmt.Errorf("unable to INSERT user session key: %w", err)
}
Expand All @@ -248,7 +257,7 @@ func populateTestData(dbPool *pgxpool.Pool) error {
return nil
}

func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {
func prepareServer(encryptedSessionKey bool) (*httptest.Server, *pgxpool.Pool, error) {
pgurl, err := pgt.CreateDatabase(context.Background())
if err != nil {
return nil, nil, err
Expand All @@ -273,7 +282,7 @@ func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {
return nil, nil, err
}

err = populateTestData(dbPool)
err = populateTestData(dbPool, encryptedSessionKey)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -308,7 +317,7 @@ func TestServerInit(t *testing.T) {

fmt.Println(pgConfig.ConnString())

u, err := Init(logger, pgConfig)
u, err := Init(logger, pgConfig, false)
if err != nil {
t.Fatal(err)
}
Expand All @@ -326,7 +335,7 @@ func TestServerInit(t *testing.T) {
}

func TestSessionKeyHandlingNoEnc(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -391,7 +400,7 @@ func TestSessionKeyHandlingNoEnc(t *testing.T) {
}

func TestSessionKeyHandlingWithEnc(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(true)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -461,7 +470,7 @@ func TestSessionKeyHandlingWithEnc(t *testing.T) {
}

func TestGetUsers(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -534,7 +543,7 @@ func TestGetUsers(t *testing.T) {
}

func TestGetUser(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -642,7 +651,7 @@ func TestGetUser(t *testing.T) {
}

func TestPostUsers(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -776,7 +785,7 @@ func TestPostUsers(t *testing.T) {
}

func TestGetOrganizations(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -849,7 +858,7 @@ func TestGetOrganizations(t *testing.T) {
}

func TestGetOrganization(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -957,7 +966,7 @@ func TestGetOrganization(t *testing.T) {
}

func TestPostOrganizations(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1057,7 +1066,7 @@ func TestPostOrganizations(t *testing.T) {
}

func TestGetServices(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1130,7 +1139,7 @@ func TestGetServices(t *testing.T) {
}

func TestGetService(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1249,7 +1258,7 @@ func TestGetService(t *testing.T) {
}

func TestPostServices(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1394,7 +1403,7 @@ func TestPostServices(t *testing.T) {
}

func TestGetServiceVersions(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1463,7 +1472,7 @@ func TestGetServiceVersions(t *testing.T) {
}

func TestPostServiceVersion(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down Expand Up @@ -1735,7 +1744,7 @@ func TestPostServiceVersion(t *testing.T) {
}

func TestGetVcls(t *testing.T) {
ts, dbPool, err := prepareServer()
ts, dbPool, err := prepareServer(false)
if dbPool != nil {
defer dbPool.Close()
}
Expand Down

0 comments on commit 11d453f

Please sign in to comment.