Skip to content

Commit

Permalink
Add restriction on outputing structures that cannot be decoded, fix t…
Browse files Browse the repository at this point in the history
…ests to match
  • Loading branch information
NHAS committed Nov 10, 2024
1 parent 4cb4e38 commit 7e1724e
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 39 deletions.
21 changes: 15 additions & 6 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,19 @@ func GetGeneratedCliFlags[T any](delimiter string) []string {
panic("GetGeneratedEnv(...) only supports configs of Struct type")
}

o := options{}
FromCli(delimiter)(&o)
cp := newCliLoader[T](&o)

var result []string
for _, field := range getFields(true, &a) {
result = append(result, strings.Join(resolvePath(&a, field.path), delimiter))

cliName, ok := determineVariableName(&a, cp.o.cli.delimiter, nil, field)
if !ok {
continue
}

result = append(result, cliName)
}

return result
Expand Down Expand Up @@ -311,11 +321,10 @@ func (cp *ciParser[T]) apply(result *T) (somethingSet bool, err error) {
field.value = field.value.Elem()
}

// if this changes update LoadCli
flagName := strings.Join(resolvePath(dummyCopy, field.path), cp.o.cli.delimiter)
if cp.o.cli.transform != nil {
flagName = cp.o.cli.transform(flagName)
logger.Info("using transform func on cli flag", "before_func", strings.Join(resolvePath(dummyCopy, field.path), cp.o.cli.delimiter), "after_func", flagName)
flagName, ok := determineVariableName(result, cp.o.cli.delimiter, cp.o.cli.transform, field)
if !ok {
// logging done in determine variable
continue
}

logger.Info("resolved confy path", "resolved_path", flagName, "path", strings.Join(field.path, cp.o.cli.delimiter))
Expand Down
1 change: 0 additions & 1 deletion cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ func TestCliHelperMethod(t *testing.T) {

expectedContents := []string{
"Thing",
"Nested",
"Nested.NestedVal",
}

Expand Down
3 changes: 1 addition & 2 deletions entry_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package confy

import (
"log/slog"
"os"
"testing"
)

func TestMain(m *testing.M) {

level.Set(slog.LevelDebug)
level.Set(LoggingDisabled)
code := m.Run()

os.Exit(code)
Expand Down
43 changes: 14 additions & 29 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,19 @@ func GetGeneratedEnv[T any](delimiter string) []string {
panic("GetGeneratedEnv(...) only supports configs of Struct type")
}

o := options{}
FromEnvs(delimiter)(&o)
ep := newEnvLoader[T](&o)

var result []string
for _, field := range getFields(true, &a) {
result = append(result, strings.Join(resolvePath(&a, field.path), delimiter))

envVariable, ok := determineVariableName(&a, ep.o.env.delimiter, nil, field)
if !ok {
continue
}

result = append(result, envVariable)
}

return result
Expand All @@ -76,34 +86,9 @@ func GetGeneratedEnvWithTransform[T any](delimiter string, transformFunc Transfo
func (ep *envParser[T]) apply(result *T) (somethingSet bool, err error) {

for _, field := range getFields(true, result) {
// Update GetGeneratedEnv if this changes
envVariable := strings.Join(resolvePath(result, field.path), ep.o.env.delimiter)
if ep.o.env.transform != nil {
envVariable = ep.o.env.transform(envVariable)
logger.Info("using transform func on env variable", "before_func", strings.Join(resolvePath(result, field.path), ep.o.env.delimiter), "after_func", envVariable)
}

if field.value.Kind() == reflect.Struct {
current := field.value
_, ok := current.Addr().Interface().(encoding.TextUnmarshaler)
if !ok {
logger.Warn("type doesnt implement encoding.TextUnmarshaler skipping looking for an ENV variable for it", "path", strings.Join(field.path, ep.o.env.delimiter))
continue
}
}

if field.value.Kind() == reflect.Array || field.value.Kind() == reflect.Slice {
sliceContentType := field.value.Type().Elem()

switch sliceContentType.Kind() {
case reflect.String, reflect.Int, reflect.Int64, reflect.Float64, reflect.Bool:
default:
inter := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
if !reflect.PointerTo(sliceContentType).Implements(inter) {
logger.Warn("type inside of complex slice did not implement encoding.TextUnmarshaler", "path", strings.Join(field.path, ep.o.env.delimiter))
continue
}
}
envVariable, ok := determineVariableName(result, ep.o.env.delimiter, ep.o.env.transform, field)
if !ok {
continue
}

value, wasSet := os.LookupEnv(envVariable)
Expand Down
1 change: 0 additions & 1 deletion env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func TestEnvHelperMethod(t *testing.T) {

expectedContents := []string{
"Thing",
"Nested",
"Nested_NestedVal",
}

Expand Down
36 changes: 36 additions & 0 deletions reflection_utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package confy

import (
"encoding"
"reflect"
"strings"
)
Expand Down Expand Up @@ -158,3 +159,38 @@ func maskSensitive(value string, tag reflect.StructTag) string {

return printedValue
}

// determineVariableName returns the variable name after resolving and transforming
// ok: bool indicates whether this is a decodable type
func determineVariableName[T any](result *T, delimiter string, transform Transform, field fieldsData) (string, bool) {
variable := strings.Join(resolvePath(result, field.path), delimiter)
if transform != nil {
variable = transform(variable)
logger.Info("using transform func on variable", "before_func", strings.Join(resolvePath(result, field.path), delimiter), "after_func", variable)
}

if field.value.Kind() == reflect.Struct {
current := field.value
_, ok := current.Addr().Interface().(encoding.TextUnmarshaler)
if !ok {
logger.Warn("type doesnt implement encoding.TextUnmarshaler skipping looking for an ENV variable for it", "path", strings.Join(field.path, delimiter))
return "", false
}
}

if field.value.Kind() == reflect.Array || field.value.Kind() == reflect.Slice {
sliceContentType := field.value.Type().Elem()

switch sliceContentType.Kind() {
case reflect.String, reflect.Int, reflect.Int64, reflect.Float64, reflect.Bool:
default:
inter := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
if !reflect.PointerTo(sliceContentType).Implements(inter) {
logger.Warn("type inside of complex slice did not implement encoding.TextUnmarshaler", "path", strings.Join(field.path, delimiter))
return "", false
}
}
}

return variable, true
}

0 comments on commit 7e1724e

Please sign in to comment.