Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve config import and export utils #219

Merged
merged 10 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/gorilla/mux"

"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
Expand Down Expand Up @@ -461,7 +461,11 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var v interface{}
v, err = e.StructFunc(apiRequest)
if err == nil && v != nil {
responseData, err = json.Marshal(v)
var mimeType string
responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept"))
if err == nil {
w.Header().Set("Content-Type", mimeType)
}
}

case e.RecordFunc != nil:
Expand All @@ -482,7 +486,6 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Check for handler error.
if err != nil {
// if statusProvider, ok := err.(HTTPStatusProvider); ok {
var statusProvider HTTPStatusProvider
if errors.As(err, &statusProvider) {
http.Error(w, err.Error(), statusProvider.HTTPStatus())
Expand All @@ -498,8 +501,12 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// Set content type if not yet set.
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
}

// Write response.
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(responseData)
Expand Down
4 changes: 2 additions & 2 deletions config/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func parseAndReplaceConfig(jsonData string) error {
return err
}

validationErrors := replaceConfig(m)
validationErrors, _ := ReplaceConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}
Expand All @@ -27,7 +27,7 @@ func parseAndReplaceDefaultConfig(jsonData string) error {
return err
}

validationErrors := replaceDefaultConfig(m)
validationErrors, _ := ReplaceDefaultConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}
Expand Down
2 changes: 1 addition & 1 deletion config/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func start() error {

err = loadConfig(false)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
return fmt.Errorf("failed to load config file: %w", err)
}
return nil
}
Expand Down
48 changes: 45 additions & 3 deletions config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"fmt"
"reflect"
"regexp"
"sync"

Expand Down Expand Up @@ -108,11 +109,13 @@ const (
// requirement. The type of RequiresAnnotation is []ValueRequirement
// or ValueRequirement.
RequiresAnnotation = "safing/portbase:config:requires"
// RequiresFeaturePlan can be used to mark a setting as only available
// RequiresFeatureIDAnnotation can be used to mark a setting as only available
// when the user has a certain feature ID in the subscription plan.
// The type is []string or string.
RequiresFeatureID = "safing/portmaster:ui:config:requires-feature"

RequiresFeatureIDAnnotation = "safing/portmaster:ui:config:requires-feature"
// SettablePerAppAnnotation can be used to mark a setting as settable per-app and
// is a boolean.
SettablePerAppAnnotation = "safing/portmaster:settable-per-app"
// RequiresUIReloadAnnotation can be used to inform the UI that changing the value
// of the annotated setting requires a full reload of the user interface.
// The value of this annotation does not matter as the sole presence of
Expand Down Expand Up @@ -308,6 +311,22 @@ func (option *Option) GetAnnotation(key string) (interface{}, bool) {
return val, ok
}

// AnnotationEquals returns whether the annotation of the given key matches the
// given value.
func (option *Option) AnnotationEquals(key string, value any) bool {
option.Lock()
defer option.Unlock()

if option.Annotations == nil {
return false
}
setValue, ok := option.Annotations[key]
if !ok {
return false
}
return reflect.DeepEqual(value, setValue)
}

// copyOrNil returns a copy of the option, or nil if copying failed.
func (option *Option) copyOrNil() *Option {
copied, err := copystructure.Copy(option)
Expand All @@ -325,6 +344,29 @@ func (option *Option) IsSetByUser() bool {
return option.activeValue != nil
}

// UserValue returns the value set by the user or nil if the value has not
// been changed from the default.
func (option *Option) UserValue() any {
option.Lock()
defer option.Unlock()

if option.activeValue == nil {
return nil
}
return option.activeValue.getData(option)
}

// ValidateValue checks if the given value is valid for the option.
func (option *Option) ValidateValue(value any) error {
option.Lock()
defer option.Unlock()

if _, err := validateValue(option, value); err != nil {
return err
}
return nil
}

// Export expors an option to a Record.
func (option *Option) Export() (record.Record, error) {
option.Lock()
Expand Down
2 changes: 1 addition & 1 deletion config/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func loadConfig(requireValidConfig bool) error {
return err
}

validationErrors := replaceConfig(newValues)
validationErrors, _ := ReplaceConfig(newValues)
if requireValidConfig && len(validationErrors) > 0 {
return fmt.Errorf("encountered %d validation errors during config loading", len(validationErrors))
}
Expand Down
110 changes: 76 additions & 34 deletions config/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,70 +37,112 @@ func signalChanges() {
module.TriggerEvent(ChangeEvent, nil)
}

// replaceConfig sets the (prioritized) user defined config.
func replaceConfig(newValues map[string]interface{}) []*ValidationError {
var validationErrors []*ValidationError
// ValidateConfig validates the given configuration and returns all validation
// errors as well as whether the given configuration contains unknown keys.
func ValidateConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool, containsUnknown bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only checking the
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()

var checked int
for key, option := range options {
newValue, ok := newValues[key]
if ok {
checked++

func() {
option.Lock()
defer option.Unlock()

_, err := validateValue(option, newValue)
if err != nil {
validationErrors = append(validationErrors, err)
}

if option.RequiresRestart {
requiresRestart = true
}
}()
}
}

return validationErrors, requiresRestart, checked < len(newValues)
}

// ReplaceConfig sets the (prioritized) user defined config.
func ReplaceConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()

for key, option := range options {
newValue, ok := newValues[key]

option.Lock()
option.activeValue = nil

if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
validationErrors = append(validationErrors, err)
func() {
option.Lock()
defer option.Unlock()

option.activeValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
}
handleOptionUpdate(option, true)

handleOptionUpdate(option, true)
option.Unlock()
if option.RequiresRestart {
requiresRestart = true
}
}()
}

signalChanges()

return validationErrors
return validationErrors, requiresRestart
}

// replaceDefaultConfig sets the (fallback) default config.
func replaceDefaultConfig(newValues map[string]interface{}) []*ValidationError {
var validationErrors []*ValidationError

// ReplaceDefaultConfig sets the (fallback) default config.
func ReplaceDefaultConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()

for key, option := range options {
newValue, ok := newValues[key]

option.Lock()
option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
validationErrors = append(validationErrors, err)
func() {
option.Lock()
defer option.Unlock()

option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
}
handleOptionUpdate(option, true)
option.Unlock()
handleOptionUpdate(option, true)

if option.RequiresRestart {
requiresRestart = true
}
}()
}

signalChanges()

return validationErrors
return validationErrors, requiresRestart
}

// SetConfigOption sets a single value in the (prioritized) user defined config.
Expand Down
2 changes: 1 addition & 1 deletion config/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestLayersGetters(t *testing.T) { //nolint:paralleltest
t.Fatal(err)
}

validationErrors := replaceConfig(mapData)
validationErrors, _ := ReplaceConfig(mapData)
if len(validationErrors) > 0 {
t.Fatalf("%d errors, first: %s", len(validationErrors), validationErrors[0].Error())
}
Expand Down
12 changes: 12 additions & 0 deletions formats/dsd/dsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"

"github.com/fxamacker/cbor/v2"
"github.com/ghodss/yaml"
"github.com/vmihailenco/msgpack/v5"

"github.com/safing/portbase/formats/varint"
Expand Down Expand Up @@ -41,6 +42,12 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (err error) {
return fmt.Errorf("dsd: failed to unpack json: %w, data: %s", err, utils.SafeFirst16Bytes(data))
}
return nil
case YAML:
err = yaml.Unmarshal(data, t)
if err != nil {
return fmt.Errorf("dsd: failed to unpack yaml: %w, data: %s", err, utils.SafeFirst16Bytes(data))
}
return nil
case CBOR:
err = cbor.Unmarshal(data, t)
if err != nil {
Expand Down Expand Up @@ -121,6 +128,11 @@ func dumpWithoutIdentifier(t interface{}, format uint8, indent string) ([]byte,
if err != nil {
return nil, err
}
case YAML:
data, err = yaml.Marshal(t)
if err != nil {
return nil, err
}
case CBOR:
data, err = cbor.Marshal(t)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions formats/dsd/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
GenCode = 71 // G
JSON = 74 // J
MsgPack = 77 // M
YAML = 89 // Y

// Compression types.
GZIP = 90 // Z
Expand Down Expand Up @@ -48,6 +49,8 @@ func ValidateSerializationFormat(format uint8) (validatedFormat uint8, ok bool)
return format, true
case JSON:
return format, true
case YAML:
return format, true
case MsgPack:
return format, true
default:
Expand Down
Loading
Loading