Skip to content

Commit

Permalink
Start storing session keys for cookies in db
Browse files Browse the repository at this point in the history
When trying to add tests for server.Init() it became tricky to manage
the goose migration files across both the "migrations" package as well
as the testdata because all .go migration files would be loaded via
init() calls even if we only expected to load the ones for our
migrations package.

For this reason stop using migration files for filling in testdata and
only use it for managing structural changes to the database.
  • Loading branch information
eest committed Dec 26, 2024
1 parent 832cd25 commit 46721f2
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 343 deletions.
6 changes: 6 additions & 0 deletions pkg/migrations/files/00001_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ CREATE TABLE users (
name text UNIQUE NOT NULL CONSTRAINT non_empty CHECK(length(name)>=1 AND length(name)<=63)
);

CREATE TABLE user_session_keys (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
ts timestamptz NOT NULL DEFAULT now(),
key bytea UNIQUE NOT NULL CONSTRAINT secure_length CHECK(length(key)>=32 AND length(key)<=32)
);

CREATE TABLE user_argon2keys (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
ts timestamptz NOT NULL DEFAULT now(),
Expand Down
33 changes: 32 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1654,7 +1654,7 @@ type InitUser struct {
func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {
err := migrations.Up(logger, pgConfig)
if err != nil {
return InitUser{}, fmt.Errorf("unable to get run initial migration: %w", err)
return InitUser{}, fmt.Errorf("unable to run initial migration: %w", err)
}

dbPool, err := pgxpool.NewWithConfig(context.Background(), pgConfig)
Expand All @@ -1679,6 +1679,11 @@ func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {
Password: password,
}

userSessionKey, err := generateRandomKey(32)
if err != nil {
return InitUser{}, fmt.Errorf("unable to create random user session key: %w", err)
}

err = pgx.BeginFunc(context.Background(), dbPool, func(tx pgx.Tx) error {
// Verify there are no roles present
var rolesExists bool
Expand Down Expand Up @@ -1707,6 +1712,11 @@ func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {

u.ID = userID

_, err = insertUserSessionKey(tx, userSessionKey)
if err != nil {
return fmt.Errorf("unable to INSERT initial user session key: %w", err)
}

return nil
})
if err != nil {
Expand All @@ -1716,6 +1726,27 @@ func Init(logger zerolog.Logger, pgConfig *pgxpool.Config) (InitUser, error) {
return u, nil
}

func generateRandomKey(length int) ([]byte, error) {
b := make([]byte, length)

_, err := rand.Read(b)
if err != nil {
return nil, fmt.Errorf("failed to read random bytes: %w", err)
}

return b, nil
}

func insertUserSessionKey(tx pgx.Tx, key []byte) (pgtype.UUID, error) {
var sessionKeyID pgtype.UUID
err := tx.QueryRow(context.Background(), "INSERT INTO user_session_keys (key) VALUES ($1) RETURNING id", key).Scan(&sessionKeyID)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("unable to INSERT user session key: %w", err)
}

return sessionKeyID, nil
}

func Run(logger zerolog.Logger) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
247 changes: 232 additions & 15 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package server
import (
"bytes"
"context"
"embed"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
Expand All @@ -16,17 +16,15 @@ import (
"strings"
"testing"

_ "github.com/SUNET/sunet-cdn-manager/pkg/server/testdata/migrations" // needed to run .go migration files
"github.com/SUNET/sunet-cdn-manager/pkg/migrations"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
"github.com/rs/zerolog"
"github.com/stapelberg/postgrestest"
"golang.org/x/crypto/argon2"
)

//go:embed testdata/migrations/*.sql
var embedMigrations embed.FS

var (
pgt *postgrestest.Server
logger zerolog.Logger
Expand All @@ -45,13 +43,198 @@ func TestMain(m *testing.M) {
}
logger = zerolog.New(os.Stderr).With().Timestamp().Caller().Logger()

goose.SetBaseFS(embedMigrations)
m.Run()
}

if err := goose.SetDialect("postgres"); err != nil {
logger.Fatal().Err(err).Msg("unable to goose.SetDialect()")
func populateTestData(dbPool *pgxpool.Pool) error {
// use static UUIDs to get known contents for testing
testData := []string{
// Organizations
"INSERT INTO organizations (id, name) VALUES ('00000002-0000-0000-0000-000000000001', 'org1')",
"INSERT INTO organizations (id, name) VALUES ('00000002-0000-0000-0000-000000000002', 'org2')",
"INSERT INTO organizations (id, name) VALUES ('00000002-0000-0000-0000-000000000003', 'org3')",

// Services
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000001', id, 'org1-service1' FROM organizations WHERE name='org1'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000002', id, 'org1-service2' FROM organizations WHERE name='org1'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000003', id, 'org1-service3' FROM organizations WHERE name='org1'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000004', id, 'org2-service1' FROM organizations WHERE name='org2'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000005', id, 'org2-service2' FROM organizations WHERE name='org2'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000006', id, 'org2-service3' FROM organizations WHERE name='org2'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000007', id, 'org3-service1' FROM organizations WHERE name='org3'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000008', id, 'org3-service2' FROM organizations WHERE name='org3'",
"INSERT INTO services (id, org_id, name) SELECT '00000003-0000-0000-0000-000000000009', id, 'org3-service3' FROM organizations WHERE name='org3'",

// Service versions
// org1, last version is active
"UPDATE services SET version_counter = version_counter + 1 WHERE name='org1-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000001', id, version_counter FROM services WHERE name='org1-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org1-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000002', id, version_counter FROM services WHERE name='org1-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org1-service1'",
"INSERT INTO service_versions (id, service_id, version, active) SELECT '00000004-0000-0000-0000-000000000003', id, version_counter, TRUE FROM services WHERE name='org1-service1'",

// org2, second version is active
"UPDATE services SET version_counter = version_counter + 1 WHERE name='org2-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000004', id, version_counter FROM services WHERE name='org2-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org2-service1'",
"INSERT INTO service_versions (id, service_id, version, active) SELECT '00000004-0000-0000-0000-000000000005', id, version_counter, TRUE FROM services WHERE name='org2-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org2-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000006', id, version_counter FROM services WHERE name='org2-service1'",

// org3, no version is active
"UPDATE services SET version_counter = version_counter + 1 WHERE name='org3-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000007', id, version_counter FROM services WHERE name='org3-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org3-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000008', id, version_counter FROM services WHERE name='org3-service1'",

"UPDATE services SET version_counter = version_counter + 1 WHERE name='org3-service1'",
"INSERT INTO service_versions (id, service_id, version) SELECT '00000004-0000-0000-0000-000000000009', id, version_counter FROM services WHERE name='org3-service1'",

// Roles
"INSERT INTO roles (id, name, superuser) VALUES ('00000005-0000-0000-0000-000000000001', 'admin', TRUE)",
"INSERT INTO roles (id, name) VALUES ('00000005-0000-0000-0000-000000000002', 'customer')",

// Domains
"INSERT INTO service_domains (id, service_version_id, domain) VALUES ('00000008-0000-0000-0000-000000000001', '00000004-0000-0000-0000-000000000003', 'www.example.se')",
"INSERT INTO service_domains (id, service_version_id, domain) VALUES ('00000008-0000-0000-0000-000000000002', '00000004-0000-0000-0000-000000000003', 'www.example.com')",

// Origins
"INSERT INTO service_origins (id, service_version_id, host, port, tls) VALUES ('00000009-0000-0000-0000-000000000001', '00000004-0000-0000-0000-000000000003', 'srv2.example.com', 80, false)",
"INSERT INTO service_origins (id, service_version_id, host, port, tls) VALUES ('00000009-0000-0000-0000-000000000002', '00000004-0000-0000-0000-000000000003', 'srv1.example.se', 443, true)",
}

m.Run()
err := pgx.BeginFunc(context.Background(), dbPool, func(tx pgx.Tx) error {
for _, sql := range testData {
_, err := tx.Exec(context.Background(), sql)
if err != nil {
return err
}
}
localUsers := []struct {
name string
password string
orgName string
role string
superuser bool
id string
}{
{
name: "admin",
password: "adminpass1",
role: "admin",
id: "00000006-0000-0000-0000-000000000001",
},
{
name: "username1",
password: "password1",
role: "customer",
orgName: "org1",
id: "00000006-0000-0000-0000-000000000002",
},
{
name: "username2",
password: "password2",
role: "customer",
orgName: "org2",
id: "00000006-0000-0000-0000-000000000003",
},
{
name: "username3-no-org",
password: "password3",
role: "customer",
id: "00000006-0000-0000-0000-000000000004",
},
}

for _, localUser := range localUsers {
var userID pgtype.UUID
err := userID.Scan(localUser.id)
if err != nil {
return err
}

var orgID *pgtype.UUID // may be nil

if localUser.orgName != "" {
err := tx.QueryRow(context.Background(), "SELECT id FROM organizations WHERE name=$1", localUser.orgName).Scan(&orgID)
if err != nil {
return err
}
}

_, err = tx.Exec(context.Background(), "INSERT INTO users (id, org_id, name, role_id) SELECT $1, $2, $3, id FROM roles WHERE name=$4", userID, orgID, localUser.name, localUser.role)
if err != nil {
return err
}

// Generate 16 byte (128 bit) salt as
// recommended for argon2 in RFC 9106
salt := make([]byte, 16)
_, err = rand.Read(salt)
if err != nil {
return err
}

timeSize := uint32(1)
memorySize := uint32(64 * 1024)
threads := uint8(4)
tagSize := uint32(32)

key := argon2.IDKey([]byte(localUser.password), salt, timeSize, memorySize, threads, tagSize)
_, err = tx.Exec(context.Background(), "INSERT INTO user_argon2keys (user_id, key, salt, time, memory, threads, tag_size) VALUES ($1, $2, $3, $4, $5, $6, $7)", userID, key, salt, timeSize, memorySize, threads, tagSize)
if err != nil {
return err
}
}

vclRcvs := []struct {
id string
file string
serviceVersionID string
}{
{
id: "00000007-0000-0000-0000-000000000001",
serviceVersionID: "00000004-0000-0000-0000-000000000003",
file: "testdata/vcl/vcl_recv/content1.vcl",
},
}

for _, vclRcv := range vclRcvs {
var vclID, serviceVersionID pgtype.UUID
err := vclID.Scan(vclRcv.id)
if err != nil {
return err
}

err = serviceVersionID.Scan(vclRcv.serviceVersionID)
if err != nil {
return err
}

contentBytes, err := os.ReadFile(vclRcv.file)
if err != nil {
return err
}

_, err = tx.Exec(context.Background(), "INSERT INTO service_vcl_recv (id, service_version_id, content) VALUES($1, $2, $3)", vclID, serviceVersionID, contentBytes)
if err != nil {
return err
}
}

return nil
})
if err != nil {
return err
}

return nil
}

func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {
Expand All @@ -64,7 +247,7 @@ func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {

pgConfig, err := pgxpool.ParseConfig(pgurl)
if err != nil {
return nil, nil, errors.New("unable to parse PostgreSQL config string")
return nil, nil, err
}

fmt.Println(pgConfig.ConnString())
Expand All @@ -74,10 +257,14 @@ func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {
return nil, nil, errors.New("unable to create database pool")
}

db := stdlib.OpenDBFromPool(dbPool)
err = migrations.Up(logger, pgConfig)
if err != nil {
return nil, nil, err
}

if err := goose.Up(db, "testdata/migrations"); err != nil {
return nil, dbPool, err
err = populateTestData(dbPool)
if err != nil {
return nil, nil, err
}

router := newChiRouter(logger, dbPool)
Expand All @@ -92,6 +279,36 @@ func prepareServer() (*httptest.Server, *pgxpool.Pool, error) {
return ts, dbPool, nil
}

func TestServerInit(t *testing.T) {
pgurl, err := pgt.CreateDatabase(context.Background())
if err != nil {
t.Fatal(err)
}

pgConfig, err := pgxpool.ParseConfig(pgurl)
if err != nil {
t.Fatal(err)
}

fmt.Println(pgConfig.ConnString())

u, err := Init(logger, pgConfig)
if err != nil {
t.Fatal(err)
}

expectedUsername := "admin"
expectedPasswordLength := 30

if u.Name != expectedUsername {
t.Fatalf("expected initial user '%s', got: '%s'", expectedUsername, u.Name)
}

if len(u.Password) != expectedPasswordLength {
t.Fatalf("expected initial user password length %d, got: %d", expectedPasswordLength, len(u.Password))
}
}

func TestGetUsers(t *testing.T) {
ts, dbPool, err := prepareServer()
if dbPool != nil {
Expand Down
Loading

0 comments on commit 46721f2

Please sign in to comment.