Skip to content

Commit

Permalink
📝 feat: session in header && recover middleware && encrypt session da…
Browse files Browse the repository at this point in the history
…ta (#1004)
  • Loading branch information
garrettladley authored Jun 11, 2024
1 parent cb8f59f commit 8ae7c56
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 26 deletions.
4 changes: 4 additions & 0 deletions backend/Dockerfile.server
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ WORKDIR /app

COPY . .

RUN go install github.com/a-h/templ/cmd/templ@latest

RUN templ generate

RUN go mod download

RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -installsuffix cgo -o bin/sac main.go
Expand Down
24 changes: 24 additions & 0 deletions backend/config/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package config

import m "github.com/garrettladley/mattress"

type SessionSettings struct {
Redis RedisSettings
PassPhrase *m.Secret[string]
}

type intermediateSessionSettings struct {
PassPhrase string `env:"PASS_PHRASE"`
}

func (i *intermediateSessionSettings) into(redis RedisSettings) (*SessionSettings, error) {
passPhrase, err := m.NewSecret(i.PassPhrase)
if err != nil {
return nil, err
}

return &SessionSettings{
Redis: redis,
PassPhrase: passPhrase,
}, nil
}
14 changes: 13 additions & 1 deletion backend/config/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package config
type Settings struct {
Application ApplicationSettings
Database DatabaseSettings
RedisSession RedisSettings
Session SessionSettings
RedisLimiter RedisSettings
SuperUser SuperUserSettings
Calendar CalendarSettings
Expand All @@ -22,6 +22,7 @@ type intermediateSettings struct {
Application ApplicationSettings `envPrefix:"SAC_APPLICATION_"`
Database intermediateDatabaseSettings `envPrefix:"SAC_DB_"`
RedisSession intermediateRedisSettings `envPrefix:"SAC_REDIS_SESSION_"`
Session intermediateSessionSettings `envPrefix:"SAC_SESSION_"`
RedisLimiter intermediateRedisSettings `envPrefix:"SAC_REDIS_LIMITER_"`
SuperUser intermediateSuperUserSettings `envPrefix:"SAC_SUDO_"`
AWS intermediateAWSSettings `envPrefix:"SAC_AWS_"`
Expand All @@ -38,6 +39,16 @@ func (i *intermediateSettings) into() (*Settings, error) {
return nil, err
}

redisSession, err := i.RedisSession.into()
if err != nil {
return nil, err
}

session, err := i.Session.into(*redisSession)
if err != nil {
return nil, err
}

redisLimiter, err := i.RedisLimiter.into()
if err != nil {
return nil, err
Expand Down Expand Up @@ -76,6 +87,7 @@ func (i *intermediateSettings) into() (*Settings, error) {
return &Settings{
Application: i.Application,
Database: *database,
Session: *session,
RedisLimiter: *redisLimiter,
SuperUser: *superUser,
Calendar: *calendar,
Expand Down
3 changes: 3 additions & 0 deletions backend/database/store/storer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package store
import (
"context"
"fmt"
"log/slog"
"runtime"
"time"

Expand Down Expand Up @@ -51,10 +52,12 @@ func NewRedisClient(settings RedisSettings) *RedisClient {
}

func (r *RedisClient) Get(key string) ([]byte, error) {
slog.Info("getting", "key", key)
return r.client.Get(context.Background(), key).Bytes()
}

func (r *RedisClient) Set(key string, val []byte, exp time.Duration) error {
slog.Info("setting", "key", key, "val", string(val), "exp", exp)
return r.client.Set(context.Background(), key, val, exp).Err()
}

Expand Down
8 changes: 7 additions & 1 deletion backend/entities/auth/base/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"net/http"
"net/url"
"time"

"github.com/GenerateNU/sac/backend/integrations/oauth/soth"
Expand Down Expand Up @@ -101,7 +102,12 @@ func (h *Handler) ProviderCallback(c *fiber.Ctx) error {
return err
}

return c.Redirect(c.Cookies("redirect", "/"))
redirect, err := url.PathUnescape(c.Cookies("redirect", "/"))
if err != nil {
return err
}

return c.Redirect(redirect)
}

func (h *Handler) ProviderLogout(c *fiber.Ctx) error {
Expand Down
63 changes: 63 additions & 0 deletions backend/integrations/oauth/crypt/crypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package crypt

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
)

func Encrypt(data string, passphrase string) (string, error) {
block, err := createCipherBlock(passphrase)
if err != nil {
return "", err
}

plaintext := []byte(data)
if len(plaintext) > 1028 {
return "", fmt.Errorf("plaintext too long")
}

ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}

stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)

return hex.EncodeToString(ciphertext), nil
}

func Decrypt(encryptedData, passphrase string) (string, error) {
block, err := createCipherBlock(passphrase)
if err != nil {
return "", err
}

ciphertext, _ := hex.DecodeString(encryptedData)
if len(ciphertext) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}

iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]

stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)

return string(ciphertext), nil
}

func createCipherBlock(key string) (cipher.Block, error) {
hash := sha256.Sum256([]byte(key))
block, err := aes.NewCipher(hash[:])
if err != nil {
return nil, err
}
return block, nil
}
49 changes: 31 additions & 18 deletions backend/integrations/oauth/soth/sothic/sothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/url"
"strings"

Expand All @@ -16,41 +17,49 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/session"

"github.com/GenerateNU/sac/backend/integrations/oauth/crypt"
"github.com/GenerateNU/sac/backend/integrations/oauth/soth"
)

type key int

const (
SessionName string = "_sothic_session"
SessionName string = "_sac_session"

// ProviderParamKey can be used as a key in context when passing in a provider
ProviderParamKey key = iota
)

// Session can/should be set by applications using gothic. The default is a cookie store.
// Session can/should be set by applications using sothic.
var (
SessionStore *session.Store
encrypter func(string) (string, error)
decrypter func(string) (string, error)
)

// MUST be called before using the package
func Init(sessionSettings config.RedisSettings) {
func Init(sessionSettings config.SessionSettings) {
config := session.Config{
Storage: store.NewRedisClient(sessionSettings),
KeyLookup: fmt.Sprintf("cookie:%s", SessionName),
// for local
CookieHTTPOnly: true,
// MARK: secure in prod
// TODO: use build tags to set this
Storage: store.NewRedisClient(sessionSettings.Redis),
KeyLookup: fmt.Sprintf("header:%s", SessionName),
}

encrypter = func(s string) (string, error) {
return crypt.Encrypt(s, sessionSettings.PassPhrase.Expose())
}

decrypter = func(s string) (string, error) {
return crypt.Decrypt(s, sessionSettings.PassPhrase.Expose())
}

SessionStore = session.New(config)
}

/*
BeginAuthHandler is a convenience handler for starting the authentication process.
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider".
It expects to be able to get the name of the provider from the path parameter
":provider" or as set by SetProvider.
BeginAuthHandler will redirect the user to the appropriate authentication end-point
for the requested provider.
Expand Down Expand Up @@ -84,7 +93,7 @@ func SetState(c *fiber.Ctx) string {
nonceBytes := make([]byte, 64)
_, err := io.ReadFull(rand.Reader, nonceBytes)
if err != nil {
panic("gothic: source of randomness unavailable: " + err.Error())
panic(fmt.Sprintf("sothic: source of randomness unavailable: %v", err.Error()))
}
return base64.URLEncoding.EncodeToString(nonceBytes)
}
Expand Down Expand Up @@ -131,8 +140,7 @@ func GetAuthURL(c *fiber.Ctx) (string, error) {
return "", err
}

err = StoreInSession(providerName, sess.Marshal(), c)
if err != nil {
if err := StoreInSession(providerName, sess.Marshal(), c); err != nil {
return "", err
}

Expand Down Expand Up @@ -194,8 +202,7 @@ func CompleteUserAuth(c *fiber.Ctx) (soth.User, error) {
return soth.User{}, err
}

err = StoreInSession(providerName, sess.Marshal(), c)
if err != nil {
if err := StoreInSession(providerName, sess.Marshal(), c); err != nil {
return soth.User{}, err
}

Expand Down Expand Up @@ -283,6 +290,7 @@ func SetProvider(c *fiber.Ctx, provider string) {
func StoreInSession(key string, value string, c *fiber.Ctx) error {
session, err := SessionStore.Get(c)
if err != nil {
slog.Info("error getting session", "error", err)
return err
}

Expand Down Expand Up @@ -326,7 +334,7 @@ func getSessionValue(store *session.Session, key string) (string, error) {
return "", err
}

return string(s), nil
return decrypter(string(s))
}

func updateSessionValue(session *session.Session, key, value string) error {
Expand All @@ -342,7 +350,12 @@ func updateSessionValue(session *session.Session, key, value string) error {
return err
}

session.Set(key, b.String())
encrypted, err := encrypter(b.String())
if err != nil {
return err
}

session.Set(key, encrypted)

return nil
}
2 changes: 1 addition & 1 deletion backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func main() {
startBackgroundJobs(ctx, db)

stores := store.ConfigureStores(config.RedisLimiter)
sothic.Init(config.RedisLimiter)
sothic.Init(config.Session)
integrations := configureIntegrations(&config.Integrations)

tp := telemetry.InitTracer()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package auth

import (
"net/url"
"slices"
"time"

"github.com/GenerateNU/sac/backend/locals"
"github.com/GenerateNU/sac/backend/permission"

"github.com/GenerateNU/sac/backend/entities/models"
"github.com/GenerateNU/sac/backend/integrations/oauth/soth/sothic"

"github.com/GenerateNU/sac/backend/locals"
"github.com/GenerateNU/sac/backend/permission"
"github.com/GenerateNU/sac/backend/utilities"

"github.com/gofiber/fiber/v2"
Expand All @@ -21,7 +20,7 @@ func (m *AuthMiddlewareHandler) Authorize(requiredPermissions ...permission.Perm
if err != nil {
c.Cookie(&fiber.Cookie{
Name: "redirect",
Value: c.OriginalURL(),
Value: url.PathEscape(c.OriginalURL()),
Expires: time.Now().Add(5 * time.Minute),
// MARK: secure should be true in prod
// use go build tags to do this
Expand Down
3 changes: 3 additions & 0 deletions backend/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/gofiber/fiber/v2/middleware/compress"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/fiber/v2/middleware/requestid"

"gorm.io/gorm"
Expand Down Expand Up @@ -118,6 +119,8 @@ func newFiberApp(appSettings config.ApplicationSettings) *fiber.App {
ErrorHandler: utilities.ErrorHandler,
})

app.Use(recover.New())

app.Use(cors.New(cors.Config{
AllowOrigins: appSettings.ApplicationURL(),
AllowCredentials: true,
Expand Down
28 changes: 28 additions & 0 deletions backend/tests/crypt/crypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package crypt

import (
"testing"

"github.com/GenerateNU/sac/backend/integrations/oauth/crypt"
)

func TestCryptEncryptDecrypt(t *testing.T) {
t.Parallel()

data := "test data"
passphrase := "test passphrase"

encrypted, err := crypt.Encrypt(data, passphrase)
if err != nil {
t.Fatal(err)
}

decrypted, err := crypt.Decrypt(encrypted, passphrase)
if err != nil {
t.Fatal(err)
}

if decrypted != data {
t.Fatalf("expected %s, got %s", data, decrypted)
}
}
Loading

0 comments on commit 8ae7c56

Please sign in to comment.