Skip to content

Commit

Permalink
Add support for conditional resets of data when merging
Browse files Browse the repository at this point in the history
  • Loading branch information
dhaavi committed Dec 5, 2023
1 parent 593b420 commit 2d3faed
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 7 deletions.
14 changes: 14 additions & 0 deletions config-example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ databases:
"autonomous_system_organization": string
"autonomous_system_number": uint32
inputs: # Source data and their mapping.
# Inputs are processed as listed. Earlier entries are overwritten by later entries.
- file: "input/asn-ipv4.csv"
fields: ["from", "to", "autonomous_system_number", "autonomous_system_organization"]
- file: "input/geo-whois-asn-country-ipv4.csv"
Expand All @@ -17,6 +18,12 @@ databases:
floatDecimals: 2 # Limit floats (eg. coordinates) to decimals for smaller DB size. (0=off, set to -1 to no decimals)
forceIPVersion: true # Check IPs and discard IPs with the wrong version. (IPv4 and live in IPv6 mmdb)
maxPrefix: 0 # Remove any network prefixes greater than maxPrefix for smaller DB size. (0=off)
merge:
conditionalResets: # Reset set of top level entries if another set is changed.
# Reset the location entry when the country is changed.
# If the new entry also has a location, it is kept, but a different country without location resets the location.
- ifChanged: ["country"]
reset: ["location"]

- name: "My IPv6 GeoIP DB"
mmdb:
Expand All @@ -27,6 +34,7 @@ databases:
"autonomous_system_organization": string
"autonomous_system_number": uint32
inputs: # Source data and their mapping.
# Inputs are processed as listed. Earlier entries are overwritten by later entries.
- file: "input/asn-ipv6.csv"
fields: ["from", "to", "autonomous_system_number", "autonomous_system_organization"]
- file: "input/geo-whois-asn-country-ipv6.csv"
Expand All @@ -36,3 +44,9 @@ databases:
floatDecimals: 2 # Limit floats (eg. coordinates) to decimals for smaller DB size. (0=off, set to -1 to no decimals)
forceIPVersion: true # Check IPs and discard IPs with the wrong version. (IPv4 and live in IPv6 mmdb)
maxPrefix: 0 # Remove any network prefixes greater than maxPrefix for smaller DB size. (0=off)
merge:
conditionalResets: # Reset set of top level entries if another set is changed.
# Reset the location entry when the country is changed.
# If the new entry also has a location, it is kept, but a different country without location resets the location.
- ifChanged: ["country"]
reset: ["location"]
12 changes: 12 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type DatabaseConfig struct {
Inputs []DatabaseInput `yaml:"inputs"`
Output string `yaml:"output"`
Optimize Optimizations `yaml:"optimize"`
Merge MergeConfig `yaml:"merge"`
}

// MMDBConfig holds mmdb specific config.
Expand All @@ -41,6 +42,17 @@ type Optimizations struct {
MaxPrefix int `yaml:"maxPrefix"`
}

// MergeConfig holds merge configuration.
type MergeConfig struct {
ConditionalResets []ConditionalResetConfig `yaml:"conditionalResets"`
}

// ConditionalResetConfig defines a conditional reset merge config.
type ConditionalResetConfig struct {
IfChanged []string `yaml:"ifChanged"`
Reset []string `yaml:"reset"`
}

// LoadConfig loads a configuration file.
func LoadConfig(filePath string) (*Config, error) {
data, err := os.ReadFile(filePath)
Expand Down
4 changes: 2 additions & 2 deletions source.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (se SourceEntry) ToMMDBMap(optim Optimizations) (mmdbtype.Map, error) {
// Transform value to mmdb type.
mmdbVal, err := entry.ToMMDBType(optim)
if err != nil {
return nil, fmt.Errorf("failed to transform %s with value %s (of type %s)", key, entry.Value, entry.Type)
return nil, fmt.Errorf("failed to transform %s with value %s (of type %s): %w", key, entry.Value, entry.Type, err)
}

// Get sub map for entry.
Expand All @@ -81,7 +81,7 @@ func (se SourceEntry) ToMMDBMap(optim Optimizations) (mmdbtype.Map, error) {
} else {
mapForEntry, ok = subMapVal.(mmdbtype.Map)
if !ok {
return nil, fmt.Errorf("submap %s already exists but is a %T, and not a map", strings.Join(keyParts[:1], "."), subMapVal)
return nil, fmt.Errorf("failed to transform %s: submap %s already exists but is a %T, and not a map", key, strings.Join(keyParts[:1], "."), subMapVal)
}
}
}
Expand Down
90 changes: 85 additions & 5 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/maxmind/mmdbwriter"
"github.com/maxmind/mmdbwriter/inserter"
"github.com/maxmind/mmdbwriter/mmdbtype"
"go4.org/netipx"
)

Expand Down Expand Up @@ -41,6 +42,10 @@ func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) e
dbConfig.Optimize.ForceIPVersion,
dbConfig.Optimize.MaxPrefix,
))
sendUpdate(updates, fmt.Sprintf(
"conditional resets: %+v",
dbConfig.Merge.ConditionalResets,
))

// Close update channel when finished.
if updates != nil {
Expand Down Expand Up @@ -75,7 +80,7 @@ func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) e

mmdbMap, err := entry.ToMMDBMap(dbConfig.Optimize)
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to transform %+v: %s", entry, err.Error()))
sendUpdate(updates, fmt.Sprintf("failed to convert %+v to mmdb map: %s", entry, err.Error()))
continue
}

Expand All @@ -95,7 +100,7 @@ func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) e
}
}

err = writer.InsertFunc(entry.Net, inserter.TopLevelMergeWith(mmdbMap))
err = writer.InsertFunc(entry.Net, ConditionalResetTopLevelMergeWith(mmdbMap, dbConfig.Merge.ConditionalResets))
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to insert %+v: %s", entry, err.Error()))
continue
Expand Down Expand Up @@ -127,7 +132,7 @@ func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) e
continue
}

err = writer.InsertFunc(netipx.PrefixIPNet(subnet), inserter.TopLevelMergeWith(mmdbMap))
err = writer.InsertFunc(netipx.PrefixIPNet(subnet), ConditionalResetTopLevelMergeWith(mmdbMap, dbConfig.Merge.ConditionalResets))
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to insert %+v: %s", entry, err.Error()))
continue
Expand Down Expand Up @@ -171,16 +176,91 @@ func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) e
fileSize = stat.Size()
}
sendUpdate(updates, fmt.Sprintf(
"---\n%s finished: inserted %d entries in %s, resulting in %dMB",
"---\n%s finished: inserted %d entries in %s, resulting in %.2f MB written to %s",
dbConfig.Name,
totalInserts,
time.Since(totalStartTime).Round(time.Second),
fileSize/1000000,
float64(fileSize)/1000000,
dbConfig.Output,
))

return nil
}

// ConditionalResetTopLevelMergeWith is based on TopLevelMergeWith,
// but conditionally resets fields as defined in the conditional reset config.
// Both the new and existing value must be a Map. An error will be returned
// otherwise.
func ConditionalResetTopLevelMergeWith(newValue mmdbtype.DataType, cfg []ConditionalResetConfig) inserter.Func {
return func(existingValue mmdbtype.DataType) (mmdbtype.DataType, error) {
// Check if both values are maps before we start merging.
newMap, ok := newValue.(mmdbtype.Map)
if !ok {
return nil, fmt.Errorf(
"the new value is a %T, not a Map; ConditionalResetTopLevelMerge only works if both values are Map values",
newValue,
)
}
if existingValue == nil {
return newValue, nil
}
existingMap, ok := existingValue.(mmdbtype.Map)
if !ok {
return nil, fmt.Errorf(
"the existing value is a %T, not a Map; ConditionalResetTopLevelMerge only works if both values are Map values",
existingValue,
)
}

// Start merging.

// First, do a normal top-level merge.
returnMap := existingMap.Copy().(mmdbtype.Map) //nolint:forcetypeassert
for k, v := range newMap {
returnMap[k] = v.Copy()
}

// Then check which fields changed.
for _, c := range cfg {
var changed bool
for _, key := range c.IfChanged {
// Get existing value.
existingSubVal, ok := existingMap[mmdbtype.String(key)]
if !ok {
// There is no existing value of that key, so there is no change possible.
continue
}
// Get new value
newSubVal, ok := newMap[mmdbtype.String(key)]
if !ok {
// Value of that key is not being set, so there is no change possible.
continue
}
// Compare values if both are set.
if !newSubVal.Equal(existingSubVal) {
changed = true
break
}
}
// If any field changed, reset fields.
if changed {
for _, key := range c.Reset {
resetVal, ok := newMap[mmdbtype.String(key)]
if ok {
// Reset with new value.
returnMap[mmdbtype.String(key)] = resetVal
} else {
// Remove if no new value is present.
delete(returnMap, mmdbtype.String(key))
}
}
}
}

return returnMap, nil
}
}

func sendUpdate(to chan string, msg string) {
if to == nil {
return
Expand Down

0 comments on commit 2d3faed

Please sign in to comment.