diff --git a/internal/cli/service.go b/internal/cli/service.go index 54660f6..3c84a62 100644 --- a/internal/cli/service.go +++ b/internal/cli/service.go @@ -10,6 +10,7 @@ import ( "github.com/rwx-research/mint-cli/internal/accesstoken" "github.com/rwx-research/mint-cli/internal/api" + "github.com/rwx-research/mint-cli/internal/dotenv" "github.com/rwx-research/mint-cli/internal/errors" "github.com/briandowns/spinner" @@ -274,20 +275,13 @@ func (s Service) SetSecretsInVault(cfg SetSecretsInVaultConfig) error { return errors.Wrapf(err, "error while reading %q", cfg.File) } - fileLines := strings.Split(string(fileContent), "\n") + dotenvMap := make(map[string]string) + err = dotenv.ParseBytes(fileContent, dotenvMap) + if err != nil { + return errors.Wrapf(err, "error while parsing %q", cfg.File) + } - for i := range fileLines { - if fileLines[i] == "" { - continue - } - key, value, found := strings.Cut(fileLines[i], "=") - if !found { - return errors.New(fmt.Sprintf("Invalid secret '%s' in file %s. Secrets must be specified in the form 'KEY=value'.", cfg.Secrets[i], cfg.File)) - } - // If a line is like ABC="def", we need to strip off the leading and trailing quotation mark - if strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"") { - value = value[1 : len(value)-1] - } + for key, value := range dotenvMap { secrets = append(secrets, api.Secret{ Name: key, Secret: value, diff --git a/internal/cli/service_test.go b/internal/cli/service_test.go index bbc3cdd..e5f1b14 100644 --- a/internal/cli/service_test.go +++ b/internal/cli/service_test.go @@ -1,6 +1,8 @@ package cli_test import ( + "sort" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -842,19 +844,26 @@ AAAEC6442PQKevgYgeT0SIu9zwlnEMl6MF59ZgM+i0ByMv4eLJPqG3xnZcEQmktHj/GY2i Context("when reading secrets from a file", func() { BeforeEach(func() { mockAPI.MockSetSecretsInVault = func(ssivc api.SetSecretsInVaultConfig) (*api.SetSecretsInVaultResult, error) { + sort.Slice(ssivc.Secrets, func(i, j int) bool { + return ssivc.Secrets[i].Name < ssivc.Secrets[j].Name + }) Expect(ssivc.VaultName).To(Equal("default")) - Expect(ssivc.Secrets[0].Name).To(Equal("ABC")) + Expect(ssivc.Secrets[0].Name).To(Equal("A")) Expect(ssivc.Secrets[0].Secret).To(Equal("123")) - Expect(ssivc.Secrets[1].Name).To(Equal("DEF")) + Expect(ssivc.Secrets[1].Name).To(Equal("B")) Expect(ssivc.Secrets[1].Secret).To(Equal("xyz")) + Expect(ssivc.Secrets[2].Name).To(Equal("C")) + Expect(ssivc.Secrets[2].Secret).To(Equal("q\\nqq")) + Expect(ssivc.Secrets[3].Name).To(Equal("D")) + Expect(ssivc.Secrets[3].Secret).To(Equal("a multiline\nstring\nspanning lines")) return &api.SetSecretsInVaultResult{ - SetSecrets: []string{"ABC","DEF"}, + SetSecrets: []string{"A", "B", "C", "D"}, }, nil } mockFS.MockOpen = func(name string) (fs.File, error) { Expect(name).To(Equal("secrets.txt")) - file := mocks.NewFile("ABC=123\nDEF=\"xyz\"\n") + file := mocks.NewFile("A=123\nB=\"xyz\"\nC='q\\nqq'\nD=\"a multiline\nstring\nspanning lines\"") return file, nil } }) @@ -868,7 +877,7 @@ AAAEC6442PQKevgYgeT0SIu9zwlnEMl6MF59ZgM+i0ByMv4eLJPqG3xnZcEQmktHj/GY2i }) Expect(err).NotTo(HaveOccurred()) - Expect(stdout.String()).To(Equal("\nSuccessfully set the following secrets: ABC, DEF")) + Expect(stdout.String()).To(Equal("\nSuccessfully set the following secrets: A, B, C, D")) }) }) }) diff --git a/internal/dotenv/parser.go b/internal/dotenv/parser.go new file mode 100644 index 0000000..392dcb6 --- /dev/null +++ b/internal/dotenv/parser.go @@ -0,0 +1,273 @@ +// This is copied from https://github.com/joho/godotenv/blob/e3b6eee84d15b0fa274078565ca46591f5e08876/parser.go + +package dotenv + +import ( + "bytes" + "errors" + "fmt" + "regexp" + "strings" + "unicode" +) + +const ( + charComment = '#' + prefixSingleQuote = '\'' + prefixDoubleQuote = '"' + + exportPrefix = "export" +) + +func ParseBytes(src []byte, out map[string]string) error { + src = bytes.Replace(src, []byte("\r\n"), []byte("\n"), -1) + cutset := src + for { + cutset = getStatementStart(cutset) + if cutset == nil { + // reached end of file + break + } + + key, left, err := locateKeyName(cutset) + if err != nil { + return err + } + + value, left, err := extractVarValue(left, out) + if err != nil { + return err + } + + out[key] = value + cutset = left + } + + return nil +} + +// getStatementPosition returns position of statement begin. +// +// It skips any comment line or non-whitespace character. +func getStatementStart(src []byte) []byte { + pos := indexOfNonSpaceChar(src) + if pos == -1 { + return nil + } + + src = src[pos:] + if src[0] != charComment { + return src + } + + // skip comment section + pos = bytes.IndexFunc(src, isCharFunc('\n')) + if pos == -1 { + return nil + } + + return getStatementStart(src[pos:]) +} + +// locateKeyName locates and parses key name and returns rest of slice +func locateKeyName(src []byte) (key string, cutset []byte, err error) { + // trim "export" and space at beginning + src = bytes.TrimLeftFunc(src, isSpace) + if bytes.HasPrefix(src, []byte(exportPrefix)) { + trimmed := bytes.TrimPrefix(src, []byte(exportPrefix)) + if bytes.IndexFunc(trimmed, isSpace) == 0 { + src = bytes.TrimLeftFunc(trimmed, isSpace) + } + } + + // locate key name end and validate it in single loop + offset := 0 +loop: + for i, char := range src { + rchar := rune(char) + if isSpace(rchar) { + continue + } + + switch char { + case '=', ':': + // library also supports yaml-style value declaration + key = string(src[0:i]) + offset = i + 1 + break loop + case '_': + default: + // variable name should match [A-Za-z0-9_.] + if unicode.IsLetter(rchar) || unicode.IsNumber(rchar) || rchar == '.' { + continue + } + + return "", nil, fmt.Errorf( + `unexpected character %q in variable name near %q`, + string(char), string(src)) + } + } + + if len(src) == 0 { + return "", nil, errors.New("zero length string") + } + + // trim whitespace + key = strings.TrimRightFunc(key, unicode.IsSpace) + cutset = bytes.TrimLeftFunc(src[offset:], isSpace) + return key, cutset, nil +} + +// extractVarValue extracts variable value and returns rest of slice +func extractVarValue(src []byte, vars map[string]string) (value string, rest []byte, err error) { + quote, hasPrefix := hasQuotePrefix(src) + if !hasPrefix { + // unquoted value - read until end of line + endOfLine := bytes.IndexFunc(src, isLineEnd) + + // Hit EOF without a trailing newline + if endOfLine == -1 { + endOfLine = len(src) + + if endOfLine == 0 { + return "", nil, nil + } + } + + // Convert line to rune away to do accurate countback of runes + line := []rune(string(src[0:endOfLine])) + + // Assume end of line is end of var + endOfVar := len(line) + if endOfVar == 0 { + return "", src[endOfLine:], nil + } + + // Work backwards to check if the line ends in whitespace then + // a comment (ie asdasd # some comment) + for i := endOfVar - 1; i >= 0; i-- { + if line[i] == charComment && i > 0 { + if isSpace(line[i-1]) { + endOfVar = i + break + } + } + } + + trimmed := strings.TrimFunc(string(line[0:endOfVar]), isSpace) + + return expandVariables(trimmed, vars), src[endOfLine:], nil + } + + // lookup quoted string terminator + for i := 1; i < len(src); i++ { + if char := src[i]; char != quote { + continue + } + + // skip escaped quote symbol (\" or \', depends on quote) + if prevChar := src[i-1]; prevChar == '\\' { + continue + } + + // trim quotes + trimFunc := isCharFunc(rune(quote)) + value = string(bytes.TrimLeftFunc(bytes.TrimRightFunc(src[0:i], trimFunc), trimFunc)) + if quote == prefixDoubleQuote { + // unescape newlines for double quote (this is compat feature) + // and expand environment variables + value = expandVariables(expandEscapes(value), vars) + } + + return value, src[i+1:], nil + } + + // return formatted error if quoted string is not terminated + valEndIndex := bytes.IndexFunc(src, isCharFunc('\n')) + if valEndIndex == -1 { + valEndIndex = len(src) + } + + return "", nil, fmt.Errorf("unterminated quoted value %s", src[:valEndIndex]) +} + +func expandEscapes(str string) string { + out := escapeRegex.ReplaceAllStringFunc(str, func(match string) string { + c := strings.TrimPrefix(match, `\`) + switch c { + case "n": + return "\n" + case "r": + return "\r" + default: + return match + } + }) + return unescapeCharsRegex.ReplaceAllString(out, "$1") +} + +func indexOfNonSpaceChar(src []byte) int { + return bytes.IndexFunc(src, func(r rune) bool { + return !unicode.IsSpace(r) + }) +} + +// hasQuotePrefix reports whether charset starts with single or double quote and returns quote character +func hasQuotePrefix(src []byte) (prefix byte, isQuored bool) { + if len(src) == 0 { + return 0, false + } + + switch prefix := src[0]; prefix { + case prefixDoubleQuote, prefixSingleQuote: + return prefix, true + default: + return 0, false + } +} + +func isCharFunc(char rune) func(rune) bool { + return func(v rune) bool { + return v == char + } +} + +// isSpace reports whether the rune is a space character but not line break character +// +// this differs from unicode.IsSpace, which also applies line break as space +func isSpace(r rune) bool { + switch r { + case '\t', '\v', '\f', '\r', ' ', 0x85, 0xA0: + return true + } + return false +} + +func isLineEnd(r rune) bool { + if r == '\n' || r == '\r' { + return true + } + return false +} + +var ( + escapeRegex = regexp.MustCompile(`\\.`) + expandVarRegex = regexp.MustCompile(`(\\)?(\$)(\()?\{?([A-Z0-9_]+)?\}?`) + unescapeCharsRegex = regexp.MustCompile(`\\([^$])`) +) + +func expandVariables(v string, m map[string]string) string { + return expandVarRegex.ReplaceAllStringFunc(v, func(s string) string { + submatch := expandVarRegex.FindStringSubmatch(s) + + if submatch == nil { + return s + } + if submatch[1] == "\\" || submatch[2] == "(" { + return submatch[0][1:] + } else if submatch[4] != "" { + return m[submatch[4]] + } + return s + }) +}