Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: 2FA recovery codes #5

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions internal/dbsqlc/copyfrom.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions internal/dbsqlc/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions internal/dbsqlc/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions internal/dbsqlc/recovery_code_query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- name: CreateRecoveryCodeBatch :copyfrom
INSERT INTO shield_recovery_codes
(id, user_id, recovery_code_hash, is_consumable)
VALUES
(@id::UUID, @user_id::UUID, @recovery_code_hash, @is_consumable);

-- name: EvictUnconsumedRecoveryCodeBatch :exec
UPDATE shield_recovery_codes
SET
evicted_by = @evicted_by::UUID,
evicted_at = NOW()
WHERE user_id = @user_id AND is_consumable = TRUE;
37 changes: 37 additions & 0 deletions internal/dbsqlc/recovery_code_query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions internal/migrations/1731842623024_recovery_code.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- migration: 1731842623024_recovery_code.sql

CREATE TABLE IF NOT EXISTS shield_recovery_codes (
id UUID NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id UUID NOT NULL,
recovery_code_hash VARCHAR(4095) NOT NULL,
is_consumable BOOL NOT NULL DEFAULT TRUE,
evicted_by UUID NULL,
evicted_at TIMESTAMP NULL DEFAULT NULL,
PRIMARY KEY (id),
FOREIGN KEY (user_id) REFERENCES shield_users (id)
ON DELETE CASCADE,
FOREIGN KEY (evicted_by) REFERENCES shield_users (id)
ON DELETE CASCADE
);

---- create above / drop below ----

DROP TABLE IF EXISTS shield_recovery_codes;
2 changes: 1 addition & 1 deletion internal/random/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func secureBytes(l int) ([]byte, error) {
bytes := make([]byte, l)
_, err := rand.Read(bytes)
if err != nil {
return bytes, fmt.Errorf("random: error reading random bytes: %w", err)
return bytes, fmt.Errorf("shield: error reading random bytes: %w", err)
}
return bytes, nil
}
Expand Down
6 changes: 5 additions & 1 deletion shieldpassword/bcrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ type bcryptPasswordHasher struct {
cost int
}

// NewBcryptPasswordHasher implements a password hashing algorithm with bcrypt.
// NewBcryptPasswordHasher creates a password hasher using the bcrypt algorithm.
//
// Please note that bcrypt has a maximum input length of 72 bytes. For passwords
// requiring more than 72 bytes of data, consider using an alternative algorithm
// such as Argon2.
func NewBcryptPasswordHasher(cost int) PasswordHasher {
return &bcryptPasswordHasher{cost}
}
Expand Down
232 changes: 232 additions & 0 deletions shieldrecoverycode/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package shieldrecoverycode

import (
"cmp"
"context"
"fmt"
"log/slog"

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"go.inout.gg/foundations/debug"
"go.inout.gg/shield"
"go.inout.gg/shield/internal/dbsqlc"
"go.inout.gg/shield/internal/random"
"go.inout.gg/shield/internal/uuidv7"
"go.inout.gg/shield/shieldpassword"
)

var _ Generator = (*generator)(nil)

const (
DefaultRecoveryCodeTotalCount = 16
DefaultRecoveryCodeLength = 16
)

var DefaultGenerator Generator = &generator{}

// Generator provides methods to create a set of unique recovery codes used
// for 2FA authentication recovery.
type Generator interface {
Generate(count int, len int) ([]string, error)
}

type generator struct{}

// Generate creates count number of secure random recovery codes.
// Each code is length bytes long and encoded as a hex string.
func (g *generator) Generate(count, length int) ([]string, error) {
codes := make([]string, count, count)

for i := range count {
code, err := random.SecureHexString(length)
if err != nil {
return nil, err
}

codes[i] = code
}

return codes, nil
}

type Config struct {
Logger *slog.Logger
PasswordHasher shieldpassword.PasswordHasher
Generator Generator
RecoveryCodeTotalCount int
RecoveryCodeLength int
}

func (c *Config) defaults() {
c.PasswordHasher = cmp.Or(c.PasswordHasher, shieldpassword.DefaultPasswordHasher)
c.Logger = cmp.Or(c.Logger, shield.DefaultLogger)
c.RecoveryCodeTotalCount = cmp.Or(c.RecoveryCodeTotalCount, DefaultRecoveryCodeTotalCount)
c.RecoveryCodeLength = cmp.Or(c.RecoveryCodeLength, DefaultRecoveryCodeLength)
c.Generator = cmp.Or(c.Generator, DefaultGenerator)
}

func (c *Config) assert() {
debug.Assert(c.Logger != nil, "expected Logger to be defined")
debug.Assert(c.PasswordHasher != nil, "expected PasswordHasher to be defined")
debug.Assert(c.Generator != nil, "expected Generator to be defined")
}

func NewConfig(opts ...func(*Config)) *Config {
c := &Config{}
for _, opt := range opts {
opt(c)
}

c.defaults()
c.assert()

return c
}

type Handler struct {
config *Config
pool *pgxpool.Pool
}

func New(pool *pgxpool.Pool, config *Config) *Handler {
if config == nil {
config = NewConfig()
}

h := Handler{config, pool}
h.assert()

return &h
}

func (h *Handler) assert() {
h.config.assert()
debug.Assert(h.pool != nil, "expected pool to be defined")
}

func (h *Handler) Generate() ([]string, error) {
codes, err := h.config.Generator.Generate(h.config.RecoveryCodeTotalCount, h.config.RecoveryCodeLength)
if err != nil {
return nil, err
}

hashedCodes := make([]string, len(codes))
for i, code := range codes {
hashedCode, err := h.config.PasswordHasher.Hash(code)
if err != nil {
return nil, err
}

hashedCodes[i] = hashedCode
}

return hashedCodes, nil
}

// CreateRecoveryCodes generates a new set of recovery codes
func (h *Handler) CreateRecoveryCodes(ctx context.Context, userID uuid.UUID) error {
codes, err := h.Generate()
if err != nil {
return err
}

tx, err := h.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("shield/recovery_code: failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

if err := h.CreateRecoveryCodesInTx(ctx, userID, codes, tx); err != nil {
return err
}

if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("shield/recovery_code: failed to commit transaction: %w", err)
}

return nil
}

// ReplaceRecoveryCodes regenerates a new set of recovery codes and replaces
// any previously unconsumed recovery codes with the newly generated set.
//
// userID is the ID of the user to update recovery codes for
func (h *Handler) ReplaceRecoveryCodes(ctx context.Context, userID, replacedBy uuid.UUID) error {
codes, err := h.Generate()
if err != nil {
return err
}

tx, err := h.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("shield/recovery_code: failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

if err := h.ReplaceRecoveryCodesInTx(ctx, userID, replacedBy, codes, tx); err != nil {
return err
}

if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("shield/recovery_code: failed to commit transaction: %w", err)
}

return nil
}

func (h *Handler) ReplaceRecoveryCodesInTx(
ctx context.Context,
userID uuid.UUID,
replacedBy uuid.UUID,
codes []string,
tx pgx.Tx,
) error {
if err := h.EvictRecoveryCodesInTx(ctx, userID, replacedBy, tx); err != nil {
return err
}

if err := h.CreateRecoveryCodesInTx(ctx, userID, codes, tx); err != nil {
return err
}

return nil
}

func (h *Handler) EvictRecoveryCodesInTx(
ctx context.Context,
userID uuid.UUID,
evictedBy uuid.UUID,
tx pgx.Tx,
) error {
arg := dbsqlc.EvictUnconsumedRecoveryCodeBatchParams{UserID: userID, EvictedBy: evictedBy}
if err := dbsqlc.New().EvictUnconsumedRecoveryCodeBatch(ctx, tx, arg); err != nil {
return fmt.Errorf("shield/recovery_code: failed to evict recovery codes: %w", err)
}

return nil
}

func (h *Handler) CreateRecoveryCodesInTx(
ctx context.Context,
userID uuid.UUID,
codes []string,
tx pgx.Tx,
) error {
rows := make([]dbsqlc.CreateRecoveryCodeBatchParams, len(codes))
for i, code := range codes {
rows[i] = dbsqlc.CreateRecoveryCodeBatchParams{
ID: uuidv7.Must(),
IsConsumable: true,
RecoveryCodeHash: code,
UserID: userID,
}
}

if _, err := dbsqlc.New().CreateRecoveryCodeBatch(ctx, tx, rows); err != nil {
return fmt.Errorf("shield/recovery_code: failed to create recovery codes: %w", err)
}

return nil
}
Loading