diff --git a/cmd/initialize/init_policies.go b/cmd/initialize/init_policies.go index 12ed1c9ce..798aa621d 100644 --- a/cmd/initialize/init_policies.go +++ b/cmd/initialize/init_policies.go @@ -55,21 +55,7 @@ func initPoliciesCmd() *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - fs := utils.FS(ctx) - workDir := destDir - err := fs.MkdirAll(workDir, 0755) - if err != nil { - log.Debug("Failed to create policy directory!") - return err - } - policyPath := filepath.Join(workDir, "sample.rego") - file, err := fs.Create(policyPath) - if err != nil { - log.Debug("Failed to create sample policy!") - return err - } - defer file.Close() - fmt.Fprintf(file, "%s", hd.Doc(` + samplePolicy := hd.Doc(` # Simplest never-failing policy package policy.release.my_package @@ -86,16 +72,30 @@ func initPoliciesCmd() *cobra.Command { false result := "Never denies" } - `)) - + `) + if destDir == "" { + fmt.Fprintf(cmd.OutOrStdout(), "%s", samplePolicy) + return nil + } + fs := utils.FS(ctx) + workDir := destDir + err := fs.MkdirAll(workDir, 0755) + if err != nil { + log.Debug("Failed to create policy directory!") + return err + } + policyPath := filepath.Join(workDir, "sample.rego") + file, err := fs.Create(policyPath) + if err != nil { + log.Debug("Failed to create sample policy!") + return err + } + defer file.Close() + fmt.Fprintf(file, "%s", samplePolicy) return nil }, } - cmd.Flags().StringVarP(&destDir, "dest-dir", "d", "", "Directory to use when creating EC policy scaffolding") - if err := cmd.MarkFlagRequired("dest-dir"); err != nil { - panic(err) - } - + cmd.Flags().StringVarP(&destDir, "dest-dir", "d", "", "Directory to use when creating EC policy scaffolding. If not specified stdout will be used.") return cmd } diff --git a/cmd/initialize/init_policies_test.go b/cmd/initialize/init_policies_test.go index b4fa3c40a..8d0018789 100644 --- a/cmd/initialize/init_policies_test.go +++ b/cmd/initialize/init_policies_test.go @@ -18,11 +18,11 @@ package initialize import ( "bytes" + "context" "testing" "github.com/spf13/afero" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" "github.com/enterprise-contract/ec-cli/internal/utils" ) @@ -33,14 +33,51 @@ func TestInitializeNoError(t *testing.T) { cmd := initPoliciesCmd() cmd.SetContext(ctx) - buffy := bytes.Buffer{} - cmd.SetOut(&buffy) + buffy := new(bytes.Buffer) + cmd.SetOut(buffy) + + cmd.SetArgs([]string{ + "--dest-dir", + "sample", + }) + + err := cmd.Execute() + assert.NoError(t, err) +} + +func TestInitializeSamplePolicy(t *testing.T) { + fs := afero.NewMemMapFs() + ctx := utils.WithFS(context.Background(), fs) + + cmd := initPoliciesCmd() + cmd.SetContext(ctx) + buffy := new(bytes.Buffer) + cmd.SetOut(buffy) cmd.SetArgs([]string{ "--dest-dir", - "todo", + "sample", }) err := cmd.Execute() assert.NoError(t, err) + samplePolicy, err := afero.ReadFile(fs, "sample/sample.rego") + if err != nil { + t.Fatal(err) + } + assert.Contains(t, string(samplePolicy), "Simplest never-failing policy") +} + +func TestInitializeStdOut(t *testing.T) { + fs := afero.NewMemMapFs() + ctx := utils.WithFS(context.Background(), fs) + + cmd := initPoliciesCmd() + cmd.SetContext(ctx) + buffy := bytes.Buffer{} + cmd.SetOut(&buffy) + + err := cmd.Execute() + assert.NoError(t, err) + assert.Contains(t, buffy.String(), "Simplest never-failing policy") }