Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions cachePersistence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package geoblock

import (
"bytes"
"context"
"fmt"
"log"
"os"
"path/filepath"
"sync/atomic"
"time"

lru "github.com/PascalMinder/geoblock/lrucache"
)

const defaultWriteCycle = 15

// Options configures cache initialization and persistence behavior.
type Options struct {
// CacheSize specifies the number of entries to store in memory.
CacheSize int

// CachePath is the file path used for on-disk persistence.
// Leave empty to disable persistence.
CachePath string

// PersistInterval defines how often the cache is automatically
// flushed to disk. Defaults to 15 seconds if zero.
PersistInterval time.Duration

// Logger is used for diagnostic messages.
Logger *log.Logger

// Name is included in log messages to identify the cache instance.
Name string
}

// InitializeCache creates a new LRU cache and, if a valid persistence
// path is provided, starts a background goroutine to periodically
// save snapshots to disk.
//
// The returned `CachePersist` can be used to mark the cache as dirty when
// data changes. If persistence is disabled, it is a no-op.
//
// Callers should cancel the provided context to stop persistence
// and ensure a final snapshot is written.
func InitializeCache(ctx context.Context, opts Options) (*lru.LRUCache, *CachePersist, error) {
if opts.PersistInterval <= 0 {
opts.PersistInterval = defaultWriteCycle * time.Second
}

cache, err := lru.NewLRUCache(opts.CacheSize)
if err != nil {
return nil, nil, fmt.Errorf("create LRU cache: %w", err)
}

var ipDB *CachePersist // stays nil if disabled
if path, err := ValidatePersistencePath(opts.CachePath); len(path) > 0 {
// load existing cache
if err := cache.ImportFromFile(path); err != nil && !os.IsNotExist(err) {
opts.Logger.Printf("%s: could not load IP DB snapshot (%s): %v", opts.Name, path, err)
}

ipDB = &CachePersist{
path: path,
persistTicker: time.NewTicker(opts.PersistInterval),
persistChannel: make(chan struct{}, 1),
cache: cache,
log: opts.Logger,
name: opts.Name,
}

go func(ctx context.Context, p *CachePersist) {
defer p.persistTicker.Stop()
for {
select {
case <-ctx.Done():
p.snapshotToDisk()
return
case <-p.persistTicker.C:
p.snapshotToDisk()
case <-p.persistChannel:
p.snapshotToDisk()
}
}
}(ctx, ipDB)

opts.Logger.Printf("%s: IP database persistence enabled -> %s", opts.Name, path)
} else if err != nil {
opts.Logger.Printf("%s: IP database persistence disabled: %v", opts.Name, err)
} else {
opts.Logger.Printf("%s: IP database persistence disabled (no path)", opts.Name)
}

return cache, ipDB, nil
}

// CachePersist periodically snapshots a cache to disk.
type CachePersist struct {
path string
persistTicker *time.Ticker
persistChannel chan struct{}
ipDirty uint32 // 0 clean, 1 dirty

cache *lru.LRUCache
log *log.Logger
name string
}

// MarkDirty marks the cache as modified and schedules a snapshot.
func (p *CachePersist) MarkDirty() {
if p == nil { // feature OFF
return
}
atomic.StoreUint32(&p.ipDirty, 1)
select {
case p.persistChannel <- struct{}{}:
default:
}
}

// Snapshot writes the cache to disk if it has been marked dirty.
func (p *CachePersist) snapshotToDisk() {
if p == nil || atomic.LoadUint32(&p.ipDirty) == 0 {
return
}

var buf bytes.Buffer
if err := p.cache.Export(&buf); err != nil {
p.log.Printf("%s: cache snapshot encode error: %v", p.name, err)
return
}

dir := filepath.Dir(p.path)
tmp, err := os.CreateTemp(dir, "ipdb-*.tmp")
if err != nil {
p.log.Printf("%s: snapshot temp file error: %v", p.name, err)
return
}
tmpPath := tmp.Name()

if _, err := tmp.Write(buf.Bytes()); err != nil {
tmp.Close()
os.Remove(tmpPath)
p.log.Printf("%s: snapshot write error: %v", p.name, err)
return
}
if err := tmp.Sync(); err != nil {
tmp.Close()
os.Remove(tmpPath)
p.log.Printf("%s: snapshot fsync error: %v", p.name, err)
return
}
if err := tmp.Close(); err != nil {
os.Remove(tmpPath)
p.log.Printf("%s: snapshot close error: %v", p.name, err)
return
}
if err := os.Rename(tmpPath, p.path); err != nil {
os.Remove(tmpPath)
p.log.Printf("%s: snapshot rename error: %v", p.name, err)
return
}

atomic.StoreUint32(&p.ipDirty, 0)
p.log.Printf("%s: cache snapshot written to %s", p.name, p.path)
}
26 changes: 24 additions & 2 deletions docker/dev-geoblock/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
traefik:
image: traefik:v3.1
image: traefik:v3.5.3

volumes:
- /var/run/docker.sock:/var/run/docker.sock
Expand All @@ -18,14 +18,36 @@ services:
- "80:80"

hello:
image: containous/whoami
image: traefik/whoami
labels:
- traefik.enable=true
- traefik.http.routers.hello.entrypoints=http
- traefik.http.routers.hello.rule=Host(`hello.localhost`)
- traefik.http.services.hello.loadbalancer.server.port=80
- traefik.http.routers.hello.middlewares=my-plugin@file

hello1:
image: traefik/whoami
command:
- --port=81
labels:
- traefik.enable=true
- traefik.http.routers.hello1.entrypoints=http
- traefik.http.routers.hello1.rule=Host(`hello1.localhost`)
- traefik.http.services.hello1.loadbalancer.server.port=81
- traefik.http.routers.hello1.middlewares=my-plugin@file

hello2:
image: traefik/whoami
command:
- --port=82
labels:
- traefik.enable=true
- traefik.http.routers.hello2.entrypoints=http
- traefik.http.routers.hello2.rule=Host(`hello2.localhost`)
- traefik.http.services.hello2.loadbalancer.server.port=82
- traefik.http.routers.hello2.middlewares=my-plugin@file

whoami:
image: jwilder/whoami
labels:
Expand Down
1 change: 1 addition & 0 deletions docker/traefik-config/dynamic-configuration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ http:
allowUnknownCountries: false
unknownCountryApiResponse: "nil"
logFilePath: "/geoblock/geoblockA.log"
ipDatabaseCachePath: "/geoblock/cacheA.cache"
countries:
- GB
- IS
90 changes: 28 additions & 62 deletions geoblock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ import (
"context"
"fmt"
"io"
"io/fs"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"

Expand All @@ -25,7 +23,7 @@ const (
unknownCountryCode = "AA"
countryCodeLength = 2
defaultDeniedRequestHTTPStatusCode = 403
filePermissions = fs.FileMode(0666)
defaultCacheWriteCycle = 15
)

// Config the plugin configuration.
Expand All @@ -49,8 +47,9 @@ type Config struct {
AllowedIPAddresses []string `yaml:"allowedIPAddresses,omitempty"`
AddCountryHeader bool `yaml:"addCountryHeader"`
HTTPStatusCodeDeniedRequest int `yaml:"httpStatusCodeDeniedRequest"`
LogFilePath string `yaml:"logFilePath"`
RedirectURLIfDenied string `yaml:"redirectUrlIfDenied"`
LogFilePath string `yaml:"logFilePath"`
IPDatabaseCachePath string `yaml:"ipDatabaseCachePath"`
}

type ipEntry struct {
Expand Down Expand Up @@ -91,6 +90,7 @@ type GeoBlock struct {
redirectURLIfDenied string
name string
infoLogger *log.Logger
ipDatabasePersistence *CachePersist
}

// New created a new GeoBlock plugin.
Expand Down Expand Up @@ -130,26 +130,27 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
printConfiguration(name, config, infoLogger)
}

// create LRU cache for IP lookup
cache, err := lru.NewLRUCache(config.CacheSize)
if err != nil {
infoLogger.Fatal(err)
}

// create custom log target if needed
logFile, err := initializeLogFile(name, config.LogFilePath, infoLogger)
if err != nil {
infoLogger.Printf("%s: Error initializing log file: %v\n", name, err)
var logFile *os.File
if len(config.LogFilePath) > 0 {
logTarget, err := CreateCustomLogTarget(ctx, infoLogger, name, config.LogFilePath)
if err != nil {
infoLogger.Fatal(err)
}
logFile = logTarget
}

// Set up a goroutine to close the file when the context is done
if logFile != nil {
go func(logger *log.Logger) {
<-ctx.Done() // Wait for context cancellation
logger.SetOutput(os.Stdout)
logFile.Close()
logger.Printf("%s: Log file closed for middleware\n", name)
}(infoLogger)
// initialize local IP lookup cache
cacheOptions := Options{
CacheSize: config.CacheSize,
CachePath: config.IPDatabaseCachePath,
PersistInterval: defaultCacheWriteCycle,
Logger: infoLogger,
Name: name,
}
cache, ipDB, err := InitializeCache(ctx, cacheOptions)
if err != nil {
infoLogger.Fatal(err)
}

return &GeoBlock{
Expand Down Expand Up @@ -179,6 +180,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
redirectURLIfDenied: config.RedirectURLIfDenied,
name: name,
infoLogger: infoLogger,
ipDatabasePersistence: ipDB, // may be nil => feature OFF
}, nil
}

Expand Down Expand Up @@ -286,6 +288,8 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
}
} else {
entry = cacheEntry.(ipEntry)
// order has changed
a.ipDatabasePersistence.MarkDirty()
}

if a.logAPIRequests {
Expand Down Expand Up @@ -351,6 +355,8 @@ func (a *GeoBlock) cachedRequestIP(requestIPAddr *net.IP, req *http.Request) (bo
}
} else {
entry = cacheEntry.(ipEntry)
// order has changed
a.ipDatabasePersistence.MarkDirty()
}

if a.logAPIRequests {
Expand Down Expand Up @@ -414,6 +420,7 @@ func (a *GeoBlock) createNewIPEntry(req *http.Request, ipAddressString string) (

entry = ipEntry{Country: country, Timestamp: time.Now()}
a.database.Add(ipAddressString, entry)
a.ipDatabasePersistence.MarkDirty() // new entry in the cache

if a.logAPIRequests {
a.infoLogger.Printf("%s: [%s] added to database: %s", a.name, ipAddressString, entry)
Expand Down Expand Up @@ -638,44 +645,3 @@ func printConfiguration(name string, config *Config, logger *log.Logger) {
logger.Printf("%s: Redirect URL on denied requests: %s", name, config.RedirectURLIfDenied)
}
}

func initializeLogFile(name string, logFilePath string, logger *log.Logger) (*os.File, error) {
if len(logFilePath) == 0 {
return nil, nil
}

writeable, err := isFolder(logFilePath)
if err != nil {
logger.Printf("%s: %s", name, err)
return nil, err
} else if !writeable {
logger.Printf("%s: Specified log folder is not writeable: %s", name, logFilePath)
return nil, fmt.Errorf("%s: folder is not writeable: %s", name, logFilePath)
}

logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, filePermissions)
if err != nil {
logger.Printf("%s: Failed to open log file: %v\n", name, err)
return nil, err
}

logger.SetOutput(logFile)
return logFile, nil
}

func isFolder(filePath string) (bool, error) {
dirPath := filepath.Dir(filePath)
info, err := os.Stat(dirPath)
if err != nil {
if os.IsNotExist(err) {
return false, fmt.Errorf("path does not exist")
}
return false, fmt.Errorf("error checking path: %w", err)
}

if !info.IsDir() {
return false, fmt.Errorf("folder does not exist")
}

return true, nil
}
Loading