diff --git a/internal/flag/context.go b/internal/flag/context.go index a3fb01beea..48c8738dbf 100644 --- a/internal/flag/context.go +++ b/internal/flag/context.go @@ -3,10 +3,13 @@ package flag import ( "context" "slices" + "strconv" "strings" "time" + "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/superfly/flyctl/internal/command_context" "github.com/superfly/flyctl/internal/env" "github.com/superfly/flyctl/internal/flag/flagctx" "github.com/superfly/flyctl/internal/flag/flagnames" @@ -38,12 +41,40 @@ func FirstArg(ctx context.Context) string { return "" } +func EnvNameFromCmd(cmd *cobra.Command) string { + if cmd.Parent() != nil { + varname := EnvNameFromCmd(cmd.Parent()) + "_" + cmd.Name() + return strings.ToUpper(varname) + } else { + return strings.ToUpper(cmd.Name()) + } + +} + +func FromEnv(ctx context.Context, name string) string { + cmd := command_context.FromContext(ctx) + value := "" + for cmd != nil { + var_name := EnvNameFromCmd(cmd) + "_" + strings.ToUpper(name) + var_name = strings.ReplaceAll(var_name, "-", "_") + value = env.First(var_name) + if value == "" { + cmd = cmd.Parent() + } else { + return value + } + } + return value +} + // GetString returns the value of the named string flag ctx carries. func GetString(ctx context.Context, name string) string { - if v, err := FromContext(ctx).GetString(name); err != nil { - return "" - } else { + if v, err := FromContext(ctx).GetString(name); err == nil && v != "" { return v + } else if v := FromEnv(ctx, name); v != "" { + return v + } else { + return "" } } @@ -55,20 +86,32 @@ func SetString(ctx context.Context, name, value string) error { // GetInt returns the value of the named int flag ctx carries. It panics // in case ctx carries no flags or in case the named flag isn't an int one. func GetInt(ctx context.Context, name string) int { - if v, err := FromContext(ctx).GetInt(name); err != nil { - panic(err) - } else { + if v, err := FromContext(ctx).GetInt(name); err == nil { return v + } else if v := FromEnv(ctx, name); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return i + } else { + panic(err) + } + } else { + panic(err) } } // GetFloat64 returns the value of the named int flag ctx carries. It panics // in case ctx carries no flags or in case the named flag isn't a float64 one. func GetFloat64(ctx context.Context, name string) float64 { - if v, err := FromContext(ctx).GetFloat64(name); err != nil { - panic(err) - } else { + if v, err := FromContext(ctx).GetFloat64(name); err == nil { return v + } else if v := FromEnv(ctx, name); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } else { + panic(err) + } + } else { + panic(err) } } @@ -76,20 +119,24 @@ func GetFloat64(ctx context.Context, name string) float64 { // Preserves commas (unlike the following `GetStringSlice`): in `--flag x,y` the value is string[]{`x,y`}. // This is useful to pass key-value pairs like environment variables or build arguments. func GetStringArray(ctx context.Context, name string) []string { - if v, err := FromContext(ctx).GetStringArray(name); err != nil { - return []string{} - } else { + if v, err := FromContext(ctx).GetStringArray(name); err == nil { return v + } else if v := FromEnv(ctx, name); v != "" { + return []string{v} + } else { + return []string{} } } // GetStringSlice returns the values of the named string flag ctx carries. // Can be comma separated or passed "by repeated flags": `--flag x,y` is equivalent to `--flag x --flag y`. func GetStringSlice(ctx context.Context, name string) []string { - if v, err := FromContext(ctx).GetStringSlice(name); err != nil { - return []string{} - } else { + if v, err := FromContext(ctx).GetStringSlice(name); err == nil { return v + } else if v := FromEnv(ctx, name); v != "" { + return strings.Split(v, ",") + } else { + return []string{} } } @@ -110,20 +157,30 @@ func GetNonEmptyStringSlice(ctx context.Context, name string) []string { // GetDuration returns the value of the named duration flag ctx carries. func GetDuration(ctx context.Context, name string) time.Duration { - if v, err := FromContext(ctx).GetDuration(name); err != nil { - return 0 - } else { + if v, err := FromContext(ctx).GetDuration(name); err == nil { return v + } else if v := FromEnv(ctx, name); v != "" { + if d, err := time.ParseDuration(v); err == nil { + return d + } else { + return 0 + } + } else { + return 0 } } // GetBool returns the value of the named boolean flag ctx carries. func GetBool(ctx context.Context, name string) bool { - if v, err := FromContext(ctx).GetBool(name); err != nil { - return false - } else { + isSpecified := IsSpecified(ctx, name) + if v, err := FromContext(ctx).GetBool(name); err == nil && isSpecified { return v + } else if v := FromEnv(ctx, name); v != "" { + if b, err := strconv.ParseBool(v); err == nil { + return b + } } + return false } // IsSpecified returns whether a flag has been specified at all or not. @@ -136,9 +193,6 @@ func IsSpecified(ctx context.Context, name string) bool { // GetOrg is shorthand for GetString(ctx, Org). func GetOrg(ctx context.Context) string { org := GetString(ctx, flagnames.Org) - if org == "" { - org = env.First("FLY_ORG") - } return org }