Skip to content

Commit 6d4a7a2

Browse files
author
Pascal Minder
committed
clean up code
1 parent de90b13 commit 6d4a7a2

File tree

6 files changed

+377
-60
lines changed

6 files changed

+377
-60
lines changed

cachePersistence.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package geoblock
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"log"
7+
"os"
8+
"path/filepath"
9+
"sync/atomic"
10+
"time"
11+
12+
lru "github.com/PascalMinder/geoblock/lrucache"
13+
)
14+
15+
func InitializeCache(ctx context.Context, logger *log.Logger, name string, cacheSize int, cachePath string) (*lru.LRUCache, *CachePersist) {
16+
cache, err := lru.NewLRUCache(cacheSize)
17+
if err != nil {
18+
logger.Fatal(err)
19+
}
20+
21+
var ipDB *CachePersist // stays nil if disabled
22+
if path, err := ValidatePersistencePath(cachePath); len(path) > 0 {
23+
// load existing cache
24+
if err := cache.ImportFromFile(path); err != nil && !os.IsNotExist(err) {
25+
logger.Printf("%s: could not load IP DB snapshot (%s): %v", name, path, err)
26+
}
27+
28+
ipDB = &CachePersist{
29+
path: path,
30+
persistTicker: time.NewTicker(15 * time.Second),
31+
persistChannel: make(chan struct{}, 1),
32+
cache: cache,
33+
log: logger,
34+
name: name,
35+
}
36+
37+
go func(ctx context.Context, p *CachePersist) {
38+
defer p.persistTicker.Stop()
39+
for {
40+
select {
41+
case <-ctx.Done():
42+
p.snapshotToDisk()
43+
return
44+
case <-p.persistTicker.C:
45+
p.snapshotToDisk()
46+
case <-p.persistChannel:
47+
p.snapshotToDisk()
48+
}
49+
}
50+
}(ctx, ipDB)
51+
52+
logger.Printf("%s: IP database persistence enabled -> %s", name, path)
53+
} else if err != nil {
54+
logger.Printf("%s: IP database persistence disabled: %v", name, err)
55+
} else {
56+
logger.Printf("%s: IP database persistence disabled (no path)", name)
57+
}
58+
59+
return cache, ipDB
60+
}
61+
62+
type CachePersist struct {
63+
path string
64+
persistTicker *time.Ticker
65+
persistChannel chan struct{}
66+
ipDirty uint32 // 0 clean, 1 dirty
67+
68+
cache *lru.LRUCache
69+
log *log.Logger
70+
name string
71+
}
72+
73+
func (p *CachePersist) MarkDirty() {
74+
if p == nil { // feature OFF
75+
return
76+
}
77+
atomic.StoreUint32(&p.ipDirty, 1)
78+
select {
79+
case p.persistChannel <- struct{}{}:
80+
default:
81+
}
82+
}
83+
84+
func (p *CachePersist) snapshotToDisk() {
85+
if p == nil || atomic.LoadUint32(&p.ipDirty) == 0 {
86+
return
87+
}
88+
89+
var buf bytes.Buffer
90+
if err := p.cache.Export(&buf); err != nil {
91+
p.log.Printf("%s: cache snapshot encode error: %v", p.name, err)
92+
return
93+
}
94+
95+
dir := filepath.Dir(p.path)
96+
tmp, err := os.CreateTemp(dir, "ipdb-*.tmp")
97+
if err != nil {
98+
p.log.Printf("%s: snapshot temp file error: %v", p.name, err)
99+
return
100+
}
101+
tmpPath := tmp.Name()
102+
103+
if _, err := tmp.Write(buf.Bytes()); err != nil {
104+
tmp.Close()
105+
os.Remove(tmpPath)
106+
p.log.Printf("%s: snapshot write error: %v", p.name, err)
107+
return
108+
}
109+
if err := tmp.Sync(); err != nil {
110+
tmp.Close()
111+
os.Remove(tmpPath)
112+
p.log.Printf("%s: snapshot fsync error: %v", p.name, err)
113+
return
114+
}
115+
if err := tmp.Close(); err != nil {
116+
os.Remove(tmpPath)
117+
p.log.Printf("%s: snapshot close error: %v", p.name, err)
118+
return
119+
}
120+
if err := os.Rename(tmpPath, p.path); err != nil {
121+
os.Remove(tmpPath)
122+
p.log.Printf("%s: snapshot rename error: %v", p.name, err)
123+
return
124+
}
125+
126+
atomic.StoreUint32(&p.ipDirty, 0)
127+
p.log.Printf("%s: cache snapshot written to %s", p.name, p.path)
128+
}

docker/traefik-config/dynamic-configuration.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ http:
3131
allowUnknownCountries: false
3232
unknownCountryApiResponse: "nil"
3333
logFilePath: "/geoblock/geoblockA.log"
34+
ipDatabaseCachePath: "/geoblock/cacheA.cache"
3435
countries:
3536
- GB
3637
- IS

geoblock.go

Lines changed: 13 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"net"
1111
"net/http"
1212
"os"
13-
"path/filepath"
1413
"strings"
1514
"time"
1615

@@ -49,8 +48,9 @@ type Config struct {
4948
AllowedIPAddresses []string `yaml:"allowedIPAddresses,omitempty"`
5049
AddCountryHeader bool `yaml:"addCountryHeader"`
5150
HTTPStatusCodeDeniedRequest int `yaml:"httpStatusCodeDeniedRequest"`
52-
LogFilePath string `yaml:"logFilePath"`
5351
RedirectURLIfDenied string `yaml:"redirectUrlIfDenied"`
52+
LogFilePath string `yaml:"logFilePath"`
53+
IPDatabaseCachePath string `yaml:"ipDatabaseCachePath"`
5454
}
5555

5656
type ipEntry struct {
@@ -91,6 +91,7 @@ type GeoBlock struct {
9191
redirectURLIfDenied string
9292
name string
9393
infoLogger *log.Logger
94+
ipDatabasePersistence *CachePersist
9495
}
9596

9697
// New created a new GeoBlock plugin.
@@ -130,27 +131,14 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
130131
printConfiguration(name, config, infoLogger)
131132
}
132133

133-
// create LRU cache for IP lookup
134-
cache, err := lru.NewLRUCache(config.CacheSize)
135-
if err != nil {
136-
infoLogger.Fatal(err)
137-
}
138-
139134
// create custom log target if needed
140-
logFile, err := initializeLogFile(name, config.LogFilePath, infoLogger)
135+
logFile, err := CreateCustomLogTarget(ctx, infoLogger, name, config.LogFilePath)
141136
if err != nil {
142-
infoLogger.Printf("%s: Error initializing log file: %v\n", name, err)
137+
infoLogger.Fatal(err)
143138
}
144139

145-
// Set up a goroutine to close the file when the context is done
146-
if logFile != nil {
147-
go func(logger *log.Logger) {
148-
<-ctx.Done() // Wait for context cancellation
149-
logger.SetOutput(os.Stdout)
150-
logFile.Close()
151-
logger.Printf("%s: Log file closed for middleware\n", name)
152-
}(infoLogger)
153-
}
140+
// initialize local IP lookup cache
141+
cache, ipDB := InitializeCache(ctx, infoLogger, name, config.CacheSize, config.IPDatabaseCachePath)
154142

155143
return &GeoBlock{
156144
next: next,
@@ -179,6 +167,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
179167
redirectURLIfDenied: config.RedirectURLIfDenied,
180168
name: name,
181169
infoLogger: infoLogger,
170+
ipDatabasePersistence: ipDB, // may be nil => feature OFF
182171
}, nil
183172
}
184173

@@ -286,6 +275,8 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
286275
}
287276
} else {
288277
entry = cacheEntry.(ipEntry)
278+
// order has changed
279+
a.ipDatabasePersistence.MarkDirty()
289280
}
290281

291282
if a.logAPIRequests {
@@ -351,6 +342,8 @@ func (a *GeoBlock) cachedRequestIP(requestIPAddr *net.IP, req *http.Request) (bo
351342
}
352343
} else {
353344
entry = cacheEntry.(ipEntry)
345+
// order has changed
346+
a.ipDatabasePersistence.MarkDirty()
354347
}
355348

356349
if a.logAPIRequests {
@@ -414,6 +407,7 @@ func (a *GeoBlock) createNewIPEntry(req *http.Request, ipAddressString string) (
414407

415408
entry = ipEntry{Country: country, Timestamp: time.Now()}
416409
a.database.Add(ipAddressString, entry)
410+
a.ipDatabasePersistence.MarkDirty() // new entry in the cache
417411

418412
if a.logAPIRequests {
419413
a.infoLogger.Printf("%s: [%s] added to database: %s", a.name, ipAddressString, entry)
@@ -638,44 +632,3 @@ func printConfiguration(name string, config *Config, logger *log.Logger) {
638632
logger.Printf("%s: Redirect URL on denied requests: %s", name, config.RedirectURLIfDenied)
639633
}
640634
}
641-
642-
func initializeLogFile(name string, logFilePath string, logger *log.Logger) (*os.File, error) {
643-
if len(logFilePath) == 0 {
644-
return nil, nil
645-
}
646-
647-
writeable, err := isFolder(logFilePath)
648-
if err != nil {
649-
logger.Printf("%s: %s", name, err)
650-
return nil, err
651-
} else if !writeable {
652-
logger.Printf("%s: Specified log folder is not writeable: %s", name, logFilePath)
653-
return nil, fmt.Errorf("%s: folder is not writeable: %s", name, logFilePath)
654-
}
655-
656-
logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, filePermissions)
657-
if err != nil {
658-
logger.Printf("%s: Failed to open log file: %v\n", name, err)
659-
return nil, err
660-
}
661-
662-
logger.SetOutput(logFile)
663-
return logFile, nil
664-
}
665-
666-
func isFolder(filePath string) (bool, error) {
667-
dirPath := filepath.Dir(filePath)
668-
info, err := os.Stat(dirPath)
669-
if err != nil {
670-
if os.IsNotExist(err) {
671-
return false, fmt.Errorf("path does not exist")
672-
}
673-
return false, fmt.Errorf("error checking path: %w", err)
674-
}
675-
676-
if !info.IsDir() {
677-
return false, fmt.Errorf("folder does not exist")
678-
}
679-
680-
return true, nil
681-
}

logPersistence.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package geoblock
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
"os"
8+
)
9+
10+
func CreateCustomLogTarget(ctx context.Context, logger *log.Logger, name string, logFilePath string) (*os.File, error) {
11+
logFile, err := initializeLogFile(logFilePath, logger)
12+
if err != nil {
13+
return nil, fmt.Errorf("error initializing log file: %v", err)
14+
}
15+
16+
// Set up a goroutine to close the file when the context is done
17+
if logFile != nil {
18+
go func(logger *log.Logger) {
19+
<-ctx.Done() // Wait for context cancellation
20+
logger.SetOutput(os.Stdout)
21+
logFile.Close()
22+
logger.Printf("%s: Log file closed for middleware\n", name)
23+
}(logger)
24+
}
25+
26+
logger.Printf("%s: Log file opened for middleware\n", name)
27+
return logFile, nil
28+
}
29+
30+
func initializeLogFile(logFilePath string, logger *log.Logger) (*os.File, error) {
31+
if len(logFilePath) == 0 {
32+
return nil, nil
33+
}
34+
35+
path, err := ValidatePersistencePath(logFilePath)
36+
if err != nil {
37+
return nil, err
38+
} else if len(path) == 0 {
39+
return nil, fmt.Errorf("folder is not writeable (%s)", path)
40+
}
41+
42+
logFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, filePermissions)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to open log file (%s): %v", path, err)
45+
}
46+
47+
logger.SetOutput(logFile)
48+
return logFile, nil
49+
}

0 commit comments

Comments
 (0)