diff --git a/protoplugin.go b/protoplugin.go index e459745..858d222 100644 --- a/protoplugin.go +++ b/protoplugin.go @@ -131,9 +131,9 @@ func run( _, err := fmt.Fprintln(stdout, runOptions.version) return err } - return fmt.Errorf("unknown argument: %s", args[0]) + return newUnknownArgumentsError(args) default: - return fmt.Errorf("unknown arguments: %v", strings.Join(args, " ")) + return newUnknownArgumentsError(args) } if runOptions.warningHandlerFunc == nil { @@ -210,3 +210,18 @@ func (f mainOptionsFunc) applyMainOption(runOptions *runOptions) { func (f mainOptionsFunc) applyRunOption(runOptions *runOptions) { f(runOptions) } + +type unknownArgumentsError struct { + args []string +} + +func newUnknownArgumentsError(args []string) error { + return &unknownArgumentsError{args: args} +} + +func (e *unknownArgumentsError) Error() string { + if len(e.args) == 1 { + return fmt.Sprintf("unknown argument: %s", e.args[0]) + } + return fmt.Sprintf("unknown arguments: %s", strings.Join(e.args, " ")) +} diff --git a/protoplugin_test.go b/protoplugin_test.go index ec285bc..d901f8e 100644 --- a/protoplugin_test.go +++ b/protoplugin_test.go @@ -22,6 +22,7 @@ import ( "sort" "strings" "testing" + "testing/iotest" "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/protoutil" @@ -68,6 +69,38 @@ func TestBasic(t *testing.T) { ) } +func TestWithVersionOption(t *testing.T) { + t.Parallel() + + run := func(args []string, runOptions ...RunOption) (string, error) { + stdout := bytes.NewBuffer(nil) + err := Run( + context.Background(), + args, + iotest.ErrReader(io.EOF), + stdout, + io.Discard, + HandlerFunc(func(ctx context.Context, w *ResponseWriter, r *Request) error { return nil }), + runOptions..., + ) + return stdout.String(), err + } + + var unknownArgumentsError *unknownArgumentsError + _, err := run([]string{"--unsupported"}) + require.ErrorAs(t, err, &unknownArgumentsError) + _, err = run([]string{"--unsupported"}, WithVersion("0.0.1")) + require.ErrorAs(t, err, &unknownArgumentsError) + _, err = run([]string{"--version"}) + require.ErrorAs(t, err, &unknownArgumentsError) + _, err = run([]string{"--foo", "--bar"}) + require.ErrorAs(t, err, &unknownArgumentsError) + + out, err := run([]string{"--version"}, WithVersion("0.0.1")) + require.NoError(t, err) + require.Equal(t, "0.0.1\n", out) +} + func testBasic( t *testing.T, fileToGenerate []string,