diff --git a/docker-compose.yml b/docker-compose.yml index 9eca54be12..d1a98be87e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,13 @@ version: "3.8" services: + nats: + image: nats:2.9-alpine + container_name: nats_server + ports: + - "4222:4222" + - "8222:8222" + postgres: image: postgres:15 ports: diff --git a/go.mod b/go.mod index 9ddf73b3b4..a87bb52d50 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/FZambia/statik v0.1.2-0.20180217151304-b9f012bb2a1b github.com/FZambia/tarantool v0.3.1 github.com/FZambia/viper-lite v0.0.0-20220110144934-1899f66c7d0e - github.com/centrifugal/centrifuge v0.32.3-0.20240619053500-4023c34a5ae5 - github.com/centrifugal/protocol v0.13.3 + github.com/centrifugal/centrifuge v0.32.3-0.20240703050444-a08d816282a0 + github.com/centrifugal/protocol v0.13.4-0.20240702174651-e8db704aa2d2 github.com/cristalhq/jwt/v5 v5.4.0 github.com/gobwas/glob v0.2.3 github.com/google/uuid v1.6.0 @@ -43,7 +43,7 @@ require ( golang.org/x/sync v0.7.0 golang.org/x/time v0.5.0 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 + google.golang.org/protobuf v1.34.2 ) require ( @@ -88,7 +88,7 @@ require ( github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect - github.com/redis/rueidis v1.0.38 // indirect + github.com/redis/rueidis v1.0.40 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/segmentio/encoding v0.4.0 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/go.sum b/go.sum index 2332997e67..d7365605d3 100644 --- a/go.sum +++ b/go.sum @@ -12,10 +12,10 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/centrifugal/centrifuge v0.32.3-0.20240619053500-4023c34a5ae5 h1:V0pY6N7/HeulbQtUUVIFjWrMrn5krzLPZ2cGBR7smZE= -github.com/centrifugal/centrifuge v0.32.3-0.20240619053500-4023c34a5ae5/go.mod h1:ELAYx5oUb/E42IAMlMAvd3Zl4lZTjdnZAa0nXRetkF4= -github.com/centrifugal/protocol v0.13.3 h1:Ryt5uIYCz5wpJOHc0+L2dC1ty2OQzwdU4TV3pmPOfnA= -github.com/centrifugal/protocol v0.13.3/go.mod h1:7V5vI30VcoxJe4UD87xi7bOsvI0bmEhvbQuMjrFM2L4= +github.com/centrifugal/centrifuge v0.32.3-0.20240703050444-a08d816282a0 h1:v7UG9tQc9nLmlKKNY5r45nnYCw080zGr44y9XA+NX+w= +github.com/centrifugal/centrifuge v0.32.3-0.20240703050444-a08d816282a0/go.mod h1:b3areBhskWFOpJdYOOHeIoChKelRlZkxC00Lo8aFCJo= +github.com/centrifugal/protocol v0.13.4-0.20240702174651-e8db704aa2d2 h1:U339eI0wzXpO5gRF43br85ZIxw061GHP0SZ7rjwOxiY= +github.com/centrifugal/protocol v0.13.4-0.20240702174651-e8db704aa2d2/go.mod h1:7V5vI30VcoxJe4UD87xi7bOsvI0bmEhvbQuMjrFM2L4= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -158,8 +158,8 @@ github.com/quic-go/webtransport-go v0.8.0 h1:HxSrwun11U+LlmwpgM1kEqIqH90IT4N8auv github.com/quic-go/webtransport-go v0.8.0/go.mod h1:N99tjprW432Ut5ONql/aUhSLT0YVSlwHohQsuac9WaM= github.com/rakutentech/jwk-go v1.1.3 h1:PiLwepKyUaW+QFG3ki78DIO2+b4IVK3nMhlxM70zrQ4= github.com/rakutentech/jwk-go v1.1.3/go.mod h1:LtzSv4/+Iti1nnNeVQiP6l5cI74GBStbhyXCYvgPZFk= -github.com/redis/rueidis v1.0.38 h1:ZlEBumHM+ECCMgf/zQZImLfmxb/sxGKnBP0R0AxoH/Y= -github.com/redis/rueidis v1.0.38/go.mod h1:bnbkk4+CkXZgDPEbUtSos/o55i4RhFYYesJ4DS2zmq0= +github.com/redis/rueidis v1.0.40 h1:zoC+GUTdNHhP7ZHrnMiIDcP16DUEVcxsPThQsvD7yDg= +github.com/redis/rueidis v1.0.40/go.mod h1:bnbkk4+CkXZgDPEbUtSos/o55i4RhFYYesJ4DS2zmq0= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -299,8 +299,8 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/envconfig/LICENSE b/internal/envconfig/LICENSE new file mode 100644 index 0000000000..4bfa7a84d8 --- /dev/null +++ b/internal/envconfig/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013 Kelsey Hightower + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/internal/envconfig/MAINTAINERS b/internal/envconfig/MAINTAINERS new file mode 100644 index 0000000000..6527a9f2cc --- /dev/null +++ b/internal/envconfig/MAINTAINERS @@ -0,0 +1,2 @@ +Kelsey Hightower kelsey.hightower@gmail.com github.com/kelseyhightower +Travis Parker travis.parker@gmail.com github.com/teepark diff --git a/internal/envconfig/README.md b/internal/envconfig/README.md new file mode 100644 index 0000000000..2a1bc67cc8 --- /dev/null +++ b/internal/envconfig/README.md @@ -0,0 +1,24 @@ +# envconfig + +This is a fork of https://github.com/kelseyhightower/envconfig, original license left unchanged. + +First reason is this issue: https://github.com/kelseyhightower/envconfig/issues/148. + +Basically, we are not using ALT names here, so the main change is: + +``` +name := strings.ToUpper(ftype.Tag.Get("envconfig")) +if name == "" { + name = ftype.Name +} + +// Capture information about the config variable +info := varInfo{ + Name: name, + Field: f, + Tags: ftype.Tag, + //Alt: strings.ToUpper(ftype.Tag.Get("envconfig")), +} +``` + +The second reason of fork is that we use exported `VarInfo` here instead of `varInfo`. This helps Centrifugo to find unknown configuration options. diff --git a/internal/envconfig/doc.go b/internal/envconfig/doc.go new file mode 100644 index 0000000000..f28561cd1c --- /dev/null +++ b/internal/envconfig/doc.go @@ -0,0 +1,8 @@ +// Copyright (c) 2013 Kelsey Hightower. All rights reserved. +// Use of this source code is governed by the MIT License that can be found in +// the LICENSE file. + +// Package envconfig implements decoding of environment variables based on a user +// defined specification. A typical use is using environment variables for +// configuration settings. +package envconfig diff --git a/internal/envconfig/env_os.go b/internal/envconfig/env_os.go new file mode 100644 index 0000000000..ca1243b7be --- /dev/null +++ b/internal/envconfig/env_os.go @@ -0,0 +1,7 @@ +//go:build appengine || go1.5 + +package envconfig + +import "os" + +var lookupEnv = os.LookupEnv diff --git a/internal/envconfig/env_syscall.go b/internal/envconfig/env_syscall.go new file mode 100644 index 0000000000..221ff7fadf --- /dev/null +++ b/internal/envconfig/env_syscall.go @@ -0,0 +1,8 @@ +//go:build !appengine && !go1.5 +// +build !appengine,!go1.5 + +package envconfig + +import "syscall" + +var lookupEnv = syscall.Getenv diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go new file mode 100644 index 0000000000..d0da32fccd --- /dev/null +++ b/internal/envconfig/envconfig.go @@ -0,0 +1,386 @@ +// Copyright (c) 2013 Kelsey Hightower. All rights reserved. +// Use of this source code is governed by the MIT License that can be found in +// the LICENSE file. + +package envconfig + +import ( + "encoding" + "errors" + "fmt" + "os" + "reflect" + "regexp" + "strconv" + "strings" + "time" +) + +// ErrInvalidSpecification indicates that a specification is of the wrong type. +var ErrInvalidSpecification = errors.New("specification must be a struct pointer") + +var gatherRegexp = regexp.MustCompile("([^A-Z]+|[A-Z]+[^A-Z]+|[A-Z]+)") +var acronymRegexp = regexp.MustCompile("([A-Z]+)([A-Z][^A-Z]+)") + +// A ParseError occurs when an environment variable cannot be converted to +// the type required by a struct field during assignment. +type ParseError struct { + KeyName string + FieldName string + TypeName string + Value string + Err error +} + +// Decoder has the same semantics as Setter, but takes higher precedence. +// It is provided for historical compatibility. +type Decoder interface { + Decode(value string) error +} + +// Setter is implemented by types can self-deserialize values. +// Any type that implements flag.Value also implements Setter. +type Setter interface { + Set(value string) error +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("envconfig.Process: assigning %[1]s to %[2]s: converting '%[3]s' to type %[4]s. details: %[5]s", e.KeyName, e.FieldName, e.Value, e.TypeName, e.Err) +} + +// VarInfo maintains information about the configuration variable +type VarInfo struct { + Name string + Alt string + Key string + Field reflect.Value + Tags reflect.StructTag +} + +// GatherInfo gathers information about the specified struct +func gatherInfo(prefix string, spec interface{}) ([]VarInfo, error) { + s := reflect.ValueOf(spec) + + if s.Kind() != reflect.Ptr { + return nil, ErrInvalidSpecification + } + s = s.Elem() + if s.Kind() != reflect.Struct { + return nil, ErrInvalidSpecification + } + typeOfSpec := s.Type() + + // over allocate an info array, we will extend if needed later + infos := make([]VarInfo, 0, s.NumField()) + for i := 0; i < s.NumField(); i++ { + f := s.Field(i) + ftype := typeOfSpec.Field(i) + if !f.CanSet() || isTrue(ftype.Tag.Get("ignored")) { + continue + } + + for f.Kind() == reflect.Ptr { + if f.IsNil() { + if f.Type().Elem().Kind() != reflect.Struct { + // nil pointer to a non-struct: leave it alone + break + } + // nil pointer to struct: create a zero instance + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + + name := strings.ToUpper(ftype.Tag.Get("envconfig")) + if name == "" { + name = ftype.Name + } + + // Capture information about the config variable + info := VarInfo{ + Name: name, + Field: f, + Tags: ftype.Tag, + //Alt: strings.ToUpper(ftype.Tag.Get("envconfig")), + } + + // Default to the field name as the env var name (will be upcased) + info.Key = info.Name + + // Best effort to un-pick camel casing as separate words + if isTrue(ftype.Tag.Get("split_words")) { + words := gatherRegexp.FindAllStringSubmatch(ftype.Name, -1) + if len(words) > 0 { + var name []string + for _, words := range words { + if m := acronymRegexp.FindStringSubmatch(words[0]); len(m) == 3 { + name = append(name, m[1], m[2]) + } else { + name = append(name, words[0]) + } + } + + info.Key = strings.Join(name, "_") + } + } + if info.Alt != "" { + info.Key = info.Alt + } + if prefix != "" { + info.Key = fmt.Sprintf("%s_%s", prefix, info.Key) + } + info.Key = strings.ToUpper(info.Key) + infos = append(infos, info) + + if f.Kind() == reflect.Struct { + // honor Decode if present + if decoderFrom(f) == nil && setterFrom(f) == nil && textUnmarshaler(f) == nil && binaryUnmarshaler(f) == nil { + innerPrefix := prefix + if !ftype.Anonymous { + innerPrefix = info.Key + } + + embeddedPtr := f.Addr().Interface() + embeddedInfos, err := gatherInfo(innerPrefix, embeddedPtr) + if err != nil { + return nil, err + } + infos = append(infos[:len(infos)-1], embeddedInfos...) + + continue + } + } + } + return infos, nil +} + +// CheckDisallowed checks that no environment variables with the prefix are set +// that we don't know how or want to parse. This is likely only meaningful with +// a non-empty prefix. +func CheckDisallowed(prefix string, spec interface{}) error { + infos, err := gatherInfo(prefix, spec) + if err != nil { + return err + } + + vars := make(map[string]struct{}) + for _, info := range infos { + vars[info.Key] = struct{}{} + } + + if prefix != "" { + prefix = strings.ToUpper(prefix) + "_" + } + + for _, env := range os.Environ() { + if !strings.HasPrefix(env, prefix) { + continue + } + v := strings.SplitN(env, "=", 2)[0] + if _, found := vars[v]; !found { + return fmt.Errorf("unknown environment variable %s", v) + } + } + + return nil +} + +// Process populates the specified struct based on environment variables +func Process(prefix string, spec interface{}) ([]VarInfo, error) { + infos, err := gatherInfo(prefix, spec) + + for _, info := range infos { + // `os.Getenv` cannot differentiate between an explicitly set empty value + // and an unset value. `os.LookupEnv` is preferred to `syscall.Getenv`, + // but it is only available in go1.5 or newer. We're using Go build tags + // here to use os.LookupEnv for >=go1.5 + value, ok := lookupEnv(info.Key) + if !ok && info.Alt != "" { + value, ok = lookupEnv(info.Alt) + } + + def := info.Tags.Get("default") + if def != "" && !ok { + value = def + } + + req := info.Tags.Get("required") + if !ok && def == "" { + if isTrue(req) { + key := info.Key + if info.Alt != "" { + key = info.Alt + } + return nil, fmt.Errorf("required key %s missing value", key) + } + continue + } + + err = processField(value, info.Field) + if err != nil { + return nil, &ParseError{ + KeyName: info.Key, + FieldName: info.Name, + TypeName: info.Field.Type().String(), + Value: value, + Err: err, + } + } + } + + return infos, err +} + +// MustProcess is the same as Process but panics if an error occurs +func MustProcess(prefix string, spec interface{}) { + if _, err := Process(prefix, spec); err != nil { + panic(err) + } +} + +func processField(value string, field reflect.Value) error { + typ := field.Type() + + decoder := decoderFrom(field) + if decoder != nil { + return decoder.Decode(value) + } + // look for Set method if Decode not defined + setter := setterFrom(field) + if setter != nil { + return setter.Set(value) + } + + if t := textUnmarshaler(field); t != nil { + return t.UnmarshalText([]byte(value)) + } + + if b := binaryUnmarshaler(field); b != nil { + return b.UnmarshalBinary([]byte(value)) + } + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + if field.IsNil() { + field.Set(reflect.New(typ)) + } + field = field.Elem() + } + + switch typ.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var ( + val int64 + err error + ) + if field.Kind() == reflect.Int64 && typ.PkgPath() == "time" && typ.Name() == "Duration" { + var d time.Duration + d, err = time.ParseDuration(value) + val = int64(d) + } else { + val, err = strconv.ParseInt(value, 0, typ.Bits()) + } + if err != nil { + return err + } + + field.SetInt(val) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val, err := strconv.ParseUint(value, 0, typ.Bits()) + if err != nil { + return err + } + field.SetUint(val) + case reflect.Bool: + val, err := strconv.ParseBool(value) + if err != nil { + return err + } + field.SetBool(val) + case reflect.Float32, reflect.Float64: + val, err := strconv.ParseFloat(value, typ.Bits()) + if err != nil { + return err + } + field.SetFloat(val) + case reflect.Slice: + sl := reflect.MakeSlice(typ, 0, 0) + if typ.Elem().Kind() == reflect.Uint8 { + sl = reflect.ValueOf([]byte(value)) + } else if strings.TrimSpace(value) != "" { + vals := strings.Split(value, ",") + sl = reflect.MakeSlice(typ, len(vals), len(vals)) + for i, val := range vals { + err := processField(val, sl.Index(i)) + if err != nil { + return err + } + } + } + field.Set(sl) + case reflect.Map: + mp := reflect.MakeMap(typ) + if strings.TrimSpace(value) != "" { + pairs := strings.Split(value, ",") + for _, pair := range pairs { + kvpair := strings.Split(pair, ":") + if len(kvpair) != 2 { + return fmt.Errorf("invalid map item: %q", pair) + } + k := reflect.New(typ.Key()).Elem() + err := processField(kvpair[0], k) + if err != nil { + return err + } + v := reflect.New(typ.Elem()).Elem() + err = processField(kvpair[1], v) + if err != nil { + return err + } + mp.SetMapIndex(k, v) + } + } + field.Set(mp) + } + + return nil +} + +func interfaceFrom(field reflect.Value, fn func(interface{}, *bool)) { + // it may be impossible for a struct field to fail this check + if !field.CanInterface() { + return + } + var ok bool + fn(field.Interface(), &ok) + if !ok && field.CanAddr() { + fn(field.Addr().Interface(), &ok) + } +} + +func decoderFrom(field reflect.Value) (d Decoder) { + interfaceFrom(field, func(v interface{}, ok *bool) { d, *ok = v.(Decoder) }) + return d +} + +func setterFrom(field reflect.Value) (s Setter) { + interfaceFrom(field, func(v interface{}, ok *bool) { s, *ok = v.(Setter) }) + return s +} + +func textUnmarshaler(field reflect.Value) (t encoding.TextUnmarshaler) { + interfaceFrom(field, func(v interface{}, ok *bool) { t, *ok = v.(encoding.TextUnmarshaler) }) + return t +} + +func binaryUnmarshaler(field reflect.Value) (b encoding.BinaryUnmarshaler) { + interfaceFrom(field, func(v interface{}, ok *bool) { b, *ok = v.(encoding.BinaryUnmarshaler) }) + return b +} + +func isTrue(s string) bool { + b, _ := strconv.ParseBool(s) + return b +} diff --git a/internal/envconfig/envconfig_1.8_test.go b/internal/envconfig/envconfig_1.8_test.go new file mode 100644 index 0000000000..b530cf3c36 --- /dev/null +++ b/internal/envconfig/envconfig_1.8_test.go @@ -0,0 +1,68 @@ +//go:build go1.8 + +package envconfig + +import ( + "errors" + "net/url" + "os" + "testing" +) + +type SpecWithURL struct { + UrlValue url.URL + UrlPointer *url.URL +} + +func TestParseURL(t *testing.T) { + var s SpecWithURL + + os.Clearenv() + os.Setenv("ENV_CONFIG_URLVALUE", "https://github.com/kelseyhightower/envconfig") + os.Setenv("ENV_CONFIG_URLPOINTER", "https://github.com/kelseyhightower/envconfig") + + _, err := Process("env_config", &s) + if err != nil { + t.Fatal("unexpected error:", err) + } + + u, err := url.Parse("https://github.com/kelseyhightower/envconfig") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.UrlValue != *u { + t.Errorf("expected %q, got %q", u, s.UrlValue.String()) + } + + if *s.UrlPointer != *u { + t.Errorf("expected %q, got %q", u, s.UrlPointer) + } +} + +func TestParseURLError(t *testing.T) { + var s SpecWithURL + + os.Clearenv() + os.Setenv("ENV_CONFIG_URLPOINTER", "http_://foo") + + _, err := Process("env_config", &s) + + v, ok := err.(*ParseError) + if !ok { + t.Fatalf("expected ParseError, got %T %v", err, err) + } + if v.FieldName != "UrlPointer" { + t.Errorf("expected %s, got %v", "UrlPointer", v.FieldName) + } + + expectedUnerlyingError := url.Error{ + Op: "parse", + URL: "http_://foo", + Err: errors.New("first path segment in URL cannot contain colon"), + } + + if v.Err.Error() != expectedUnerlyingError.Error() { + t.Errorf("expected %q, got %q", expectedUnerlyingError, v.Err) + } +} diff --git a/internal/envconfig/envconfig_test.go b/internal/envconfig/envconfig_test.go new file mode 100644 index 0000000000..f203a5c660 --- /dev/null +++ b/internal/envconfig/envconfig_test.go @@ -0,0 +1,872 @@ +// Copyright (c) 2013 Kelsey Hightower. All rights reserved. +// Use of this source code is governed by the MIT License that can be found in +// the LICENSE file. + +package envconfig + +import ( + "flag" + "fmt" + "net/url" + "os" + "testing" + "time" +) + +type HonorDecodeInStruct struct { + Value string +} + +func (h *HonorDecodeInStruct) Decode(env string) error { + h.Value = "decoded" + return nil +} + +type CustomURL struct { + Value *url.URL +} + +func (cu *CustomURL) UnmarshalBinary(data []byte) error { + u, err := url.Parse(string(data)) + cu.Value = u + return err +} + +type Specification struct { + Embedded `desc:"can we document a struct"` + EmbeddedButIgnored `ignored:"true"` + Debug bool + Port int + Rate float32 + User string + TTL uint32 + Timeout time.Duration + AdminUsers []string + MagicNumbers []int + EmptyNumbers []int + ByteSlice []byte + ColorCodes map[string]int + MultiWordVar string + MultiWordVarWithAutoSplit uint32 `split_words:"true"` + MultiWordACRWithAutoSplit uint32 `split_words:"true"` + SomePointer *string + SomePointerWithDefault *string `default:"foo2baz" desc:"foorbar is the word"` + MultiWordVarWithAlt string `envconfig:"MULTI_WORD_VAR_WITH_ALT" desc:"what alt"` + MultiWordVarWithLowerCaseAlt string `envconfig:"multi_word_var_with_lower_case_alt"` + NoPrefixWithAlt string `envconfig:"SERVICE_HOST"` + DefaultVar string `default:"foobar"` + RequiredVar string `required:"True"` + NoPrefixDefault string `envconfig:"BROKER" default:"127.0.0.1"` + RequiredDefault string `required:"true" default:"foo2bar"` + Ignored string `ignored:"true"` + NestedSpecification struct { + Property string `envconfig:"inner"` + PropertyWithDefault string `default:"fuzzybydefault"` + } `envconfig:"outer"` + AfterNested string + DecodeStruct HonorDecodeInStruct `envconfig:"honor"` + Datetime time.Time + MapField map[string]string `default:"one:two,three:four"` + UrlValue CustomURL + UrlPointer *CustomURL +} + +type Embedded struct { + Enabled bool `desc:"some embedded value"` + EmbeddedPort int + MultiWordVar string + MultiWordVarWithAlt string `envconfig:"MULTI_WITH_DIFFERENT_ALT"` + EmbeddedAlt string `envconfig:"EMBEDDED_WITH_ALT"` + EmbeddedIgnored string `ignored:"true"` +} + +type EmbeddedButIgnored struct { + FirstEmbeddedButIgnored string + SecondEmbeddedButIgnored string +} + +func TestProcess(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("ENV_CONFIG_PORT", "8080") + os.Setenv("ENV_CONFIG_RATE", "0.5") + os.Setenv("ENV_CONFIG_USER", "Kelsey") + os.Setenv("ENV_CONFIG_TIMEOUT", "2m") + os.Setenv("ENV_CONFIG_ADMINUSERS", "John,Adam,Will") + os.Setenv("ENV_CONFIG_MAGICNUMBERS", "5,10,20") + os.Setenv("ENV_CONFIG_EMPTYNUMBERS", "") + os.Setenv("ENV_CONFIG_BYTESLICE", "this is a test value") + os.Setenv("ENV_CONFIG_COLORCODES", "red:1,green:2,blue:3") + os.Setenv("SERVICE_HOST", "127.0.0.1") + os.Setenv("ENV_CONFIG_TTL", "30") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + os.Setenv("ENV_CONFIG_IGNORED", "was-not-ignored") + os.Setenv("ENV_CONFIG_OUTER_INNER", "iamnested") + os.Setenv("ENV_CONFIG_AFTERNESTED", "after") + os.Setenv("ENV_CONFIG_HONOR", "honor") + os.Setenv("ENV_CONFIG_DATETIME", "2016-08-16T18:57:05Z") + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT", "24") + os.Setenv("ENV_CONFIG_MULTI_WORD_ACR_WITH_AUTO_SPLIT", "25") + os.Setenv("ENV_CONFIG_URLVALUE", "https://github.com/kelseyhightower/envconfig") + os.Setenv("ENV_CONFIG_URLPOINTER", "https://github.com/kelseyhightower/envconfig") + _, err := Process("env_config", &s) + if err != nil { + t.Error(err.Error()) + } + //if s.NoPrefixWithAlt != "127.0.0.1" { + // t.Errorf("expected %v, got %v", "127.0.0.1", s.NoPrefixWithAlt) + //} + if !s.Debug { + t.Errorf("expected %v, got %v", true, s.Debug) + } + if s.Port != 8080 { + t.Errorf("expected %d, got %v", 8080, s.Port) + } + if s.Rate != 0.5 { + t.Errorf("expected %f, got %v", 0.5, s.Rate) + } + if s.TTL != 30 { + t.Errorf("expected %d, got %v", 30, s.TTL) + } + if s.User != "Kelsey" { + t.Errorf("expected %s, got %s", "Kelsey", s.User) + } + if s.Timeout != 2*time.Minute { + t.Errorf("expected %s, got %s", 2*time.Minute, s.Timeout) + } + if s.RequiredVar != "foo" { + t.Errorf("expected %s, got %s", "foo", s.RequiredVar) + } + if len(s.AdminUsers) != 3 || + s.AdminUsers[0] != "John" || + s.AdminUsers[1] != "Adam" || + s.AdminUsers[2] != "Will" { + t.Errorf("expected %#v, got %#v", []string{"John", "Adam", "Will"}, s.AdminUsers) + } + if len(s.MagicNumbers) != 3 || + s.MagicNumbers[0] != 5 || + s.MagicNumbers[1] != 10 || + s.MagicNumbers[2] != 20 { + t.Errorf("expected %#v, got %#v", []int{5, 10, 20}, s.MagicNumbers) + } + if len(s.EmptyNumbers) != 0 { + t.Errorf("expected %#v, got %#v", []int{}, s.EmptyNumbers) + } + expected := "this is a test value" + if string(s.ByteSlice) != expected { + t.Errorf("expected %v, got %v", expected, string(s.ByteSlice)) + } + if s.Ignored != "" { + t.Errorf("expected empty string, got %#v", s.Ignored) + } + + if len(s.ColorCodes) != 3 || + s.ColorCodes["red"] != 1 || + s.ColorCodes["green"] != 2 || + s.ColorCodes["blue"] != 3 { + t.Errorf( + "expected %#v, got %#v", + map[string]int{ + "red": 1, + "green": 2, + "blue": 3, + }, + s.ColorCodes, + ) + } + + if s.NestedSpecification.Property != "iamnested" { + t.Errorf("expected '%s' string, got %#v", "iamnested", s.NestedSpecification.Property) + } + + if s.NestedSpecification.PropertyWithDefault != "fuzzybydefault" { + t.Errorf("expected default '%s' string, got %#v", "fuzzybydefault", s.NestedSpecification.PropertyWithDefault) + } + + if s.AfterNested != "after" { + t.Errorf("expected default '%s' string, got %#v", "after", s.AfterNested) + } + + if s.DecodeStruct.Value != "decoded" { + t.Errorf("expected default '%s' string, got %#v", "decoded", s.DecodeStruct.Value) + } + + if expected := time.Date(2016, 8, 16, 18, 57, 05, 0, time.UTC); !s.Datetime.Equal(expected) { + t.Errorf("expected %s, got %s", expected.Format(time.RFC3339), s.Datetime.Format(time.RFC3339)) + } + + if s.MultiWordVarWithAutoSplit != 24 { + t.Errorf("expected %q, got %q", 24, s.MultiWordVarWithAutoSplit) + } + + if s.MultiWordACRWithAutoSplit != 25 { + t.Errorf("expected %d, got %d", 25, s.MultiWordACRWithAutoSplit) + } + + u, err := url.Parse("https://github.com/kelseyhightower/envconfig") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if *s.UrlValue.Value != *u { + t.Errorf("expected %q, got %q", u, s.UrlValue.Value.String()) + } + + if *s.UrlPointer.Value != *u { + t.Errorf("expected %q, got %q", u, s.UrlPointer.Value.String()) + } +} + +func TestParseErrorBool(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "string") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + _, err := Process("env_config", &s) + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "Debug" { + t.Errorf("expected %s, got %v", "Debug", v.FieldName) + } + if s.Debug != false { + t.Errorf("expected %v, got %v", false, s.Debug) + } +} + +func TestParseErrorFloat32(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_RATE", "string") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + _, err := Process("env_config", &s) + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "Rate" { + t.Errorf("expected %s, got %v", "Rate", v.FieldName) + } + if s.Rate != 0 { + t.Errorf("expected %v, got %v", 0, s.Rate) + } +} + +func TestParseErrorInt(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_PORT", "string") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + _, err := Process("env_config", &s) + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "Port" { + t.Errorf("expected %s, got %v", "Port", v.FieldName) + } + if s.Port != 0 { + t.Errorf("expected %v, got %v", 0, s.Port) + } +} + +func TestParseErrorUint(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_TTL", "-30") + _, err := Process("env_config", &s) + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "TTL" { + t.Errorf("expected %s, got %v", "TTL", v.FieldName) + } + if s.TTL != 0 { + t.Errorf("expected %v, got %v", 0, s.TTL) + } +} + +func TestParseErrorSplitWords(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT", "shakespeare") + _, err := Process("env_config", &s) + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "MultiWordVarWithAutoSplit" { + t.Errorf("expected %s, got %v", "", v.FieldName) + } + if s.MultiWordVarWithAutoSplit != 0 { + t.Errorf("expected %v, got %v", 0, s.MultiWordVarWithAutoSplit) + } +} + +func TestErrInvalidSpecification(t *testing.T) { + m := make(map[string]string) + _, err := Process("env_config", &m) + if err != ErrInvalidSpecification { + t.Errorf("expected %v, got %v", ErrInvalidSpecification, err) + } +} + +func TestUnsetVars(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("USER", "foo") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + // If the var is not defined the non-prefixed version should not be used + // unless the struct tag says so + if s.User != "" { + t.Errorf("expected %q, got %q", "", s.User) + } +} + +func TestAlternateVarNames(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR", "foo") + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_ALT", "bar") + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_LOWER_CASE_ALT", "baz") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + // Setting the alt version of the var in the environment has no effect if + // the struct tag is not supplied + if s.MultiWordVar != "" { + t.Errorf("expected %q, got %q", "", s.MultiWordVar) + } + + // Setting the alt version of the var in the environment correctly sets + // the value if the struct tag IS supplied + if s.MultiWordVarWithAlt != "bar" { + t.Errorf("expected %q, got %q", "bar", s.MultiWordVarWithAlt) + } + + // Alt value is not case sensitive and is treated as all uppercase + if s.MultiWordVarWithLowerCaseAlt != "baz" { + t.Errorf("expected %q, got %q", "baz", s.MultiWordVarWithLowerCaseAlt) + } +} + +func TestRequiredVar(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foobar") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.RequiredVar != "foobar" { + t.Errorf("expected %s, got %s", "foobar", s.RequiredVar) + } +} + +func TestRequiredMissing(t *testing.T) { + var s Specification + os.Clearenv() + + _, err := Process("env_config", &s) + if err == nil { + t.Error("no failure when missing required variable") + } +} + +func TestBlankDefaultVar(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "requiredvalue") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.DefaultVar != "foobar" { + t.Errorf("expected %s, got %s", "foobar", s.DefaultVar) + } + + if *s.SomePointerWithDefault != "foo2baz" { + t.Errorf("expected %s, got %s", "foo2baz", *s.SomePointerWithDefault) + } +} + +func TestNonBlankDefaultVar(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEFAULTVAR", "nondefaultval") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "requiredvalue") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.DefaultVar != "nondefaultval" { + t.Errorf("expected %s, got %s", "nondefaultval", s.DefaultVar) + } +} + +func TestExplicitBlankDefaultVar(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEFAULTVAR", "") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "") + + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.DefaultVar != "" { + t.Errorf("expected %s, got %s", "\"\"", s.DefaultVar) + } +} + +func TestAlternateNameDefaultVar(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("BROKER", "betterbroker") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + //if s.NoPrefixDefault != "betterbroker" { + // t.Errorf("expected %q, got %q", "betterbroker", s.NoPrefixDefault) + //} + + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.NoPrefixDefault != "127.0.0.1" { + t.Errorf("expected %q, got %q", "127.0.0.1", s.NoPrefixDefault) + } +} + +func TestRequiredDefault(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.RequiredDefault != "foo2bar" { + t.Errorf("expected %q, got %q", "foo2bar", s.RequiredDefault) + } +} + +func TestPointerFieldBlank(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.SomePointer != nil { + t.Errorf("expected , got %q", *s.SomePointer) + } +} + +func TestEmptyMapFieldOverride(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + os.Setenv("ENV_CONFIG_MAPFIELD", "") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if s.MapField == nil { + t.Error("expected empty map, got ") + } + + if len(s.MapField) != 0 { + t.Errorf("expected empty map, got map of size %d", len(s.MapField)) + } +} + +func TestMustProcess(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("ENV_CONFIG_PORT", "8080") + os.Setenv("ENV_CONFIG_RATE", "0.5") + os.Setenv("ENV_CONFIG_USER", "Kelsey") + os.Setenv("SERVICE_HOST", "127.0.0.1") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + MustProcess("env_config", &s) + + defer func() { + if err := recover(); err != nil { + return + } + + t.Error("expected panic") + }() + m := make(map[string]string) + MustProcess("env_config", &m) +} + +func TestEmbeddedStruct(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "required") + os.Setenv("ENV_CONFIG_ENABLED", "true") + os.Setenv("ENV_CONFIG_EMBEDDEDPORT", "1234") + os.Setenv("ENV_CONFIG_MULTIWORDVAR", "foo") + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_ALT", "bar") + os.Setenv("ENV_CONFIG_MULTI_WITH_DIFFERENT_ALT", "baz") + os.Setenv("ENV_CONFIG_EMBEDDED_WITH_ALT", "foobar") + os.Setenv("ENV_CONFIG_SOMEPOINTER", "foobaz") + os.Setenv("ENV_CONFIG_EMBEDDED_IGNORED", "was-not-ignored") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + if !s.Enabled { + t.Errorf("expected %v, got %v", true, s.Enabled) + } + if s.EmbeddedPort != 1234 { + t.Errorf("expected %d, got %v", 1234, s.EmbeddedPort) + } + if s.MultiWordVar != "foo" { + t.Errorf("expected %s, got %s", "foo", s.MultiWordVar) + } + if s.Embedded.MultiWordVar != "foo" { + t.Errorf("expected %s, got %s", "foo", s.Embedded.MultiWordVar) + } + if s.MultiWordVarWithAlt != "bar" { + t.Errorf("expected %s, got %s", "bar", s.MultiWordVarWithAlt) + } + if s.Embedded.MultiWordVarWithAlt != "baz" { + t.Errorf("expected %s, got %s", "baz", s.Embedded.MultiWordVarWithAlt) + } + if s.EmbeddedAlt != "foobar" { + t.Errorf("expected %s, got %s", "foobar", s.EmbeddedAlt) + } + if *s.SomePointer != "foobaz" { + t.Errorf("expected %s, got %s", "foobaz", *s.SomePointer) + } + if s.EmbeddedIgnored != "" { + t.Errorf("expected empty string, got %#v", s.Ignored) + } +} + +func TestEmbeddedButIgnoredStruct(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "required") + os.Setenv("ENV_CONFIG_FIRSTEMBEDDEDBUTIGNORED", "was-not-ignored") + os.Setenv("ENV_CONFIG_SECONDEMBEDDEDBUTIGNORED", "was-not-ignored") + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + if s.FirstEmbeddedButIgnored != "" { + t.Errorf("expected empty string, got %#v", s.Ignored) + } + if s.SecondEmbeddedButIgnored != "" { + t.Errorf("expected empty string, got %#v", s.Ignored) + } +} + +func TestNonPointerFailsProperly(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "snap") + + _, err := Process("env_config", s) + if err != ErrInvalidSpecification { + t.Errorf("non-pointer should fail with ErrInvalidSpecification, was instead %s", err) + } +} + +func TestCustomValueFields(t *testing.T) { + var s struct { + Foo string + Bar bracketed + Baz quoted + Struct setterStruct + } + + // Set would panic when the receiver is nil, + // so make sure it has an initial value to replace. + s.Baz = quoted{new(bracketed)} + + os.Clearenv() + os.Setenv("ENV_CONFIG_FOO", "foo") + os.Setenv("ENV_CONFIG_BAR", "bar") + os.Setenv("ENV_CONFIG_BAZ", "baz") + os.Setenv("ENV_CONFIG_STRUCT", "inner") + + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if want := "foo"; s.Foo != want { + t.Errorf("foo: got %#q, want %#q", s.Foo, want) + } + + if want := "[bar]"; s.Bar.String() != want { + t.Errorf("bar: got %#q, want %#q", s.Bar, want) + } + + if want := `["baz"]`; s.Baz.String() != want { + t.Errorf(`baz: got %#q, want %#q`, s.Baz, want) + } + + if want := `setterstruct{"inner"}`; s.Struct.Inner != want { + t.Errorf(`Struct.Inner: got %#q, want %#q`, s.Struct.Inner, want) + } +} + +func TestCustomPointerFields(t *testing.T) { + var s struct { + Foo string + Bar *bracketed + Baz *quoted + Struct *setterStruct + } + + // Set would panic when the receiver is nil, + // so make sure they have initial values to replace. + s.Bar = new(bracketed) + s.Baz = "ed{new(bracketed)} + + os.Clearenv() + os.Setenv("ENV_CONFIG_FOO", "foo") + os.Setenv("ENV_CONFIG_BAR", "bar") + os.Setenv("ENV_CONFIG_BAZ", "baz") + os.Setenv("ENV_CONFIG_STRUCT", "inner") + + if _, err := Process("env_config", &s); err != nil { + t.Error(err.Error()) + } + + if want := "foo"; s.Foo != want { + t.Errorf("foo: got %#q, want %#q", s.Foo, want) + } + + if want := "[bar]"; s.Bar.String() != want { + t.Errorf("bar: got %#q, want %#q", s.Bar, want) + } + + if want := `["baz"]`; s.Baz.String() != want { + t.Errorf(`baz: got %#q, want %#q`, s.Baz, want) + } + + if want := `setterstruct{"inner"}`; s.Struct.Inner != want { + t.Errorf(`Struct.Inner: got %#q, want %#q`, s.Struct.Inner, want) + } +} + +func TestEmptyPrefixUsesFieldNames(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("REQUIREDVAR", "foo") + + _, err := Process("", &s) + if err != nil { + t.Errorf("Process failed: %s", err) + } + + if s.RequiredVar != "foo" { + t.Errorf( + `RequiredVar not populated correctly: expected "foo", got %q`, + s.RequiredVar, + ) + } +} + +type Nested struct { + NestedSpecification struct { + Property string `envconfig:"inner"` + PropertyWithDefault string `default:"fuzzybydefault" envconfig:"property_with_default"` + } `envconfig:"outer"` +} + +func TestNestedStructVarName(t *testing.T) { + var s Nested + os.Clearenv() + os.Setenv("ENV_CONFIG_OUTER_INNER", "required") + _, err := Process("ENV_CONFIG", &s) + if err != nil { + t.Error(err.Error()) + } + if s.NestedSpecification.Property != "required" { + t.Errorf("expected %s, got %s", "required", s.NestedSpecification.Property) + } + if s.NestedSpecification.PropertyWithDefault != "fuzzybydefault" { + t.Errorf("expected %s, got %s", "fuzzybydefault", s.NestedSpecification.PropertyWithDefault) + } +} + +func TestTextUnmarshalerError(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + os.Setenv("ENV_CONFIG_DATETIME", "I'M NOT A DATE") + + _, err := Process("env_config", &s) + + v, ok := err.(*ParseError) + if !ok { + t.Errorf("expected ParseError, got %v", v) + } + if v.FieldName != "Datetime" { + t.Errorf("expected %s, got %v", "Datetime", v.FieldName) + } + + expectedLowLevelError := time.ParseError{ + Layout: time.RFC3339, + Value: "I'M NOT A DATE", + LayoutElem: "2006", + ValueElem: "I'M NOT A DATE", + } + + if v.Err.Error() != expectedLowLevelError.Error() { + t.Errorf("expected %s, got %s", expectedLowLevelError, v.Err) + } +} + +func TestBinaryUnmarshalerError(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + os.Setenv("ENV_CONFIG_URLPOINTER", "http://%41:8080/") + + _, err := Process("env_config", &s) + + v, ok := err.(*ParseError) + if !ok { + t.Fatalf("expected ParseError, got %T %v", err, err) + } + if v.FieldName != "UrlPointer" { + t.Errorf("expected %s, got %v", "UrlPointer", v.FieldName) + } + + // To be compatible with go 1.5 and lower we should do a very basic check, + // because underlying error message varies in go 1.5 and go 1.6+. + + ue, ok := v.Err.(*url.Error) + if !ok { + t.Errorf("expected error type to be \"*url.Error\", got %T", v.Err) + } + + if ue.Op != "parse" { + t.Errorf("expected error op to be \"parse\", got %q", ue.Op) + } +} + +func TestCheckDisallowedOnlyAllowed(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("UNRELATED_ENV_VAR", "true") + err := CheckDisallowed("env_config", &s) + if err != nil { + t.Errorf("expected no error, got %s", err) + } +} + +func TestCheckDisallowedMispelled(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("ENV_CONFIG_ZEBUG", "false") + err := CheckDisallowed("env_config", &s) + if experr := "unknown environment variable ENV_CONFIG_ZEBUG"; err.Error() != experr { + t.Errorf("expected %s, got %s", experr, err) + } +} + +func TestCheckDisallowedIgnored(t *testing.T) { + var s Specification + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("ENV_CONFIG_IGNORED", "false") + err := CheckDisallowed("env_config", &s) + if experr := "unknown environment variable ENV_CONFIG_IGNORED"; err.Error() != experr { + t.Errorf("expected %s, got %s", experr, err) + } +} + +//func TestErrorMessageForRequiredAltVar(t *testing.T) { +// var s struct { +// Foo string `envconfig:"BAR" required:"true"` +// } +// +// os.Clearenv() +// _, err := Process("env_config", &s) +// +// if err == nil { +// t.Error("no failure when missing required variable") +// } +// +// if !strings.Contains(err.Error(), " BAR ") { +// t.Errorf("expected error message to contain BAR, got \"%v\"", err) +// } +//} + +type bracketed string + +func (b *bracketed) Set(value string) error { + *b = bracketed("[" + value + "]") + return nil +} + +func (b bracketed) String() string { + return string(b) +} + +// quoted is used to test the precedence of Decode over Set. +// The sole field is a flag.Value rather than a setter to validate that +// all flag.Value implementations are also Setter implementations. +type quoted struct{ flag.Value } + +func (d quoted) Decode(value string) error { + return d.Set(`"` + value + `"`) +} + +type setterStruct struct { + Inner string +} + +func (ss *setterStruct) Set(value string) error { + ss.Inner = fmt.Sprintf("setterstruct{%q}", value) + return nil +} + +func BenchmarkGatherInfo(b *testing.B) { + os.Clearenv() + os.Setenv("ENV_CONFIG_DEBUG", "true") + os.Setenv("ENV_CONFIG_PORT", "8080") + os.Setenv("ENV_CONFIG_RATE", "0.5") + os.Setenv("ENV_CONFIG_USER", "Kelsey") + os.Setenv("ENV_CONFIG_TIMEOUT", "2m") + os.Setenv("ENV_CONFIG_ADMINUSERS", "John,Adam,Will") + os.Setenv("ENV_CONFIG_MAGICNUMBERS", "5,10,20") + os.Setenv("ENV_CONFIG_COLORCODES", "red:1,green:2,blue:3") + os.Setenv("SERVICE_HOST", "127.0.0.1") + os.Setenv("ENV_CONFIG_TTL", "30") + os.Setenv("ENV_CONFIG_REQUIREDVAR", "foo") + os.Setenv("ENV_CONFIG_IGNORED", "was-not-ignored") + os.Setenv("ENV_CONFIG_OUTER_INNER", "iamnested") + os.Setenv("ENV_CONFIG_AFTERNESTED", "after") + os.Setenv("ENV_CONFIG_HONOR", "honor") + os.Setenv("ENV_CONFIG_DATETIME", "2016-08-16T18:57:05Z") + os.Setenv("ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT", "24") + for i := 0; i < b.N; i++ { + var s Specification + _, _ = gatherInfo("env_config", &s) + } +} diff --git a/internal/envconfig/testdata/custom.txt b/internal/envconfig/testdata/custom.txt new file mode 100644 index 0000000000..04d2f5d0ec --- /dev/null +++ b/internal/envconfig/testdata/custom.txt @@ -0,0 +1,36 @@ +ENV_CONFIG_ENABLED=some.embedded.value +ENV_CONFIG_EMBEDDEDPORT= +ENV_CONFIG_MULTIWORDVAR= +ENV_CONFIG_MULTI_WITH_DIFFERENT_ALT= +ENV_CONFIG_EMBEDDED_WITH_ALT= +ENV_CONFIG_DEBUG= +ENV_CONFIG_PORT= +ENV_CONFIG_RATE= +ENV_CONFIG_USER= +ENV_CONFIG_TTL= +ENV_CONFIG_TIMEOUT= +ENV_CONFIG_ADMINUSERS= +ENV_CONFIG_MAGICNUMBERS= +ENV_CONFIG_EMPTYNUMBERS= +ENV_CONFIG_BYTESLICE= +ENV_CONFIG_COLORCODES= +ENV_CONFIG_MULTIWORDVAR= +ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT= +ENV_CONFIG_MULTI_WORD_ACR_WITH_AUTO_SPLIT= +ENV_CONFIG_SOMEPOINTER= +ENV_CONFIG_SOMEPOINTERWITHDEFAULT=foorbar.is.the.word +ENV_CONFIG_MULTI_WORD_VAR_WITH_ALT=what.alt +ENV_CONFIG_MULTI_WORD_VAR_WITH_LOWER_CASE_ALT= +ENV_CONFIG_SERVICE_HOST= +ENV_CONFIG_DEFAULTVAR= +ENV_CONFIG_REQUIREDVAR= +ENV_CONFIG_BROKER= +ENV_CONFIG_REQUIREDDEFAULT= +ENV_CONFIG_OUTER_INNER= +ENV_CONFIG_OUTER_PROPERTYWITHDEFAULT= +ENV_CONFIG_AFTERNESTED= +ENV_CONFIG_HONOR= +ENV_CONFIG_DATETIME= +ENV_CONFIG_MAPFIELD= +ENV_CONFIG_URLVALUE= +ENV_CONFIG_URLPOINTER= diff --git a/internal/envconfig/testdata/default_list.txt b/internal/envconfig/testdata/default_list.txt new file mode 100644 index 0000000000..fb0eced775 --- /dev/null +++ b/internal/envconfig/testdata/default_list.txt @@ -0,0 +1,183 @@ +This.application.is.configured.via.the.environment..The.following.environment +variables.can.be.used: + +ENV_CONFIG_ENABLED +..[description].some.embedded.value +..[type]........True.or.False +..[default]..... +..[required].... +ENV_CONFIG_EMBEDDEDPORT +..[description]. +..[type]........Integer +..[default]..... +..[required].... +ENV_CONFIG_MULTIWORDVAR +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_MULTI_WITH_DIFFERENT_ALT +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_EMBEDDED_WITH_ALT +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_DEBUG +..[description]. +..[type]........True.or.False +..[default]..... +..[required].... +ENV_CONFIG_PORT +..[description]. +..[type]........Integer +..[default]..... +..[required].... +ENV_CONFIG_RATE +..[description]. +..[type]........Float +..[default]..... +..[required].... +ENV_CONFIG_USER +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_TTL +..[description]. +..[type]........Unsigned.Integer +..[default]..... +..[required].... +ENV_CONFIG_TIMEOUT +..[description]. +..[type]........Duration +..[default]..... +..[required].... +ENV_CONFIG_ADMINUSERS +..[description]. +..[type]........Comma-separated.list.of.String +..[default]..... +..[required].... +ENV_CONFIG_MAGICNUMBERS +..[description]. +..[type]........Comma-separated.list.of.Integer +..[default]..... +..[required].... +ENV_CONFIG_EMPTYNUMBERS +..[description]. +..[type]........Comma-separated.list.of.Integer +..[default]..... +..[required].... +ENV_CONFIG_BYTESLICE +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_COLORCODES +..[description]. +..[type]........Comma-separated.list.of.String:Integer.pairs +..[default]..... +..[required].... +ENV_CONFIG_MULTIWORDVAR +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT +..[description]. +..[type]........Unsigned.Integer +..[default]..... +..[required].... +ENV_CONFIG_MULTI_WORD_ACR_WITH_AUTO_SPLIT +..[description]. +..[type]........Unsigned.Integer +..[default]..... +..[required].... +ENV_CONFIG_SOMEPOINTER +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_SOMEPOINTERWITHDEFAULT +..[description].foorbar.is.the.word +..[type]........String +..[default].....foo2baz +..[required].... +ENV_CONFIG_MULTI_WORD_VAR_WITH_ALT +..[description].what.alt +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_MULTI_WORD_VAR_WITH_LOWER_CASE_ALT +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_SERVICE_HOST +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_DEFAULTVAR +..[description]. +..[type]........String +..[default].....foobar +..[required].... +ENV_CONFIG_REQUIREDVAR +..[description]. +..[type]........String +..[default]..... +..[required]....true +ENV_CONFIG_BROKER +..[description]. +..[type]........String +..[default].....127.0.0.1 +..[required].... +ENV_CONFIG_REQUIREDDEFAULT +..[description]. +..[type]........String +..[default].....foo2bar +..[required]....true +ENV_CONFIG_OUTER_INNER +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_OUTER_PROPERTYWITHDEFAULT +..[description]. +..[type]........String +..[default].....fuzzybydefault +..[required].... +ENV_CONFIG_AFTERNESTED +..[description]. +..[type]........String +..[default]..... +..[required].... +ENV_CONFIG_HONOR +..[description]. +..[type]........HonorDecodeInStruct +..[default]..... +..[required].... +ENV_CONFIG_DATETIME +..[description]. +..[type]........Time +..[default]..... +..[required].... +ENV_CONFIG_MAPFIELD +..[description]. +..[type]........Comma-separated.list.of.String:String.pairs +..[default].....one:two,three:four +..[required].... +ENV_CONFIG_URLVALUE +..[description]. +..[type]........CustomURL +..[default]..... +..[required].... +ENV_CONFIG_URLPOINTER +..[description]. +..[type]........CustomURL +..[default]..... +..[required].... diff --git a/internal/envconfig/testdata/default_table.txt b/internal/envconfig/testdata/default_table.txt new file mode 100644 index 0000000000..65c9b445e1 --- /dev/null +++ b/internal/envconfig/testdata/default_table.txt @@ -0,0 +1,40 @@ +This.application.is.configured.via.the.environment..The.following.environment +variables.can.be.used: + +KEY..............................................TYPE............................................DEFAULT...............REQUIRED....DESCRIPTION +ENV_CONFIG_ENABLED...............................True.or.False.....................................................................some.embedded.value +ENV_CONFIG_EMBEDDEDPORT..........................Integer........................................................................... +ENV_CONFIG_MULTIWORDVAR..........................String............................................................................ +ENV_CONFIG_MULTI_WITH_DIFFERENT_ALT..............String............................................................................ +ENV_CONFIG_EMBEDDED_WITH_ALT.....................String............................................................................ +ENV_CONFIG_DEBUG.................................True.or.False..................................................................... +ENV_CONFIG_PORT..................................Integer........................................................................... +ENV_CONFIG_RATE..................................Float............................................................................. +ENV_CONFIG_USER..................................String............................................................................ +ENV_CONFIG_TTL...................................Unsigned.Integer.................................................................. +ENV_CONFIG_TIMEOUT...............................Duration.......................................................................... +ENV_CONFIG_ADMINUSERS............................Comma-separated.list.of.String.................................................... +ENV_CONFIG_MAGICNUMBERS..........................Comma-separated.list.of.Integer................................................... +ENV_CONFIG_EMPTYNUMBERS..........................Comma-separated.list.of.Integer................................................... +ENV_CONFIG_BYTESLICE.............................String............................................................................ +ENV_CONFIG_COLORCODES............................Comma-separated.list.of.String:Integer.pairs...................................... +ENV_CONFIG_MULTIWORDVAR..........................String............................................................................ +ENV_CONFIG_MULTI_WORD_VAR_WITH_AUTO_SPLIT........Unsigned.Integer.................................................................. +ENV_CONFIG_MULTI_WORD_ACR_WITH_AUTO_SPLIT........Unsigned.Integer.................................................................. +ENV_CONFIG_SOMEPOINTER...........................String............................................................................ +ENV_CONFIG_SOMEPOINTERWITHDEFAULT................String..........................................foo2baz...........................foorbar.is.the.word +ENV_CONFIG_MULTI_WORD_VAR_WITH_ALT...............String............................................................................what.alt +ENV_CONFIG_MULTI_WORD_VAR_WITH_LOWER_CASE_ALT....String............................................................................ +ENV_CONFIG_SERVICE_HOST..........................String............................................................................ +ENV_CONFIG_DEFAULTVAR............................String..........................................foobar............................ +ENV_CONFIG_REQUIREDVAR...........................String................................................................true........ +ENV_CONFIG_BROKER................................String..........................................127.0.0.1......................... +ENV_CONFIG_REQUIREDDEFAULT.......................String..........................................foo2bar...............true........ +ENV_CONFIG_OUTER_INNER...........................String............................................................................ +ENV_CONFIG_OUTER_PROPERTYWITHDEFAULT.............String..........................................fuzzybydefault.................... +ENV_CONFIG_AFTERNESTED...........................String............................................................................ +ENV_CONFIG_HONOR.................................HonorDecodeInStruct............................................................... +ENV_CONFIG_DATETIME..............................Time.............................................................................. +ENV_CONFIG_MAPFIELD..............................Comma-separated.list.of.String:String.pairs.....one:two,three:four................ +ENV_CONFIG_URLVALUE..............................CustomURL......................................................................... +ENV_CONFIG_URLPOINTER............................CustomURL......................................................................... diff --git a/internal/envconfig/testdata/fault.txt b/internal/envconfig/testdata/fault.txt new file mode 100644 index 0000000000..b525ff12d2 --- /dev/null +++ b/internal/envconfig/testdata/fault.txt @@ -0,0 +1,36 @@ +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} +{.Key} diff --git a/internal/envconfig/usage.go b/internal/envconfig/usage.go new file mode 100644 index 0000000000..ba2d136ce9 --- /dev/null +++ b/internal/envconfig/usage.go @@ -0,0 +1,164 @@ +// Copyright (c) 2016 Kelsey Hightower and others. All rights reserved. +// Use of this source code is governed by the MIT License that can be found in +// the LICENSE file. + +package envconfig + +import ( + "encoding" + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "text/tabwriter" + "text/template" +) + +const ( + // DefaultListFormat constant to use to display usage in a list format + DefaultListFormat = `This application is configured via the environment. The following environment +variables can be used: +{{range .}} +{{usage_key .}} + [description] {{usage_description .}} + [type] {{usage_type .}} + [default] {{usage_default .}} + [required] {{usage_required .}}{{end}} +` + // DefaultTableFormat constant to use to display usage in a tabular format + DefaultTableFormat = `This application is configured via the environment. The following environment +variables can be used: + +KEY TYPE DEFAULT REQUIRED DESCRIPTION +{{range .}}{{usage_key .}} {{usage_type .}} {{usage_default .}} {{usage_required .}} {{usage_description .}} +{{end}}` +) + +var ( + decoderType = reflect.TypeOf((*Decoder)(nil)).Elem() + setterType = reflect.TypeOf((*Setter)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() +) + +func implementsInterface(t reflect.Type) bool { + return t.Implements(decoderType) || + reflect.PointerTo(t).Implements(decoderType) || + t.Implements(setterType) || + reflect.PointerTo(t).Implements(setterType) || + t.Implements(textUnmarshalerType) || + reflect.PointerTo(t).Implements(textUnmarshalerType) || + t.Implements(binaryUnmarshalerType) || + reflect.PointerTo(t).Implements(binaryUnmarshalerType) +} + +// toTypeDescription converts Go types into a human readable description +func toTypeDescription(t reflect.Type) string { + switch t.Kind() { + case reflect.Array, reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + return "String" + } + return fmt.Sprintf("Comma-separated list of %s", toTypeDescription(t.Elem())) + case reflect.Map: + return fmt.Sprintf( + "Comma-separated list of %s:%s pairs", + toTypeDescription(t.Key()), + toTypeDescription(t.Elem()), + ) + case reflect.Ptr: + return toTypeDescription(t.Elem()) + case reflect.Struct: + if implementsInterface(t) && t.Name() != "" { + return t.Name() + } + return "" + case reflect.String: + name := t.Name() + if name != "" && name != "string" { + return name + } + return "String" + case reflect.Bool: + name := t.Name() + if name != "" && name != "bool" { + return name + } + return "True or False" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + name := t.Name() + if name != "" && !strings.HasPrefix(name, "int") { + return name + } + return "Integer" + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + name := t.Name() + if name != "" && !strings.HasPrefix(name, "uint") { + return name + } + return "Unsigned Integer" + case reflect.Float32, reflect.Float64: + name := t.Name() + if name != "" && !strings.HasPrefix(name, "float") { + return name + } + return "Float" + } + return fmt.Sprintf("%+v", t) +} + +// Usage writes usage information to stdout using the default header and table format +func Usage(prefix string, spec interface{}) error { + // The default is to output the usage information as a table + // Create tabwriter instance to support table output + tabs := tabwriter.NewWriter(os.Stdout, 1, 0, 4, ' ', 0) + + err := Usagef(prefix, spec, tabs, DefaultTableFormat) + tabs.Flush() + return err +} + +// Usagef writes usage information to the specified io.Writer using the specified template specification +func Usagef(prefix string, spec interface{}, out io.Writer, format string) error { + + // Specify the default usage template functions + functions := template.FuncMap{ + "usage_key": func(v VarInfo) string { return v.Key }, + "usage_description": func(v VarInfo) string { return v.Tags.Get("desc") }, + "usage_type": func(v VarInfo) string { return toTypeDescription(v.Field.Type()) }, + "usage_default": func(v VarInfo) string { return v.Tags.Get("default") }, + "usage_required": func(v VarInfo) (string, error) { + req := v.Tags.Get("required") + if req != "" { + reqB, err := strconv.ParseBool(req) + if err != nil { + return "", err + } + if reqB { + req = "true" + } + } + return req, nil + }, + } + + tmpl, err := template.New("envconfig").Funcs(functions).Parse(format) + if err != nil { + return err + } + + return Usaget(prefix, spec, out, tmpl) +} + +// Usaget writes usage information to the specified io.Writer using the specified template +func Usaget(prefix string, spec interface{}, out io.Writer, tmpl *template.Template) error { + // gather first + infos, err := gatherInfo(prefix, spec) + if err != nil { + return err + } + + return tmpl.Execute(out, infos) +} diff --git a/internal/envconfig/usage_test.go b/internal/envconfig/usage_test.go new file mode 100644 index 0000000000..d0d3c0ab7c --- /dev/null +++ b/internal/envconfig/usage_test.go @@ -0,0 +1,153 @@ +// Copyright (c) 2016 Kelsey Hightower and others. All rights reserved. +// Use of this source code is governed by the MIT License that can be found in +// the LICENSE file. + +package envconfig + +import ( + "bytes" + "io" + "log" + "os" + "strings" + "testing" + "text/tabwriter" +) + +var testUsageTableResult, testUsageListResult, testUsageCustomResult, testUsageBadFormatResult string + +func TestMain(m *testing.M) { + // Load the expected test results from a text file + data, err := os.ReadFile("testdata/default_table.txt") + if err != nil { + log.Fatal(err) + } + testUsageTableResult = string(data) + + data, err = os.ReadFile("testdata/default_list.txt") + if err != nil { + log.Fatal(err) + } + testUsageListResult = string(data) + + data, err = os.ReadFile("testdata/custom.txt") + if err != nil { + log.Fatal(err) + } + testUsageCustomResult = string(data) + + data, err = os.ReadFile("testdata/fault.txt") + if err != nil { + log.Fatal(err) + } + testUsageBadFormatResult = string(data) + + retCode := m.Run() + os.Exit(retCode) +} + +func compareUsage(want, got string, t *testing.T) { + got = strings.ReplaceAll(got, " ", ".") + if want != got { + shortest := len(want) + if len(got) < shortest { + shortest = len(got) + } + if len(want) != len(got) { + t.Errorf("expected result length of %d, found %d", len(want), len(got)) + } + for i := 0; i < shortest; i++ { + if want[i] != got[i] { + t.Errorf("difference at index %d, expected '%c' (%v), found '%c' (%v)\n", + i, want[i], want[i], got[i], got[i]) + break + } + } + t.Errorf("Complete Expected:\n'%s'\nComplete Found:\n'%s'\n", want, got) + } +} + +func TestUsageDefault(t *testing.T) { + var s Specification + os.Clearenv() + save := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + err := Usage("env_config", &s) + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + outC <- buf.String() + }() + _ = w.Close() + os.Stdout = save // restoring the real stdout + out := <-outC + + if err != nil { + t.Error(err.Error()) + } + compareUsage(testUsageTableResult, out, t) +} + +func TestUsageTable(t *testing.T) { + var s Specification + os.Clearenv() + buf := new(bytes.Buffer) + tabs := tabwriter.NewWriter(buf, 1, 0, 4, ' ', 0) + err := Usagef("env_config", &s, tabs, DefaultTableFormat) + _ = tabs.Flush() + if err != nil { + t.Error(err.Error()) + } + compareUsage(testUsageTableResult, buf.String(), t) +} + +func TestUsageList(t *testing.T) { + var s Specification + os.Clearenv() + buf := new(bytes.Buffer) + err := Usagef("env_config", &s, buf, DefaultListFormat) + if err != nil { + t.Error(err.Error()) + } + compareUsage(testUsageListResult, buf.String(), t) +} + +func TestUsageCustomFormat(t *testing.T) { + var s Specification + os.Clearenv() + buf := new(bytes.Buffer) + err := Usagef("env_config", &s, buf, "{{range .}}{{usage_key .}}={{usage_description .}}\n{{end}}") + if err != nil { + t.Error(err.Error()) + } + compareUsage(testUsageCustomResult, buf.String(), t) +} + +func TestUsageUnknownKeyFormat(t *testing.T) { + var s Specification + unknownError := "template: envconfig:1:2: executing \"envconfig\" at <.UnknownKey>" + os.Clearenv() + buf := new(bytes.Buffer) + err := Usagef("env_config", &s, buf, "{{.UnknownKey}}") + if err == nil { + t.Fatalf("expected 'unknown key' error, but got no error") + } + if !strings.Contains(err.Error(), unknownError) { + t.Errorf("expected '%s', but got '%s'", unknownError, err.Error()) + } +} + +func TestUsageBadFormat(t *testing.T) { + var s Specification + os.Clearenv() + // If you don't use two {{}} then you get a lieteral + buf := new(bytes.Buffer) + err := Usagef("env_config", &s, buf, "{{range .}}{.Key}\n{{end}}") + if err != nil { + t.Error(err.Error()) + } + compareUsage(testUsageBadFormatResult, buf.String(), t) +} diff --git a/internal/natsbroker/broker.go b/internal/natsbroker/broker.go index 08b9252b94..f2f80ebf16 100644 --- a/internal/natsbroker/broker.go +++ b/internal/natsbroker/broker.go @@ -3,11 +3,14 @@ package natsbroker import ( "context" + "crypto/tls" "fmt" "strings" "sync" "time" + "github.com/rs/zerolog/log" + "github.com/centrifugal/centrifuge" "github.com/centrifugal/protocol" "github.com/nats-io/nats.go" @@ -20,10 +23,57 @@ type ( // Config of NatsBroker. type Config struct { - URL string - Prefix string - DialTimeout time.Duration + // URL is a Nats server URL. + URL string + // Prefix allows customizing channel prefix in Nats to work with a single Nats from different + // unrelated Centrifugo setups. + Prefix string + // DialTimeout is a timeout for establishing connection to Nats. + DialTimeout time.Duration + // WriteTimeout is a timeout for write operation to Nats. WriteTimeout time.Duration + // TLS for the Nats connection. TLS is not used if nil. + TLS *tls.Config + + // AllowWildcards allows to enable wildcard subscriptions. By default, wildcard subscriptions + // are not allowed. Using wildcard subscriptions can't be combined with join/leave events and presence + // because subscriptions do not belong to a concrete channel after with wildcards, while join/leave events + // require concrete channel to be published. And presence does not make a lot of sense for wildcard + // subscriptions - there could be subscribers which use different mask, but still receive subset of updates. + // It's required to use channels without wildcards to for mentioned features to work properly. When + // using wildcard subscriptions a special care is needed regarding security - pay additional + // attention to a proper permission management. + AllowWildcards bool + + // RawMode allows enabling raw communication with Nats. When on, Centrifugo subscribes to channels + // without adding any prefixes to channel name. Proper prefixes must be managed by the application in this + // case. Data consumed from Nats is sent directly to subscribers without any processing. When publishing + // to Nats Centrifugo does not add any prefixes to channel names also. Centrifugo features like Publication + // tags, Publication ClientInfo, join/leave events are not supported in raw mode. + RawMode RawModeConfig +} + +type RawModeConfig struct { + // Enabled enables raw mode when true. + Enabled bool + + // ChannelReplacements is a map where keys are strings to replace and values are replacements. + // For example, you have Centrifugo namespace "chat" and using channel "chat:index", but you want to + // use channel "chat.index" in Nats. Then you can define SymbolReplacements map like this: {":": "."}. + // In this case Centrifugo will replace all ":" symbols in channel name with "." before sending to Nats. + // Broker keeps reverse mapping to the original channel to broadcast to proper channels when processing + // messages received from Nats. + ChannelReplacements map[string]string + + // Prefix is a string that will be added to all channels when publishing messages to Nats, subscribing + // to channels in Nats. It's also stripped from channel name when processing messages received from Nats. + // By default, no prefix is used. + Prefix string +} + +type subWrapper struct { + sub *nats.Subscription + origChannel string } // NatsBroker is a broker on top of Nats messaging system. @@ -31,10 +81,12 @@ type NatsBroker struct { node *centrifuge.Node config Config - nc *nats.Conn - subsMu sync.Mutex - subs map[channelID]*nats.Subscription - eventHandler centrifuge.BrokerEventHandler + nc *nats.Conn + subsMu sync.RWMutex + subs map[channelID]subWrapper + eventHandler centrifuge.BrokerEventHandler + clientChannelPrefix string + rawModeReplacer *strings.Replacer } var _ centrifuge.Broker = (*NatsBroker)(nil) @@ -42,9 +94,20 @@ var _ centrifuge.Broker = (*NatsBroker)(nil) // New creates NatsBroker. func New(n *centrifuge.Node, conf Config) (*NatsBroker, error) { b := &NatsBroker{ - node: n, - config: conf, - subs: make(map[channelID]*nats.Subscription), + node: n, + config: conf, + subs: make(map[channelID]subWrapper), + clientChannelPrefix: conf.Prefix + ".client.", + } + if conf.RawMode.Enabled { + log.Info().Str("rawModePrefix", conf.RawMode.Prefix).Msg("Nats raw mode enabled") + if len(conf.RawMode.ChannelReplacements) > 0 { + var replacerArgs []string + for k, v := range conf.RawMode.ChannelReplacements { + replacerArgs = append(replacerArgs, k, v) + } + b.rawModeReplacer = strings.NewReplacer(replacerArgs...) + } } return b, nil } @@ -58,7 +121,13 @@ func (b *NatsBroker) nodeChannel(nodeID string) channelID { } func (b *NatsBroker) clientChannel(ch string) channelID { - return channelID(b.config.Prefix + ".client." + ch) + if b.config.RawMode.Enabled { + if b.rawModeReplacer != nil { + ch = b.rawModeReplacer.Replace(ch) + } + return channelID(b.config.RawMode.Prefix + ch) + } + return channelID(b.clientChannelPrefix + ch) } // Run runs engine after node initialized. @@ -68,13 +137,16 @@ func (b *NatsBroker) Run(h centrifuge.BrokerEventHandler) error { if url == "" { url = nats.DefaultURL } - nc, err := nats.Connect( - url, + options := []nats.Option{ nats.ReconnectBufSize(-1), nats.MaxReconnects(-1), nats.Timeout(b.config.DialTimeout), nats.FlusherTimeout(b.config.WriteTimeout), - ) + } + if b.config.TLS != nil { + options = append(options, nats.Secure(b.config.TLS)) + } + nc, err := nats.Connect(url, options...) if err != nil { return fmt.Errorf("error connecting to %s: %w", url, err) } @@ -91,21 +163,38 @@ func (b *NatsBroker) Run(h centrifuge.BrokerEventHandler) error { return nil } -// Close is not implemented. +// Close ... func (b *NatsBroker) Close(_ context.Context) error { + b.nc.Close() return nil } -func IsUnsupportedChannel(ch string) bool { - return strings.Contains(ch, "*") || strings.Contains(ch, ">") +func (b *NatsBroker) IsSupportedSubscribeChannel(ch string) bool { + if b.config.AllowWildcards { + return true + } + if strings.Contains(ch, "*") || strings.Contains(ch, ">") { + return false + } + return true +} + +func (b *NatsBroker) IsSupportedPublishChannel(ch string) bool { + if strings.Contains(ch, "*") || strings.Contains(ch, ">") { + return false + } + return true } // Publish - see Broker interface description. func (b *NatsBroker) Publish(ch string, data []byte, opts centrifuge.PublishOptions) (centrifuge.StreamPosition, bool, error) { - if IsUnsupportedChannel(ch) { + if !b.IsSupportedPublishChannel(ch) { // Do not support wildcard subscriptions. return centrifuge.StreamPosition{}, false, centrifuge.ErrorBadRequest } + if b.config.RawMode.Enabled { + return centrifuge.StreamPosition{}, false, b.nc.Publish(b.config.RawMode.Prefix+ch, data) + } push := &protocol.Push{ Channel: ch, Pub: &protocol.Publication{ @@ -127,10 +216,14 @@ func (b *NatsBroker) Publish(ch string, data []byte, opts centrifuge.PublishOpti const epochTagsKey = "__centrifugo_epoch" func (b *NatsBroker) PublishWithStreamPosition(ch string, data []byte, opts centrifuge.PublishOptions, sp centrifuge.StreamPosition) error { - if IsUnsupportedChannel(ch) { + if !b.IsSupportedPublishChannel(ch) { // Do not support wildcard subscriptions. return centrifuge.ErrorBadRequest } + if b.config.RawMode.Enabled { + // Do not support stream positions in raw mode. + return centrifuge.ErrorBadRequest + } tags := opts.Tags if tags == nil { tags = map[string]string{} @@ -155,6 +248,9 @@ func (b *NatsBroker) PublishWithStreamPosition(ch string, data []byte, opts cent // PublishJoin - see Broker interface description. func (b *NatsBroker) PublishJoin(ch string, info *centrifuge.ClientInfo) error { + if b.config.RawMode.Enabled { + return nil + } push := &protocol.Push{ Channel: ch, Join: &protocol.Join{ @@ -170,6 +266,9 @@ func (b *NatsBroker) PublishJoin(ch string, info *centrifuge.ClientInfo) error { // PublishLeave - see Broker interface description. func (b *NatsBroker) PublishLeave(ch string, info *centrifuge.ClientInfo) error { + if b.config.RawMode.Enabled { + return nil + } push := &protocol.Push{ Channel: ch, Leave: &protocol.Leave{ @@ -204,14 +303,36 @@ func (b *NatsBroker) RemoveHistory(_ string) error { return centrifuge.ErrorNotAvailable } -func (b *NatsBroker) handleClientMessage(data []byte) { +func (b *NatsBroker) handleClientMessage(subject string, data []byte, sub *nats.Subscription) { + if b.config.RawMode.Enabled { + b.subsMu.RLock() + subWrap, ok := b.subs[channelID(sub.Subject)] + b.subsMu.RUnlock() + if !ok { + return + } + channel := subWrap.origChannel + _ = b.eventHandler.HandlePublication( + channel, + ¢rifuge.Publication{Data: data, Channel: strings.TrimPrefix(subject, b.config.RawMode.Prefix)}, + centrifuge.StreamPosition{}, false, nil) + return + } + var push protocol.Push err := push.UnmarshalVT(data) if err != nil { b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelWarn, "can't unmarshal push from Nats", map[string]any{"error": err.Error()})) return } + if push.Pub != nil { + var subChannel = push.Channel + var specificChannel string + if b.config.AllowWildcards { + subChannel = strings.TrimPrefix(sub.Subject, b.clientChannelPrefix) + specificChannel = push.Channel + } sp := centrifuge.StreamPosition{} if push.Pub.Offset > 0 && push.Pub.Tags != nil { sp.Offset = push.Pub.Offset @@ -219,7 +340,7 @@ func (b *NatsBroker) handleClientMessage(data []byte) { } delta := push.Pub.Delta push.Pub.Delta = false - _ = b.eventHandler.HandlePublication(push.Channel, pubFromProto(push.Pub), sp, delta, nil) + _ = b.eventHandler.HandlePublication(subChannel, pubFromProto(push.Pub, specificChannel), sp, delta, nil) } else if push.Join != nil { _ = b.eventHandler.HandleJoin(push.Channel, infoFromProto(push.Join.Info)) } else if push.Leave != nil { @@ -230,7 +351,7 @@ func (b *NatsBroker) handleClientMessage(data []byte) { } func (b *NatsBroker) handleClient(m *nats.Msg) { - b.handleClientMessage(m.Data) + b.handleClientMessage(m.Subject, m.Data, m.Sub) } func (b *NatsBroker) handleControl(m *nats.Msg) { @@ -239,40 +360,48 @@ func (b *NatsBroker) handleControl(m *nats.Msg) { // Subscribe - see Broker interface description. func (b *NatsBroker) Subscribe(ch string) error { - if IsUnsupportedChannel(ch) { + if !b.IsSupportedSubscribeChannel(ch) { // Do not support wildcard subscriptions. return centrifuge.ErrorBadRequest } - b.subsMu.Lock() - defer b.subsMu.Unlock() clientChannel := b.clientChannel(ch) - if _, ok := b.subs[clientChannel]; ok { + b.subsMu.RLock() + _, ok := b.subs[clientChannel] + b.subsMu.RUnlock() + if ok { return nil } - subClient, err := b.nc.Subscribe(string(b.clientChannel(ch)), b.handleClient) + subscription, err := b.nc.Subscribe(string(clientChannel), b.handleClient) if err != nil { return err } - b.subs[clientChannel] = subClient + b.subsMu.Lock() + defer b.subsMu.Unlock() + b.subs[clientChannel] = subWrapper{ + sub: subscription, + origChannel: ch, + } return nil } // Unsubscribe - see Broker interface description. func (b *NatsBroker) Unsubscribe(ch string) error { - b.subsMu.Lock() - defer b.subsMu.Unlock() - if sub, ok := b.subs[b.clientChannel(ch)]; ok { - _ = sub.Unsubscribe() - delete(b.subs, b.clientChannel(ch)) + clientChannel := b.clientChannel(ch) + b.subsMu.RLock() + subWrap, ok := b.subs[clientChannel] + b.subsMu.RUnlock() + if ok { + err := subWrap.sub.Unsubscribe() + if err != nil { + return err + } + b.subsMu.Lock() + defer b.subsMu.Unlock() + delete(b.subs, clientChannel) } return nil } -// Channels - see Broker interface description. -func (b *NatsBroker) Channels() ([]string, error) { - return nil, nil -} - func infoFromProto(v *protocol.ClientInfo) *centrifuge.ClientInfo { if v == nil { return nil @@ -307,13 +436,14 @@ func infoToProto(v *centrifuge.ClientInfo) *protocol.ClientInfo { return info } -func pubFromProto(pub *protocol.Publication) *centrifuge.Publication { +func pubFromProto(pub *protocol.Publication, specificChannel string) *centrifuge.Publication { if pub == nil { return nil } return ¢rifuge.Publication{ - Offset: pub.GetOffset(), - Data: pub.Data, - Info: infoFromProto(pub.GetInfo()), + Offset: pub.GetOffset(), + Data: pub.Data, + Info: infoFromProto(pub.GetInfo()), + Channel: specificChannel, } } diff --git a/internal/natsbroker/broker_test.go b/internal/natsbroker/broker_test.go index 8e3d03d981..18850d4899 100644 --- a/internal/natsbroker/broker_test.go +++ b/internal/natsbroker/broker_test.go @@ -1,14 +1,21 @@ +//go:build integration + package natsbroker import ( + "context" + "math/rand" "strconv" + "sync/atomic" "testing" + "time" "github.com/centrifugal/centrifuge" + "github.com/stretchr/testify/require" ) func newTestNatsBroker() *NatsBroker { - return NewTestNatsBrokerWithPrefix("centrifuge-test") + return NewTestNatsBrokerWithPrefix("centrifugo-test") } func NewTestNatsBrokerWithPrefix(prefix string) *NatsBroker { @@ -77,3 +84,173 @@ func BenchmarkNatsEngineSubscribeParallel(b *testing.B) { } }) } + +type natsTest struct { + Name string + BrokerConfig Config + // Not all configurations support join/leave messages. + TestJoinLeave bool + // For WildcardChannel case we subscribe once to a wildcard channel instead of individual channels. + WildcardChannel bool +} + +var natsTests = []natsTest{ + {"default_mode", Config{}, true, false}, + {"raw_mode", Config{RawMode: RawModeConfig{Enabled: true}}, false, false}, + {"raw_mode_wildcards", Config{AllowWildcards: true, RawMode: RawModeConfig{Enabled: true}}, false, true}, +} + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randString(n int) string { + random := rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[random.Intn(len(letterRunes))] + } + return string(b) +} + +func getUniquePrefix() string { + return "centrifugo-test-" + randString(3) + "-" + strconv.FormatInt(time.Now().UnixNano(), 10) +} + +func stopNatsBroker(b *NatsBroker) { + _ = b.Close(context.Background()) +} + +func TestNatsPubSubTwoNodes(t *testing.T) { + for _, tt := range natsTests { + t.Run(tt.Name, func(t *testing.T) { + prefix := getUniquePrefix() + tt.BrokerConfig.Prefix = prefix + tt.BrokerConfig.RawMode.Prefix = prefix + + node1, err := centrifuge.New(centrifuge.Config{}) + require.NoError(t, err) + b1, _ := New(node1, tt.BrokerConfig) + node1.SetBroker(b1) + defer func() { _ = node1.Shutdown(context.Background()) }() + defer stopNatsBroker(b1) + + msgNum := 10 + var numPublications int64 + var numJoins int64 + var numLeaves int64 + pubCh := make(chan struct{}) + joinCh := make(chan struct{}) + leaveCh := make(chan struct{}) + brokerEventHandler := &testBrokerEventHandler{ + HandleControlFunc: func(bytes []byte) error { + return nil + }, + HandlePublicationFunc: func(ch string, pub *centrifuge.Publication, sp centrifuge.StreamPosition, delta bool, prevPub *centrifuge.Publication) error { + c := atomic.AddInt64(&numPublications, 1) + if c == int64(msgNum) { + close(pubCh) + } + return nil + }, + HandleJoinFunc: func(ch string, info *centrifuge.ClientInfo) error { + c := atomic.AddInt64(&numJoins, 1) + if c == int64(msgNum) { + close(joinCh) + } + return nil + }, + HandleLeaveFunc: func(ch string, info *centrifuge.ClientInfo) error { + c := atomic.AddInt64(&numLeaves, 1) + if c == int64(msgNum) { + close(leaveCh) + } + return nil + }, + } + _ = b1.Run(brokerEventHandler) + + if tt.WildcardChannel { + require.NoError(t, b1.Subscribe("test.*")) + } else { + for i := 0; i < msgNum; i++ { + require.NoError(t, b1.Subscribe("test."+strconv.Itoa(i))) + } + } + + node2, _ := centrifuge.New(centrifuge.Config{}) + + b2, _ := New(node2, tt.BrokerConfig) + node2.SetBroker(b2) + _ = node2.Run() + defer func() { _ = node2.Shutdown(context.Background()) }() + defer stopNatsBroker(b2) + + for i := 0; i < msgNum; i++ { + _, err := node2.Publish("test."+strconv.Itoa(i), []byte("123")) + require.NoError(t, err) + if tt.TestJoinLeave { + err = b2.PublishJoin("test."+strconv.Itoa(i), ¢rifuge.ClientInfo{}) + require.NoError(t, err) + err = b2.PublishLeave("test."+strconv.Itoa(i), ¢rifuge.ClientInfo{}) + require.NoError(t, err) + } + } + + select { + case <-pubCh: + case <-time.After(time.Second): + require.Fail(t, "timeout waiting for PUB/SUB message") + } + if tt.TestJoinLeave { + select { + case <-joinCh: + case <-time.After(time.Second): + require.Fail(t, "timeout waiting for PUB/SUB join message") + } + select { + case <-leaveCh: + case <-time.After(time.Second): + require.Fail(t, "timeout waiting for PUB/SUB leave message") + } + } + }) + } +} + +type testBrokerEventHandler struct { + // Publication must register callback func to handle Publications received. + HandlePublicationFunc func(ch string, pub *centrifuge.Publication, sp centrifuge.StreamPosition, delta bool, prevPub *centrifuge.Publication) error + // Join must register callback func to handle Join messages received. + HandleJoinFunc func(ch string, info *centrifuge.ClientInfo) error + // Leave must register callback func to handle Leave messages received. + HandleLeaveFunc func(ch string, info *centrifuge.ClientInfo) error + // Control must register callback func to handle Control data received. + HandleControlFunc func([]byte) error +} + +func (b *testBrokerEventHandler) HandlePublication(ch string, pub *centrifuge.Publication, sp centrifuge.StreamPosition, delta bool, prevPub *centrifuge.Publication) error { + if b.HandlePublicationFunc != nil { + return b.HandlePublicationFunc(ch, pub, sp, delta, prevPub) + } + return nil +} + +func (b *testBrokerEventHandler) HandleJoin(ch string, info *centrifuge.ClientInfo) error { + if b.HandleJoinFunc != nil { + return b.HandleJoinFunc(ch, info) + } + return nil +} + +func (b *testBrokerEventHandler) HandleLeave(ch string, info *centrifuge.ClientInfo) error { + if b.HandleLeaveFunc != nil { + return b.HandleLeaveFunc(ch, info) + } + return nil +} + +func (b *testBrokerEventHandler) HandleControl(data []byte) error { + if b.HandleControlFunc != nil { + return b.HandleControlFunc(data) + } + return nil +} diff --git a/internal/proxy/grpc.go b/internal/proxy/grpc.go index a1266a6e52..8e8ab596e5 100644 --- a/internal/proxy/grpc.go +++ b/internal/proxy/grpc.go @@ -56,7 +56,13 @@ func getDialOpts(p Config) ([]grpc.DialOption, error) { value: p.GrpcCredentialsValue, })) } - if p.GrpcCertFile != "" { + if p.GrpcTLS.Enabled { + tlsConfig, err := p.GrpcTLS.ToGoTLSConfig() + if err != nil { + return nil, fmt.Errorf("failed to create TLS config %v", err) + } + dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + } else if p.GrpcCertFile != "" { cred, err := credentials.NewClientTLSFromFile(p.GrpcCertFile, "") if err != nil { return nil, fmt.Errorf("failed to create TLS credentials %v", err) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 55fc18e6fa..73935b7b50 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -36,6 +36,8 @@ type Config struct { // IncludeConnectionMeta to each proxy request (except connect where it's obtained). IncludeConnectionMeta bool `mapstructure:"include_connection_meta" json:"include_connection_meta,omitempty"` + // GrpcTLS is a common configuration for GRPC TLS. + GrpcTLS tools.TLSConfig `mapstructure:"grpc_tls" json:"grpc_tls,omitempty"` // GrpcCertFile is a path to GRPC cert file on disk. GrpcCertFile string `mapstructure:"grpc_cert_file" json:"grpc_cert_file,omitempty"` // GrpcCredentialsKey is a custom key to add into per-RPC credentials. diff --git a/internal/proxy/unknown_keys.go b/internal/proxy/unknown_keys.go index 0c982df96f..7bf67966a9 100644 --- a/internal/proxy/unknown_keys.go +++ b/internal/proxy/unknown_keys.go @@ -25,5 +25,13 @@ func WarnUnknownProxyKeys(jsonProxies []byte) { } log.Warn().Str("key", key).Any("proxy_name", jsonMap["name"]).Msg("unknown key found in the proxy object") } + tls, ok := jsonMap["grpc_tls"].(map[string]any) + if ok { + var TLSConfig tools.TLSConfig + unknownKeys := tools.FindUnknownKeys(tls, TLSConfig) + for _, key := range unknownKeys { + log.Warn().Str("key", key).Any("proxy_name", jsonMap["name"]).Msg("unknown key found in the proxy tls config object") + } + } } } diff --git a/internal/redisnatsbroker/broker.go b/internal/redisnatsbroker/broker.go index 230b79423c..2f510571c8 100644 --- a/internal/redisnatsbroker/broker.go +++ b/internal/redisnatsbroker/broker.go @@ -26,7 +26,7 @@ func New(nats *natsbroker.NatsBroker, redis *centrifuge.RedisBroker) (*Broker, e } func (b *Broker) Publish(ch string, data []byte, opts centrifuge.PublishOptions) (centrifuge.StreamPosition, bool, error) { - if natsbroker.IsUnsupportedChannel(ch) { + if !b.NatsBroker.IsSupportedPublishChannel(ch) { // Do not support wildcard subscriptions just like natsbroker.NatsBroker. return centrifuge.StreamPosition{}, false, centrifuge.ErrorBadRequest } diff --git a/internal/tools/tls.go b/internal/tools/tls.go index e0a32ec38d..fcc957d165 100644 --- a/internal/tools/tls.go +++ b/internal/tools/tls.go @@ -226,7 +226,7 @@ type TLSOptions struct { ClientCAPem string `mapstructure:"tls_client_ca_pem" json:"tls_client_ca_pem"` InsecureSkipVerify bool `mapstructure:"tls_insecure_skip_verify" json:"tls_insecure_skip_verify"` - ServerName string `mapstructure:"server_name" json:"server_name"` + ServerName string `mapstructure:"tls_server_name" json:"tls_server_name"` } func (t TLSOptions) ToMap() (TLSOptionsMap, error) { diff --git a/internal/tools/tls_v2.go b/internal/tools/tls_v2.go new file mode 100644 index 0000000000..50c7db009f --- /dev/null +++ b/internal/tools/tls_v2.go @@ -0,0 +1,241 @@ +package tools + +import ( + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/centrifugal/centrifugo/v5/internal/envconfig" + + "github.com/FZambia/viper-lite" + "github.com/hashicorp/go-envparse" + "github.com/rs/zerolog/log" +) + +// TLSConfig is a common configuration for TLS. +// It allows to configure TLS settings using different sources. The order sources are used is the following: +// 1. File to PEM +// 2. Base64 encoded PEM +// 3. Raw PEM +// It's up to the user to only use a single source of configured values. I.e. if both file and raw PEM are set +// the file will be used and raw PEM will be just ignored. +type TLSConfig struct { + // Enabled turns on using TLS. + Enabled bool `mapstructure:"enabled" json:"enabled"` + + // CertPem is a certificate in PEM format. + CertPem string `mapstructure:"cert_pem" json:"cert_pem" envconfig:"cert_pem"` + // CertPemB64 is a certificate in base64 encoded PEM format. + CertPemB64 string `mapstructure:"cert_pem_b64" json:"cert_pem_b64" envconfig:"cert_pem_b64"` + // CertPemFile is a path to a file with certificate in PEM format. + CertPemFile string `mapstructure:"cert_pem_file" json:"cert_pem_file" envconfig:"cert_pem_file"` + + // KeyPem is a key in PEM format. + KeyPem string `mapstructure:"key_pem" json:"key_pem" envconfig:"key_pem"` + // KeyPemB64 is a key in base64 encoded PEM format. + KeyPemB64 string `mapstructure:"key_pem_b64" json:"key_pem_b64" envconfig:"key_pem_b64"` + // KeyPemFile is a path to a file with key in PEM format. + KeyPemFile string `mapstructure:"key_pem_file" json:"key_pem_file" envconfig:"key_pem_file"` + + // ServerCAPem is a server root CA certificate in PEM format. + // The client uses this certificate to verify the server's certificate during the TLS handshake. + ServerCAPem string `mapstructure:"server_ca_pem" json:"server_ca_pem" envconfig:"server_ca_pem"` + // ServerCAPemB64 is a server root CA certificate in base64 encoded PEM format. + ServerCAPemB64 string `mapstructure:"server_ca_pem_b64" json:"server_ca_pem_b64" envconfig:"server_ca_pem_b64"` + // ServerCAPemFile is a path to a file with server root CA certificate in PEM format. + ServerCAPemFile string `mapstructure:"server_ca_pem_file" json:"server_ca_pem_file" envconfig:"server_ca_pem_file"` + + // ClientCAPem is a client CA certificate in PEM format. + // The server uses this certificate to verify the client's certificate during the TLS handshake. + ClientCAPem string `mapstructure:"client_ca_pem" json:"client_ca_pem" envconfig:"client_ca_pem"` + // ClientCAPemB64 is a client CA certificate in base64 encoded PEM format. + ClientCAPemB64 string `mapstructure:"client_ca_pem_b64" json:"client_ca_pem_b64" envconfig:"client_ca_pem_b64"` + // ClientCAPemFile is a path to a file with client CA certificate in PEM format. + ClientCAPemFile string `mapstructure:"client_ca_pem_file" json:"client_ca_pem_file" envconfig:"client_ca_pem_file"` + + // InsecureSkipVerify turns off server certificate verification. + InsecureSkipVerify bool `mapstructure:"insecure_skip_verify" json:"insecure_skip_verify" envconfig:"insecure_skip_verify"` + // ServerName is used to verify the hostname on the returned certificates. + ServerName string `mapstructure:"server_name" json:"server_name" envconfig:"server_name"` +} + +func (c TLSConfig) ToMap() (TLSOptionsMap, error) { + var m TLSOptionsMap + jsonData, _ := json.Marshal(m) + err := json.Unmarshal(jsonData, &m) + return m, err +} + +func (c TLSConfig) ToGoTLSConfig() (*tls.Config, error) { + if !c.Enabled { + return nil, nil + } + return makeTLSConfig(c, os.ReadFile) +} + +// makeTLSConfig constructs a tls.Config instance using the given configuration. +func makeTLSConfig(cfg TLSConfig, readFile ReadFileFunc) (*tls.Config, error) { + tlsConfig := &tls.Config{} + + if cfg.CertPemFile != "" && cfg.KeyPemFile != "" { + certPEMBlock, err := readFile(cfg.CertPemFile) + if err != nil { + return nil, fmt.Errorf("read TLS certificate for %s: %w", cfg.CertPemFile, err) + } + keyPEMBlock, err := readFile(cfg.KeyPemFile) + if err != nil { + return nil, fmt.Errorf("read TLS key for %s: %w", cfg.KeyPemFile, err) + } + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return nil, fmt.Errorf("parse certificate/key pair for %s/%s: %w", cfg.CertPemFile, cfg.KeyPemFile, err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } else if cfg.CertPemB64 != "" && cfg.KeyPemB64 != "" { + certPem, err := base64.StdEncoding.DecodeString(cfg.CertPemB64) + if err != nil { + return nil, fmt.Errorf("error base64 decode certificate PEM: %w", err) + } + keyPem, err := base64.StdEncoding.DecodeString(cfg.KeyPemB64) + if err != nil { + return nil, fmt.Errorf("error base64 decode key PEM: %w", err) + } + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + return nil, fmt.Errorf("error parse certificate/key pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } else if cfg.CertPem != "" && cfg.KeyPem != "" { + cert, err := tls.X509KeyPair([]byte(cfg.CertPem), []byte(cfg.KeyPem)) + if err != nil { + return nil, fmt.Errorf("error parse certificate/key pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + if cfg.ServerCAPemFile != "" { + caCert, err := readFile(cfg.ServerCAPemFile) + if err != nil { + return nil, fmt.Errorf("read the root CA certificate for %s: %w", cfg.ServerCAPemFile, err) + } + caCertPool, err := newCertPoolFromPEM(caCert) + if err != nil { + return nil, fmt.Errorf("error parse root CA certificate: %w", err) + } + tlsConfig.RootCAs = caCertPool + } else if cfg.ServerCAPemB64 != "" { + caCert, err := base64.StdEncoding.DecodeString(cfg.ServerCAPemB64) + if err != nil { + return nil, fmt.Errorf("error base64 decode root CA PEM: %w", err) + } + caCertPool, err := newCertPoolFromPEM(caCert) + if err != nil { + return nil, fmt.Errorf("error parse root CA certificate: %w", err) + } + tlsConfig.RootCAs = caCertPool + } else if cfg.ServerCAPem != "" { + caCertPool, err := newCertPoolFromPEM([]byte(cfg.ServerCAPem)) + if err != nil { + return nil, fmt.Errorf("error parse root CA certificate: %w", err) + } + tlsConfig.RootCAs = caCertPool + } + + if cfg.ClientCAPemFile != "" { + caCert, err := readFile(cfg.ClientCAPemFile) + if err != nil { + return nil, fmt.Errorf("read the client CA certificate for %s: %w", cfg.ClientCAPemFile, err) + } + caCertPool, err := newCertPoolFromPEM(caCert) + if err != nil { + return nil, fmt.Errorf("error parse client CA certificate: %w", err) + } + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } else if cfg.ClientCAPemB64 != "" { + caCert, err := base64.StdEncoding.DecodeString(cfg.ClientCAPemB64) + if err != nil { + return nil, fmt.Errorf("error base64 decode client CA PEM: %w", err) + } + caCertPool, err := newCertPoolFromPEM(caCert) + if err != nil { + return nil, fmt.Errorf("error parse client CA certificate: %w", err) + } + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } else if cfg.ClientCAPem != "" { + caCertPool, err := newCertPoolFromPEM([]byte(cfg.ClientCAPem)) + if err != nil { + return nil, fmt.Errorf("error parse client CA certificate: %w", err) + } + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + tlsConfig.ServerName = cfg.ServerName + tlsConfig.InsecureSkipVerify = cfg.InsecureSkipVerify + + return tlsConfig, nil +} + +// ExtractTLSConfig extracts TLS configuration from Viper instance and applies environment variables. +func ExtractTLSConfig(v *viper.Viper, key string) (TLSConfig, error) { + var cfg TLSConfig + err := v.UnmarshalKey(key, &cfg) + if err != nil { + return cfg, err + } + prefix := "CENTRIFUGO_" + strings.ToUpper(key) + varInfo, err := envconfig.Process(prefix, &cfg) + if err != nil { + return cfg, err + } + checkEnvironmentVarInfo(prefix+"_", varInfo) + return cfg, nil +} + +// ExtractGoTLSConfig is a helper to ExtractTLSConfig and then convert it to *tls.Config. +func ExtractGoTLSConfig(v *viper.Viper, key string) (*tls.Config, error) { + cfg, err := ExtractTLSConfig(v, key) + if err != nil { + return nil, fmt.Errorf("extract TLS config: %w", err) + } + return cfg.ToGoTLSConfig() +} + +func checkEnvironmentVarInfo(envPrefix string, varInfo []envconfig.VarInfo) { + envVars := os.Environ() + + defaults := make(map[string]interface{}) + for _, info := range varInfo { + defaults[info.Key] = "" + } + + for _, envVar := range envVars { + kv, err := envparse.Parse(strings.NewReader(envVar)) + if err != nil { + continue + } + for envKey := range kv { + if !strings.HasPrefix(envKey, envPrefix) { + continue + } + // Kubernetes automatically adds some variables which are not used by Centrifugo + // itself. We skip warnings about them. + if isKubernetesEnvVar(envKey) { + continue + } + if !isKnownEnv(defaults, envKey) { + log.Warn().Str("key", envKey).Msg("unknown key found in the environment") + } + } + } +} + +func isKnownEnv(defaults map[string]any, envKey string) bool { + _, ok := defaults[envKey] + return ok +} diff --git a/internal/tools/unknown_keys.go b/internal/tools/unknown_keys.go index 516472e48e..02ab0c8ce7 100644 --- a/internal/tools/unknown_keys.go +++ b/internal/tools/unknown_keys.go @@ -66,6 +66,7 @@ func CheckPlainConfigKeys(defaults map[string]any, allKeys []string) { // allow arbitrary keys for maps we have this slice of such configuration options here. var mapStringStringKeys = []string{ "proxy_static_http_headers", + "nats_raw_mode.channel_replacements", } func isMapStringStringKey(key string) bool { diff --git a/main.go b/main.go index 0b6d2a6f46..79037e26d7 100644 --- a/main.go +++ b/main.go @@ -291,6 +291,7 @@ var defaults = map[string]any{ "proxy_include_connection_meta": false, "proxy_grpc_cert_file": "", "proxy_grpc_compression": false, + "proxy_grpc_tls": tools.TLSConfig{}, "tarantool_mode": "standalone", "tarantool_address": "tcp://127.0.0.1:3301", @@ -342,10 +343,15 @@ var defaults = map[string]any{ "graphite_interval": 10 * time.Second, "graphite_tags": false, - "nats_prefix": "centrifugo", - "nats_url": "nats://127.0.0.1:4222", - "nats_dial_timeout": time.Second, - "nats_write_timeout": time.Second, + "nats_prefix": "centrifugo", + "nats_url": "nats://127.0.0.1:4222", + "nats_dial_timeout": time.Second, + "nats_write_timeout": time.Second, + "nats_allow_wildcards": false, + + "nats_raw_mode.enabled": false, + "nats_raw_mode.channel_replacements": map[string]string{}, + "nats_raw_mode.prefix": "", "websocket_disable": false, "api_disable": false, @@ -458,6 +464,32 @@ func init() { defaults[k] = v } } + tlsConfigPrefixes := []string{ + "nats_tls.", + "proxy_grpc_tls.", + } + for _, prefix := range tlsConfigPrefixes { + keyMap := map[string]any{ + prefix + "enabled": false, + prefix + "cert_pem": "", + prefix + "cert_pem_file": "", + prefix + "cert_pem_b64": "", + prefix + "key_pem": "", + prefix + "key_pem_file": "", + prefix + "key_pem_b64": "", + prefix + "server_ca_pem": "", + prefix + "server_ca_pem_file": "", + prefix + "server_ca_pem_b64": "", + prefix + "client_ca_pem": "", + prefix + "client_ca_pem_file": "", + prefix + "client_ca_pem_b64": "", + prefix + "server_name": "", + prefix + "insecure_skip_verify": false, + } + for k, v := range keyMap { + defaults[k] = v + } + } } func bindCentrifugoConfig() { @@ -1917,10 +1949,16 @@ func proxyMapConfig() (*client.ProxyMap, bool) { SubscribeStreamProxies: map[string]*proxy.SubscribeStreamProxy{}, } + tlsConfig, err := tools.ExtractTLSConfig(viper.GetViper(), "proxy_grpc_tls") + if err != nil { + log.Fatal().Msgf("error extracting TLS config for proxy GRPC: %v", err) + } + proxyConfig := proxy.Config{ BinaryEncoding: v.GetBool("proxy_binary_encoding"), IncludeConnectionMeta: v.GetBool("proxy_include_connection_meta"), GrpcCertFile: v.GetString("proxy_grpc_cert_file"), + GrpcTLS: tlsConfig, GrpcCredentialsKey: v.GetString("proxy_grpc_credentials_key"), GrpcCredentialsValue: v.GetString("proxy_grpc_credentials_value"), GrpcMetadata: v.GetStringSlice("proxy_grpc_metadata"), @@ -2599,11 +2637,26 @@ func getRedisShards(n *centrifuge.Node) ([]*centrifuge.RedisShard, string, error } func initNatsBroker(node *centrifuge.Node) (*natsbroker.NatsBroker, error) { + replacements, err := tools.MapStringString(viper.GetViper(), "nats_raw_mode.channel_replacements") + if err != nil { + return nil, fmt.Errorf("error parsing nats_raw_mode_channel_replacements: %v", err) + } + tlsConfig, err := tools.ExtractGoTLSConfig(viper.GetViper(), "nats_tls") + if err != nil { + return nil, fmt.Errorf("error configuring nats tls: %v", err) + } return natsbroker.New(node, natsbroker.Config{ - URL: viper.GetString("nats_url"), - Prefix: viper.GetString("nats_prefix"), - DialTimeout: GetDuration("nats_dial_timeout"), - WriteTimeout: GetDuration("nats_write_timeout"), + URL: viper.GetString("nats_url"), + Prefix: viper.GetString("nats_prefix"), + DialTimeout: GetDuration("nats_dial_timeout"), + WriteTimeout: GetDuration("nats_write_timeout"), + AllowWildcards: viper.GetBool("nats_allow_wildcards"), + TLS: tlsConfig, + RawMode: natsbroker.RawModeConfig{ + Enabled: viper.GetBool("nats_raw_mode.enabled"), + Prefix: viper.GetString("nats_raw_mode.prefix"), + ChannelReplacements: replacements, + }, }) }