From ea6f181946477efabe00d8360a4ca7c2e58678dd Mon Sep 17 00:00:00 2001 From: shawn Date: Thu, 4 Jul 2024 20:12:08 +0800 Subject: [PATCH 01/41] chore: fix typo in comment Signed-off-by: shawn --- tool/internal_pkg/util/util.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tool/internal_pkg/util/util.go b/tool/internal_pkg/util/util.go index e75255dec7..461b86c33e 100644 --- a/tool/internal_pkg/util/util.go +++ b/tool/internal_pkg/util/util.go @@ -88,7 +88,7 @@ func Exists(path string) bool { return !fi.IsDir() } -// LowerFirst converts the first letter to upper case for the given string. +// LowerFirst converts the first letter to lower case for the given string. func LowerFirst(s string) string { rs := []rune(s) rs[0] = unicode.ToLower(rs[0]) @@ -125,7 +125,7 @@ func UpperFirst(s string) string { return string(rs) } -// NotPtr converts an pointer type into non-pointer type. +// NotPtr converts a pointer type into non-pointer type. func NotPtr(s string) string { return strings.ReplaceAll(s, "*", "") } From e927571f98b4e9ee0431237bfccfa397ec190f2e Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 10 Jul 2024 17:25:22 +0800 Subject: [PATCH 02/41] feat: add template render subcommand using .tpl Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 7 +- tool/cmd/kitex/args/tpl_args.go | 265 ++++++++++++++++++ tool/cmd/kitex/main.go | 8 +- .../internal_pkg/generator/custom_template.go | 74 +++++ tool/internal_pkg/generator/generator.go | 9 + .../pluginmode/thriftgo/convertor.go | 2 +- .../pluginmode/thriftgo/plugin.go | 18 ++ tool/internal_pkg/util/command.go | 97 +++++++ 8 files changed, 475 insertions(+), 5 deletions(-) create mode 100644 tool/cmd/kitex/args/tpl_args.go create mode 100644 tool/internal_pkg/util/command.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index b2983d0b3e..b4ed627f93 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,14 +212,17 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" && a.TplDir == "" { if a.Use != "" { - return fmt.Errorf("-use must be used with -service or -template-dir") + return fmt.Errorf("-use must be used with -service or -template-dir or template render") } } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } + if a.ServiceName != "" && a.TplDir != "" { + return fmt.Errorf("template render and -service cannot be specified at the same time") + } if a.ServiceName != "" { a.GenerateMain = true } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go new file mode 100644 index 0000000000..4a00a3c6be --- /dev/null +++ b/tool/cmd/kitex/args/tpl_args.go @@ -0,0 +1,265 @@ +package args + +import ( + "flag" + "fmt" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" + "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/util" + "os" + "strings" +) + +func (a *Arguments) addBasicFlags(f *flag.FlagSet, version string) *flag.FlagSet { + f.BoolVar(&a.NoFastAPI, "no-fast-api", false, + "Generate codes without injecting fast method.") + f.StringVar(&a.ModuleName, "module", "", + "Specify the Go module name to generate go.mod.") + f.StringVar(&a.ServiceName, "service", "", + "Specify the service name to generate server side codes.") + f.StringVar(&a.Use, "use", "", + "Specify the kitex_gen package to import when generate server side codes.") + f.BoolVar(&a.Verbose, "v", false, "") // short for -verbose + f.BoolVar(&a.Verbose, "verbose", false, + "Turn on verbose mode.") + f.BoolVar(&a.GenerateInvoker, "invoker", false, + "Generate invoker side codes when service name is specified.") + f.StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + f.Var(&a.Includes, "I", "Add an IDL search path for includes.") + f.Var(&a.ThriftOptions, "thrift", "Specify arguments for the thrift go compiler.") + f.Var(&a.Hessian2Options, "hessian2", "Specify arguments for the hessian2 codec.") + f.DurationVar(&a.ThriftPluginTimeLimit, "thrift-plugin-time-limit", generator.DefaultThriftPluginTimeLimit, "Specify thrift plugin execution time limit.") + f.StringVar(&a.CompilerPath, "compiler-path", "", "Specify the path of thriftgo/protoc.") + f.Var(&a.ThriftPlugins, "thrift-plugin", "Specify thrift plugin arguments for the thrift compiler.") + f.Var(&a.ProtobufOptions, "protobuf", "Specify arguments for the protobuf compiler.") + f.Var(&a.ProtobufPlugins, "protobuf-plugin", "Specify protobuf plugin arguments for the protobuf compiler.(plugin_name:options:out_dir)") + f.BoolVar(&a.CombineService, "combine-service", false, + "Combine services in root thrift file.") + f.BoolVar(&a.CopyIDL, "copy-idl", false, + "Copy each IDL file to the output path.") + f.BoolVar(&a.HandlerReturnKeepResp, "handler-return-keep-resp", false, + "When the server-side handler returns both err and resp, the resp return is retained for use in middleware where both err and resp can be used simultaneously. Note: At the RPC communication level, if the handler returns an err, the framework still only returns err to the client without resp.") + f.StringVar(&a.ExtensionFile, "template-extension", a.ExtensionFile, + "Specify a file for template extension.") + f.BoolVar(&a.FrugalPretouch, "frugal-pretouch", false, + "Use frugal to compile arguments and results when new clients and servers.") + f.BoolVar(&a.Record, "record", false, + "Record Kitex cmd into kitex-all.sh.") + f.StringVar(&a.TemplateDir, "template-dir", "", + "Use custom template to generate codes.") + f.StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + f.BoolVar(&a.DeepCopyAPI, "deep-copy-api", false, + "Generate codes with injecting deep copy method.") + f.StringVar(&a.Protocol, "protocol", "", + "Specify a protocol for codec") + f.BoolVar(&a.NoDependencyCheck, "no-dependency-check", false, + "Skip dependency checking.") + a.RecordCmd = os.Args + a.Version = version + a.ThriftOptions = append(a.ThriftOptions, + "naming_style=golint", + "ignore_initialisms", + "gen_setter", + "gen_deep_equal", + "compatible_names", + "frugal_tag", + "thrift_streaming", + "no_processor", + ) + + for _, e := range a.extends { + e.Apply(f) + } + return f +} + +func (a *Arguments) buildInitFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("init", flag.ContinueOnError) + f.StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template init [flags] + +Examples: + %s template init -o /path/to/output + %s template init + +Flags: +`, version, os.Args[0], os.Args[0], os.Args[0]) + f.PrintDefaults() + } + return f +} + +func (a *Arguments) buildRenderFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("render", flag.ContinueOnError) + f.StringVar(&a.TemplateFile, "f", "", "Specify template init path") + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template render [template dir_path] [flags] IDL + +Examples: + %s template render ${template dir_path} -module ${module_name} idl/hello.thrift + %s template render ${template dir_path} -f service.go.tpl -module ${module_name} idl/hello.thrift + %s template render ${template dir_path} -module ${module_name} -I xxx.git idl/hello.thrift + +Flags: +`, version, os.Args[0], os.Args[0], os.Args[0], os.Args[0]) + f.PrintDefaults() + } + return f +} + +func (a *Arguments) buildCleanFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("clean", flag.ContinueOnError) + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template clean + +Examples: + %s template clean + +Flags: +`, version, os.Args[0], os.Args[0]) + f.PrintDefaults() + } + return f +} + +func (a *Arguments) TemplateArgs(version, curpath string) error { + templateCmd := &util.Command{ + Use: "template", + Short: "Template command", + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Template command executed") + return nil + }, + } + initCmd := &util.Command{ + Use: "init", + Short: "Init command", + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Init command executed") + f := a.buildInitFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, + } + renderCmd := &util.Command{ + Use: "render", + Short: "Render command", + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Render command executed") + if len(args) > 0 { + a.TplDir = args[0] + } + var tplDir string + for i, arg := range args { + if !strings.HasPrefix(arg, "-") { + tplDir = arg + args = append(args[:i], args[i+1:]...) + break + } + } + if tplDir == "" { + cmd.PrintUsage() + return fmt.Errorf("template directory is required") + } + + f := a.buildRenderFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, + } + cleanCmd := &util.Command{ + Use: "clean", + Short: "Clean command", + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Clean command executed") + f := a.buildCleanFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, + } + templateCmd.AddCommand(initCmd) + templateCmd.AddCommand(renderCmd) + templateCmd.AddCommand(cleanCmd) + + if _, err := templateCmd.ExecuteC(); err != nil { + return err + } + return nil +} diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 537894b97f..c03869d81c 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -76,8 +76,12 @@ func main() { log.Warn("Get current path failed:", err.Error()) os.Exit(1) } - // run as kitex - err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) + if os.Args[1] == "template" { + err = args.TemplateArgs(kitex.Version, curpath) + } else { + // run as kitex + err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) + } if err != nil { log.Warn(err.Error()) os.Exit(2) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index e8a74dbe64..0af738e60a 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -199,6 +199,80 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } +func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { + files, _ := os.ReadDir(currentDir) + //var ts []*Template + for _, f := range files { + // filter dir and non-yaml files + if f.IsDir() { + subDir := filepath.Join(currentDir, f.Name()) + subTemplates, err := readTpls(rootDir, subDir, ts) + if err != nil { + return nil, err + } + ts = append(ts, subTemplates...) + } else if strings.HasSuffix(f.Name(), ".tpl") { + p := filepath.Join(currentDir, f.Name()) + tplData, err := os.ReadFile(p) + if err != nil { + return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) + } + // Remove the .tpl suffix from the Path and compute relative path + relativePath, err := filepath.Rel(rootDir, p) + if err != nil { + return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) + } + trimmedPath := strings.TrimSuffix(relativePath, ".tpl") + t := &Template{ + Path: trimmedPath, + Body: string(tplData), + UpdateBehavior: &Update{Type: string(skip)}, + } + ts = append(ts, t) + } + } + + return ts, nil +} + +func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { + g.updatePackageInfo(pkg) + + g.setImports(HandlerFileName, pkg) + var tpls []*Template + tpls, err = readTpls(g.TplDir, g.TplDir, tpls) + if err != nil { + return nil, err + } + for _, tpl := range tpls { + newPath := filepath.Join(g.OutputPath, tpl.Path) + dir := filepath.Dir(newPath) + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, fmt.Errorf("failed to create directory %s: %v", dir, err) + } + if tpl.LoopService && g.CombineService { + svrInfo, cs := pkg.ServiceInfo, pkg.CombineServices + + for i := range cs { + pkg.ServiceInfo = cs[i] + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + pkg.ServiceInfo, pkg.CombineServices = svrInfo, cs + } else { + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + } + return fs, nil +} + func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { cg := NewCustomGenerator(pkg, outputPath) // special handling Methods field diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 5f792beee1..6120561303 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -96,6 +96,7 @@ type Generator interface { GenerateService(pkg *PackageInfo) ([]*File, error) GenerateMainPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) + GenerateCustomPackageWithTpl(pkg *PackageInfo) ([]*File, error) } // Config . @@ -133,6 +134,14 @@ type Config struct { TemplateDir string + // subcommand template + Template bool + TemplateInit bool + InitOutputDir string + TemplateClean bool + TplDir string + TemplateFile string + GenPath string DeepCopyAPI bool diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index 475a5fad09..760c3502c6 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -56,7 +56,7 @@ func (c *converter) init(req *plugin.Request) error { return fmt.Errorf("expect language to be 'go'. Encountered '%s'", req.Language) } - // resotre the arguments for kitex + // restore the arguments for kitex if err := c.Config.Unpack(req.PluginParameters); err != nil { return err } diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index a21c342e65..6a4d429871 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -39,6 +39,12 @@ func Run() int { println("Failed to get input:", err.Error()) return 1 } + //fmt.Println("data: ", string(data)) + err = os.WriteFile("/home/shawn/Develop/dump.txt", data, os.ModePerm) + if err != nil { + println("Failed to write file dump.txt", err.Error()) + return 1 + } req, err := plugin.UnmarshalRequest(data) if err != nil { @@ -107,6 +113,18 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } + if conv.Config.TplDir != "" { + if len(conv.Services) == 0 { + return conv.failResp(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.GenerateCustomPackageWithTpl(&conv.Package) + if err != nil { + return conv.failResp(err) + } + files = append(files, fs...) + } + res := &plugin.Response{ Warnings: conv.Warnings, } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go new file mode 100644 index 0000000000..8b3176d465 --- /dev/null +++ b/tool/internal_pkg/util/command.go @@ -0,0 +1,97 @@ +package util + +import ( + "flag" + "fmt" + "os" +) + +type Command struct { + Use string + Short string + Long string + RunE func(cmd *Command, args []string) error + commands []*Command + parent *Command + flags *flag.FlagSet +} + +func (c *Command) AddCommand(cmds ...*Command) { + for i, x := range cmds { + if cmds[i] == c { + panic("Command can't be a child of itself") + } + cmds[i].parent = c + c.commands = append(c.commands, x) + } +} + +// Flags returns the FlagSet of the Command +func (c *Command) Flags() *flag.FlagSet { + return c.flags +} + +// PrintUsage prints the usage of the Command +func (c *Command) PrintUsage() { + fmt.Fprintf(os.Stderr, "Usage: %s\n\n%s\n\n", c.Use, c.Long) + c.flags.PrintDefaults() + for _, cmd := range c.commands { + fmt.Fprintf(os.Stderr, " %s: %s\n", cmd.Use, cmd.Short) + } +} + +// Parent returns a commands parent command. +func (c *Command) Parent() *Command { + return c.parent +} + +// HasParent determines if the command is a child command. +func (c *Command) HasParent() bool { + return c.parent != nil +} + +// Root finds root command. +func (c *Command) Root() *Command { + if c.HasParent() { + return c.Parent().Root() + } + return c +} + +func (c *Command) Find(args []string) (*Command, []string, error) { + if len(args) == 0 { + return c, args, nil + } + + for _, cmd := range c.commands { + if cmd.Use == args[0] { + return cmd.Find(args[1:]) + } + } + + return c, args, nil +} + +// ExecuteC executes the command. +func (c *Command) ExecuteC() (cmd *Command, err error) { + args := os.Args[2:] + // Regardless of what command execute is called on, run on Root only + if c.HasParent() { + return c.Root().ExecuteC() + } + + cmd, flags, err := c.Find(args) + if err != nil { + fmt.Println(err) + return c, err + } + + if cmd.RunE != nil { + err = cmd.RunE(cmd, flags) + if err != nil { + fmt.Println(err) + } + } + + return cmd, nil +} From 07a86962c2f51cfd54a1dde78d562f31069af44e Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 10 Jul 2024 17:45:03 +0800 Subject: [PATCH 03/41] fix: fix some issue Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 5 ++++- tool/internal_pkg/generator/custom_template.go | 3 +-- tool/internal_pkg/generator/generator.go | 3 --- tool/internal_pkg/pluginmode/thriftgo/plugin.go | 6 ------ tool/internal_pkg/util/command.go | 10 +++++----- 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index b4ed627f93..c46ecda374 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -217,11 +217,14 @@ func (a *Arguments) checkServiceName() error { return fmt.Errorf("-use must be used with -service or -template-dir or template render") } } + if a.TemplateDir != "" && a.TplDir != "" { + return fmt.Errorf("template render and -template-dir cannot be used at the same time") + } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } if a.ServiceName != "" && a.TplDir != "" { - return fmt.Errorf("template render and -service cannot be specified at the same time") + return fmt.Errorf("template render and -service cannot be used at the same time") } if a.ServiceName != "" { a.GenerateMain = true diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 0af738e60a..ce7f43fb92 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -201,9 +201,8 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { files, _ := os.ReadDir(currentDir) - //var ts []*Template for _, f := range files { - // filter dir and non-yaml files + // filter dir and non-tpl files if f.IsDir() { subDir := filepath.Join(currentDir, f.Name()) subTemplates, err := readTpls(rootDir, subDir, ts) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 6120561303..92f4f8006d 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -135,10 +135,7 @@ type Config struct { TemplateDir string // subcommand template - Template bool - TemplateInit bool InitOutputDir string - TemplateClean bool TplDir string TemplateFile string diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 6a4d429871..01c0d1fe83 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -39,12 +39,6 @@ func Run() int { println("Failed to get input:", err.Error()) return 1 } - //fmt.Println("data: ", string(data)) - err = os.WriteFile("/home/shawn/Develop/dump.txt", data, os.ModePerm) - if err != nil { - println("Failed to write file dump.txt", err.Error()) - return 1 - } req, err := plugin.UnmarshalRequest(data) if err != nil { diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 8b3176d465..2d5e1e596e 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -2,7 +2,7 @@ package util import ( "flag" - "fmt" + "github.com/cloudwego/kitex/tool/internal_pkg/log" "os" ) @@ -33,10 +33,10 @@ func (c *Command) Flags() *flag.FlagSet { // PrintUsage prints the usage of the Command func (c *Command) PrintUsage() { - fmt.Fprintf(os.Stderr, "Usage: %s\n\n%s\n\n", c.Use, c.Long) + log.Warn("Usage: %s\n\n%s\n\n", c.Use, c.Long) c.flags.PrintDefaults() for _, cmd := range c.commands { - fmt.Fprintf(os.Stderr, " %s: %s\n", cmd.Use, cmd.Short) + log.Warnf(" %s: %s\n", cmd.Use, cmd.Short) } } @@ -82,14 +82,14 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd, flags, err := c.Find(args) if err != nil { - fmt.Println(err) + log.Warn(err) return c, err } if cmd.RunE != nil { err = cmd.RunE(cmd, flags) if err != nil { - fmt.Println(err) + log.Warn(err) } } From 4e2fbf693808b078e155082c02c0e136543828ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AA=E5=8D=93=E5=BF=97?= Date: Thu, 11 Jul 2024 19:28:25 +0800 Subject: [PATCH 04/41] perf: custom allocator for fast codec ReadString/ReadBinary (#1427) --- pkg/protocol/bthrift/binary.go | 29 ++++++++++++++------- pkg/protocol/bthrift/binary_test.go | 39 +++++++++++++++++++++++++++++ pkg/protocol/bthrift/interface.go | 5 ++++ 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 83bfc40fba..7f54a253e7 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -32,19 +32,28 @@ import ( var ( // Binary protocol for bthrift. - Binary binaryProtocol - _ BTProtocol = binaryProtocol{} - spanCache = mem.NewSpanCache(1024 * 1024) - spanCacheEnable bool = false + Binary binaryProtocol + _ BTProtocol = binaryProtocol{} ) +var allocator Allocator + const binaryInplaceThreshold = 4096 // 4k type binaryProtocol struct{} // SetSpanCache enable/disable binary protocol bytes/string allocator func SetSpanCache(enable bool) { - spanCacheEnable = enable + if enable { + SetAllocator(mem.NewSpanCache(1024 * 1024)) + } else { + SetAllocator(nil) + } +} + +// SetAllocator set binary protocol bytes/string allocator. +func SetAllocator(alloc Allocator) { + allocator = alloc } func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int { @@ -473,8 +482,9 @@ func (binaryProtocol) ReadString(buf []byte) (value string, length int, err erro if size < 0 || int(size) > len(buf) { return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadString] the string size greater than buf length") } - if spanCacheEnable { - data := spanCache.Copy(buf[length : length+int(size)]) + alloc := allocator + if alloc != nil { + data := alloc.Copy(buf[length : length+int(size)]) value = utils.SliceByteToString(data) } else { value = string(buf[length : length+int(size)]) @@ -494,8 +504,9 @@ func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err erro if size < 0 || size > len(buf) { return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") } - if spanCacheEnable { - value = spanCache.Copy(buf[length : length+size]) + alloc := allocator + if alloc != nil { + value = alloc.Copy(buf[length : length+size]) } else { value = make([]byte, size) copy(value, buf[length:length+size]) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 2e42b2defd..43df81b165 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -295,6 +295,24 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } +// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator +func TestWriteAndReadStringWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + wn := Binary.WriteString(buf, "kitex") + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadString(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + test.Assert(t, v == "kitex") + SetSpanCache(false) +} + // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -314,6 +332,27 @@ func TestWriteAndReadBinary(t *testing.T) { } } +// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator +func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + val := []byte("kitex") + wn := Binary.WriteBinary(buf, val) + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadBinary(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + for i := 0; i < len(v); i++ { + test.Assert(t, val[i] == v[i]) + } + SetSpanCache(false) +} + // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index c5c64b152b..01890b93ba 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -96,3 +96,8 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } + +type Allocator interface { + Make(n int) []byte + Copy(buf []byte) (p []byte) +} From a3eb24f421e08015183dfb07f38d0164a2f6c2c5 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Fri, 12 Jul 2024 14:55:48 +0800 Subject: [PATCH 05/41] chore: remove useless reflection api (#1433) --- pkg/reflection/thrift/registry.go | 68 -------------------------- pkg/reflection/thrift/registry_test.go | 38 -------------- 2 files changed, 106 deletions(-) delete mode 100644 pkg/reflection/thrift/registry.go delete mode 100644 pkg/reflection/thrift/registry_test.go diff --git a/pkg/reflection/thrift/registry.go b/pkg/reflection/thrift/registry.go deleted file mode 100644 index b85fc12d0a..0000000000 --- a/pkg/reflection/thrift/registry.go +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2023 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thriftreflection - -import ( - "sync" - - "github.com/cloudwego/thriftgo/reflection" -) - -// Files is a registry for looking up or iterating over files and the -// descriptors contained within them. -type Files struct { - filesByPath map[string]*reflection.FileDescriptor -} - -func NewFiles() *Files { - return &Files{filesByPath: map[string]*reflection.FileDescriptor{}} -} - -var globalMutex sync.RWMutex - -// GlobalFiles is a global registry of file descriptors. -var GlobalFiles = NewFiles() - -// RegisterIDL provides function for generated code to register their reflection info to GlobalFiles. -func RegisterIDL(bytes []byte) { - desc := reflection.Decode(bytes) - GlobalFiles.Register(desc) -} - -// Register registers the input FileDescriptor to *Files type variables. -func (f *Files) Register(desc *reflection.FileDescriptor) { - if f == GlobalFiles { - globalMutex.Lock() - defer globalMutex.Unlock() - } - // TODO: check conflict - f.filesByPath[desc.Filename] = desc -} - -// GetFileDescriptors returns the inner registered reflection FileDescriptors. -func (f *Files) GetFileDescriptors() map[string]*reflection.FileDescriptor { - if f == GlobalFiles { - globalMutex.RLock() - defer globalMutex.RUnlock() - m := make(map[string]*reflection.FileDescriptor, len(f.filesByPath)) - for k, v := range f.filesByPath { - m[k] = v - } - return m - } - return f.filesByPath -} diff --git a/pkg/reflection/thrift/registry_test.go b/pkg/reflection/thrift/registry_test.go deleted file mode 100644 index 96116f1ff4..0000000000 --- a/pkg/reflection/thrift/registry_test.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2023 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thriftreflection - -import ( - "testing" - - "github.com/cloudwego/thriftgo/reflection" - - "github.com/cloudwego/kitex/internal/test" -) - -func TestRegistry(t *testing.T) { - GlobalFiles.Register(&reflection.FileDescriptor{ - Filename: "testa", - }) - test.Assert(t, GlobalFiles.GetFileDescriptors()["testa"] != nil) - - f := NewFiles() - f.Register(&reflection.FileDescriptor{ - Filename: "testb", - }) - test.Assert(t, f.GetFileDescriptors()["testb"] != nil) -} From c23b499d63255bce7746497426bfd0e2f7750b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AA=E5=8D=93=E5=BF=97?= Date: Wed, 17 Jul 2024 14:39:09 +0800 Subject: [PATCH 06/41] optimize(lb): rebalance when instance weights updated (#1397) --- pkg/discovery/discovery.go | 19 +++++++++++++------ pkg/discovery/discovery_test.go | 5 +++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pkg/discovery/discovery.go b/pkg/discovery/discovery.go index 89d7b334b7..45d488a3fb 100644 --- a/pkg/discovery/discovery.go +++ b/pkg/discovery/discovery.go @@ -79,17 +79,20 @@ func DefaultDiff(cacheKey string, prev, next Result) (Change, bool) { }, } - prevMap := make(map[string]struct{}, len(prev.Instances)) + prevMap := make(map[string]Instance, len(prev.Instances)) for _, ins := range prev.Instances { - prevMap[ins.Address().String()] = struct{}{} + prevMap[ins.Address().String()] = ins } - nextMap := make(map[string]struct{}, len(next.Instances)) + nextMap := make(map[string]Instance, len(next.Instances)) for _, ins := range next.Instances { addr := ins.Address().String() - nextMap[addr] = struct{}{} - if _, found := prevMap[addr]; !found { + nextMap[addr] = ins + // FIXME(jizhuozhi): tags should also be used to determine whether the instance has updated + if prevIns, found := prevMap[addr]; !found { ch.Added = append(ch.Added, ins) + } else if prevIns.Weight() != ins.Weight() { + ch.Updated = append(ch.Updated, ins) } } @@ -98,7 +101,7 @@ func DefaultDiff(cacheKey string, prev, next Result) (Change, bool) { ch.Removed = append(ch.Removed, ins) } } - return ch, len(ch.Added)+len(ch.Removed) != 0 + return ch, len(ch.Added)+len(ch.Updated)+len(ch.Removed) != 0 } type instance struct { @@ -120,6 +123,10 @@ func (i *instance) Tag(key string) (value string, exist bool) { return } +func (i *instance) Tags() map[string]string { + return i.tags +} + // NewInstance creates a Instance using the given network, address and tags func NewInstance(network, address string, weight int, tags map[string]string) Instance { return &instance{ diff --git a/pkg/discovery/discovery_test.go b/pkg/discovery/discovery_test.go index 3acd994f1d..880f406b80 100644 --- a/pkg/discovery/discovery_test.go +++ b/pkg/discovery/discovery_test.go @@ -47,17 +47,18 @@ func TestDefaultDiff(t *testing.T) { Instances: []Instance{ NewInstance("tcp", "1", 10, nil), NewInstance("tcp", "2", 10, nil), - NewInstance("tcp", "3", 10, nil), + NewInstance("tcp", "3", 20, nil), NewInstance("tcp", "5", 10, nil), }, }}, Change{ Result: Result{Instances: []Instance{ NewInstance("tcp", "1", 10, nil), NewInstance("tcp", "2", 10, nil), - NewInstance("tcp", "3", 10, nil), + NewInstance("tcp", "3", 20, nil), NewInstance("tcp", "5", 10, nil), }, CacheKey: "1", Cacheable: true}, Added: []Instance{NewInstance("tcp", "5", 10, nil)}, + Updated: []Instance{NewInstance("tcp", "3", 20, nil)}, Removed: []Instance{NewInstance("tcp", "4", 10, nil)}, }, true}, } From 98bff0787a8616f09f668273ac722c9519ae552a Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Wed, 17 Jul 2024 20:07:32 +0800 Subject: [PATCH 07/41] fix: support setting PurePayload Transport Protocol (#1436) --- pkg/rpcinfo/rpcconfig.go | 14 ++++++++- pkg/rpcinfo/rpcconfig_test.go | 53 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 996c322c3f..c18fac054f 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -17,6 +17,7 @@ package rpcinfo import ( + "fmt" "sync" "time" @@ -166,7 +167,18 @@ func (r *rpcConfig) TransportProtocol() transport.Protocol { // SetTransportProtocol implements MutableRPCConfig interface. func (r *rpcConfig) SetTransportProtocol(tp transport.Protocol) error { - r.transportProtocol |= tp + if tp&transport.TTHeaderFramed != 0 { + if tp&(^transport.TTHeaderFramed) != 0 { + panic(fmt.Sprintf("invalid transport protocol: %b", tp)) + } + // TTHeader and Framed can be combined for [TTHeader + [FramedSize + Payload]] + r.transportProtocol &= transport.TTHeaderFramed // clear bits except TTHeader | Framed + r.transportProtocol |= tp + } else { + // other transports are mutually exclusive + // it's user's responsibility to set only one transport, not an OR-ed combination of multiple transports + r.transportProtocol = tp + } return nil } diff --git a/pkg/rpcinfo/rpcconfig_test.go b/pkg/rpcinfo/rpcconfig_test.go index d7effd7ec1..5c385f6189 100644 --- a/pkg/rpcinfo/rpcconfig_test.go +++ b/pkg/rpcinfo/rpcconfig_test.go @@ -33,3 +33,56 @@ func TestRPCConfig(t *testing.T) { test.Assert(t, c.IOBufferSize() != 0) test.Assert(t, c.TransportProtocol() == transport.PurePayload) } + +func TestSetTransportProtocol(t *testing.T) { + t.Run("set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + }) + t.Run("set-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.Framed, c.TransportProtocol()) + }) + t.Run("set-ttheader-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeaderFramed) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-ttheader-set-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-framed-set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.Framed, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-ttheader-set-grpc", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.GRPC) + test.Assert(t, c.TransportProtocol() == transport.GRPC, c.TransportProtocol()) + }) + t.Run("set-grpc-set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.GRPC) + test.Assert(t, c.TransportProtocol() == transport.GRPC, c.TransportProtocol()) + rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + }) + t.Run("set-invalid-transport", func(t *testing.T) { + defer func() { + test.Assert(t, recover() != nil) + }() + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader | transport.GRPC) + }) +} From a3046aaa26a19032fe92f76c315eec394137b499 Mon Sep 17 00:00:00 2001 From: shawn Date: Fri, 19 Jul 2024 16:45:29 +0800 Subject: [PATCH 08/41] bug fixes: fix render Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 6 +- tool/cmd/kitex/args/tpl_args.go | 205 +-- tool/cmd/kitex/main.go | 3 + .../internal_pkg/generator/custom_template.go | 2 +- tool/internal_pkg/generator/generator.go | 2 +- .../pluginmode/thriftgo/plugin.go | 2 +- tool/internal_pkg/util/command.go | 160 ++- tool/internal_pkg/util/command_test.go | 89 ++ tool/internal_pkg/util/flag.go | 1140 +++++++++++++++++ 9 files changed, 1432 insertions(+), 177 deletions(-) create mode 100644 tool/internal_pkg/util/command_test.go create mode 100644 tool/internal_pkg/util/flag.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index c46ecda374..4ce2b3eed9 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,18 +212,18 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" && a.TplDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" && a.RenderTplDir == "" { if a.Use != "" { return fmt.Errorf("-use must be used with -service or -template-dir or template render") } } - if a.TemplateDir != "" && a.TplDir != "" { + if a.TemplateDir != "" && a.RenderTplDir != "" { return fmt.Errorf("template render and -template-dir cannot be used at the same time") } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } - if a.ServiceName != "" && a.TplDir != "" { + if a.ServiceName != "" && a.RenderTplDir != "" { return fmt.Errorf("template render and -service cannot be used at the same time") } if a.ServiceName != "" { diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 4a00a3c6be..922153679c 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -1,7 +1,20 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package args import ( - "flag" "fmt" "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" @@ -10,144 +23,19 @@ import ( "strings" ) -func (a *Arguments) addBasicFlags(f *flag.FlagSet, version string) *flag.FlagSet { - f.BoolVar(&a.NoFastAPI, "no-fast-api", false, - "Generate codes without injecting fast method.") - f.StringVar(&a.ModuleName, "module", "", - "Specify the Go module name to generate go.mod.") - f.StringVar(&a.ServiceName, "service", "", - "Specify the service name to generate server side codes.") - f.StringVar(&a.Use, "use", "", - "Specify the kitex_gen package to import when generate server side codes.") - f.BoolVar(&a.Verbose, "v", false, "") // short for -verbose - f.BoolVar(&a.Verbose, "verbose", false, - "Turn on verbose mode.") - f.BoolVar(&a.GenerateInvoker, "invoker", false, - "Generate invoker side codes when service name is specified.") - f.StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") - f.Var(&a.Includes, "I", "Add an IDL search path for includes.") - f.Var(&a.ThriftOptions, "thrift", "Specify arguments for the thrift go compiler.") - f.Var(&a.Hessian2Options, "hessian2", "Specify arguments for the hessian2 codec.") - f.DurationVar(&a.ThriftPluginTimeLimit, "thrift-plugin-time-limit", generator.DefaultThriftPluginTimeLimit, "Specify thrift plugin execution time limit.") - f.StringVar(&a.CompilerPath, "compiler-path", "", "Specify the path of thriftgo/protoc.") - f.Var(&a.ThriftPlugins, "thrift-plugin", "Specify thrift plugin arguments for the thrift compiler.") - f.Var(&a.ProtobufOptions, "protobuf", "Specify arguments for the protobuf compiler.") - f.Var(&a.ProtobufPlugins, "protobuf-plugin", "Specify protobuf plugin arguments for the protobuf compiler.(plugin_name:options:out_dir)") - f.BoolVar(&a.CombineService, "combine-service", false, - "Combine services in root thrift file.") - f.BoolVar(&a.CopyIDL, "copy-idl", false, - "Copy each IDL file to the output path.") - f.BoolVar(&a.HandlerReturnKeepResp, "handler-return-keep-resp", false, - "When the server-side handler returns both err and resp, the resp return is retained for use in middleware where both err and resp can be used simultaneously. Note: At the RPC communication level, if the handler returns an err, the framework still only returns err to the client without resp.") - f.StringVar(&a.ExtensionFile, "template-extension", a.ExtensionFile, - "Specify a file for template extension.") - f.BoolVar(&a.FrugalPretouch, "frugal-pretouch", false, - "Use frugal to compile arguments and results when new clients and servers.") - f.BoolVar(&a.Record, "record", false, - "Record Kitex cmd into kitex-all.sh.") - f.StringVar(&a.TemplateDir, "template-dir", "", - "Use custom template to generate codes.") - f.StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, - "Specify a code gen path.") - f.BoolVar(&a.DeepCopyAPI, "deep-copy-api", false, - "Generate codes with injecting deep copy method.") - f.StringVar(&a.Protocol, "protocol", "", - "Specify a protocol for codec") - f.BoolVar(&a.NoDependencyCheck, "no-dependency-check", false, - "Skip dependency checking.") - a.RecordCmd = os.Args - a.Version = version - a.ThriftOptions = append(a.ThriftOptions, - "naming_style=golint", - "ignore_initialisms", - "gen_setter", - "gen_deep_equal", - "compatible_names", - "frugal_tag", - "thrift_streaming", - "no_processor", - ) - - for _, e := range a.extends { - e.Apply(f) - } - return f -} - -func (a *Arguments) buildInitFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("init", flag.ContinueOnError) - f.StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template init [flags] - -Examples: - %s template init -o /path/to/output - %s template init - -Flags: -`, version, os.Args[0], os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - -func (a *Arguments) buildRenderFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("render", flag.ContinueOnError) - f.StringVar(&a.TemplateFile, "f", "", "Specify template init path") - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template render [template dir_path] [flags] IDL - -Examples: - %s template render ${template dir_path} -module ${module_name} idl/hello.thrift - %s template render ${template dir_path} -f service.go.tpl -module ${module_name} idl/hello.thrift - %s template render ${template dir_path} -module ${module_name} -I xxx.git idl/hello.thrift - -Flags: -`, version, os.Args[0], os.Args[0], os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - -func (a *Arguments) buildCleanFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("clean", flag.ContinueOnError) - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template clean - -Examples: - %s template clean - -Flags: -`, version, os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - func (a *Arguments) TemplateArgs(version, curpath string) error { + kitexCmd := &util.Command{ + Use: "kitex", + Short: "Kitex command", + } templateCmd := &util.Command{ Use: "template", Short: "Template command", - RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Template command executed") - return nil - }, } initCmd := &util.Command{ Use: "init", Short: "Init command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Init command executed") - f := a.buildInitFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -157,7 +45,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()) if err != nil { return err } @@ -176,9 +64,8 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { Use: "render", Short: "Render command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Render command executed") if len(args) > 0 { - a.TplDir = args[0] + a.RenderTplDir = args[0] } var tplDir string for i, arg := range args { @@ -192,11 +79,6 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { cmd.PrintUsage() return fmt.Errorf("template directory is required") } - - f := a.buildRenderFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -206,7 +88,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()[1:]) if err != nil { return err } @@ -225,11 +107,6 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { Use: "clean", Short: "Clean command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Clean command executed") - f := a.buildCleanFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -239,7 +116,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()) if err != nil { return err } @@ -254,11 +131,41 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { return a.checkPath(curpath) }, } - templateCmd.AddCommand(initCmd) - templateCmd.AddCommand(renderCmd) - templateCmd.AddCommand(cleanCmd) + initCmd.Flags().StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") + renderCmd.Flags().StringVar(&a.ModuleName, "module", "", + "Specify the Go module name to generate go.mod.") + renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + renderCmd.Flags().StringVar(&a.TemplateFile, "f", "", "Specify template init path") + initCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: kitex template init [flags] + +Examples: + kitex template init -o /path/to/output + kitex template init + +Flags: +`, version) + }) + renderCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: template render [template dir_path] [flags] IDL +`, version) + }) + cleanCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: kitex template clean - if _, err := templateCmd.ExecuteC(); err != nil { +Examples: + kitex template clean + +Flags: +`, version) + }) + //renderCmd.PrintUsage() + templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) + kitexCmd.AddCommand(templateCmd) + if _, err := kitexCmd.ExecuteC(); err != nil { return err } return nil diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index c03869d81c..ede33fa89b 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -78,6 +78,9 @@ func main() { } if os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version, curpath) + } else if os.Args[1] != "template" { + log.Warnf("Unknown command %q", os.Args[1]) + os.Exit(1) } else { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index ce7f43fb92..1de769ce1b 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -239,7 +239,7 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, g.setImports(HandlerFileName, pkg) var tpls []*Template - tpls, err = readTpls(g.TplDir, g.TplDir, tpls) + tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) if err != nil { return nil, err } diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 92f4f8006d..c2a2404c5b 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -136,7 +136,7 @@ type Config struct { // subcommand template InitOutputDir string - TplDir string + RenderTplDir string TemplateFile string GenPath string diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 01c0d1fe83..e65c3e179e 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -107,7 +107,7 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } - if conv.Config.TplDir != "" { + if conv.Config.RenderTplDir != "" { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 2d5e1e596e..4746690060 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -1,9 +1,24 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package util import ( - "flag" + "fmt" "github.com/cloudwego/kitex/tool/internal_pkg/log" "os" + "strings" ) type Command struct { @@ -13,21 +28,27 @@ type Command struct { RunE func(cmd *Command, args []string) error commands []*Command parent *Command - flags *flag.FlagSet + flags *FlagSet + // helpFunc is help func defined by user. + usage func() } -func (c *Command) AddCommand(cmds ...*Command) { +func (c *Command) AddCommand(cmds ...*Command) error { for i, x := range cmds { if cmds[i] == c { - panic("Command can't be a child of itself") + return fmt.Errorf("command can't be a child of itself") } cmds[i].parent = c c.commands = append(c.commands, x) } + return nil } // Flags returns the FlagSet of the Command -func (c *Command) Flags() *flag.FlagSet { +func (c *Command) Flags() *FlagSet { + if c.flags == nil { + c.flags = NewFlagSet(c.Use, ContinueOnError) + } return c.flags } @@ -58,40 +79,135 @@ func (c *Command) Root() *Command { return c } -func (c *Command) Find(args []string) (*Command, []string, error) { - if len(args) == 0 { - return c, args, nil +// HasSubCommands determines if the command has children commands. +func (c *Command) HasSubCommands() bool { + return len(c.commands) > 0 +} + +func stripFlags(args []string) []string { + commands := make([]string, 0) + for len(args) > 0 { + s := args[0] + args = args[1:] + if s != "" && !strings.HasPrefix(s, "-") { + commands = append(commands, s) + } else if strings.HasPrefix(s, "-") { + break + } } + return commands +} +func (c *Command) findNext(next string) *Command { for _, cmd := range c.commands { - if cmd.Use == args[0] { - return cmd.Find(args[1:]) + if cmd.Use == next { + return cmd + } + } + return nil +} + +func nextArgs(args []string, x string) []string { + if len(args) == 0 { + return args + } + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case s == "--": + break + case strings.HasPrefix(s, "-"): + pos++ + continue + case !strings.HasPrefix(s, "-"): + if s == x { + var ret []string + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } + } + } + return args +} + +func validateArgs(cmd *Command, args []string) error { + // no subcommand, always take args + if !cmd.HasSubCommands() { + return nil + } + + // root command with subcommands, do subcommand checking. + if !cmd.HasParent() && len(args) > 0 { + return fmt.Errorf("unknown command %q", args[0]) + } + return nil +} + +// Find the target command given the args and command tree +func (c *Command) Find(args []string) (*Command, []string, error) { + var innerFind func(*Command, []string) (*Command, []string) + + innerFind = func(c *Command, innerArgs []string) (*Command, []string) { + argsWithoutFlags := stripFlags(innerArgs) + if len(argsWithoutFlags) == 0 { + return c, innerArgs + } + nextSubCmd := argsWithoutFlags[0] + + cmd := c.findNext(nextSubCmd) + if cmd != nil { + return innerFind(cmd, nextArgs(innerArgs, nextSubCmd)) } + return c, innerArgs } + commandFound, a := innerFind(c, args) + return commandFound, a, validateArgs(commandFound, stripFlags(a)) +} + +// ParseFlags parses persistent flag tree and local flags. +func (c *Command) ParseFlags(args []string) error { + err := c.Flags().Parse(args) + return err +} + +func (c *Command) SetUsageFunc(f func()) { + c.usage = f +} - return c, args, nil +func (c *Command) UsageFunc() func() { + return c.usage } // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { - args := os.Args[2:] - // Regardless of what command execute is called on, run on Root only - if c.HasParent() { - return c.Root().ExecuteC() - } + args := os.Args[1:] cmd, flags, err := c.Find(args) if err != nil { - log.Warn(err) return c, err } + err = cmd.execute(flags) + if err != nil { + cmd.usage() + } + + return cmd, err +} - if cmd.RunE != nil { - err = cmd.RunE(cmd, flags) +func (c *Command) execute(a []string) error { + if c == nil { + return fmt.Errorf("called Execute() on a nil Command") + } + err := c.ParseFlags(a) + if err != nil { + return err + } + if c.RunE != nil { + err := c.RunE(c, a) if err != nil { - log.Warn(err) + return err } } - - return cmd, nil + return nil } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go new file mode 100644 index 0000000000..a5a4766850 --- /dev/null +++ b/tool/internal_pkg/util/command_test.go @@ -0,0 +1,89 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "os" + "testing" +) + +func TestAddCommand(t *testing.T) { + rootCmd := &Command{Use: "root"} + childCmd := &Command{Use: "child"} + + // Test adding a valid child command + err := rootCmd.AddCommand(childCmd) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(rootCmd.commands) != 1 { + t.Fatalf("expected 1 command, got %d", len(rootCmd.commands)) + } + + if rootCmd.commands[0] != childCmd { + t.Fatalf("expected child command to be added") + } + + // Test adding a command to itself + err = rootCmd.AddCommand(rootCmd) + if err == nil { + t.Fatalf("expected an error, got nil") + } + + expectedErr := "command can't be a child of itself" + if err.Error() != expectedErr { + t.Fatalf("expected error %q, got %q", expectedErr, err.Error()) + } +} + +func TestExecuteC(t *testing.T) { + rootCmd := &Command{ + Use: "root", + RunE: func(cmd *Command, args []string) error { + return nil + }, + } + + subCmd := &Command{ + Use: "sub", + RunE: func(cmd *Command, args []string) error { + return nil + }, + } + + rootCmd.AddCommand(subCmd) + + // Simulate command line arguments + os.Args = []string{"root", "sub"} + + // Execute the command + cmd, err := rootCmd.ExecuteC() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if cmd.Use != "sub" { + t.Fatalf("expected sub command to be executed, got %s", cmd.Use) + } + + // Simulate command line arguments with an unknown command + os.Args = []string{"root", "unknown"} + + _, err = rootCmd.ExecuteC() + if err == nil { + t.Fatalf("expected an error for unknown command, got nil") + } +} diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go new file mode 100644 index 0000000000..798f566a7f --- /dev/null +++ b/tool/internal_pkg/util/flag.go @@ -0,0 +1,1140 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "bytes" + "errors" + goflag "flag" + "fmt" + "io" + "os" + "sort" + "strconv" + "strings" +) + +// ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. +var ErrHelp = errors.New("pflag: help requested") + +// ErrorHandling defines how to handle flag parsing errors. +type ErrorHandling int + +const ( + // ContinueOnError will return an err from Parse() if an error is found + ContinueOnError ErrorHandling = iota + // ExitOnError will call os.Exit(2) if an error is found when parsing + ExitOnError + // PanicOnError will panic() if an error is found when parsing flags + PanicOnError +) + +// ParseErrorsWhitelist defines the parsing errors that can be ignored +type ParseErrorsWhitelist struct { + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags bool +} + +// NormalizedName is a flag name that has been normalized according to rules +// for the FlagSet (e.g. making '-' and '_' equivalent). +type NormalizedName string + +// A FlagSet represents a set of defined flags. +type FlagSet struct { + // Usage is the function called when an error occurs while parsing flags. + // The field is a function (not a method) that may be changed to point to + // a custom error handler. + Usage func() + + // SortFlags is used to indicate, if user wants to have sorted flags in + // help/usage messages. + SortFlags bool + + // ParseErrorsWhitelist is used to configure a whitelist of errors + ParseErrorsWhitelist ParseErrorsWhitelist + + name string + parsed bool + actual map[NormalizedName]*Flag + orderedActual []*Flag + sortedActual []*Flag + formal map[NormalizedName]*Flag + orderedFormal []*Flag + sortedFormal []*Flag + shorthands map[byte]*Flag + args []string // arguments after flags + argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- + errorHandling ErrorHandling + output io.Writer // nil means stderr; use out() accessor + interspersed bool // allow interspersed option/non-option args + normalizeNameFunc func(f *FlagSet, name string) NormalizedName + + addedGoFlagSets []*goflag.FlagSet +} + +// A Flag represents the state of a flag. +type Flag struct { + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string // default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use + Annotations map[string][]string // used by cobra.Command bash autocomple code +} + +// Value is the interface to the dynamic value stored in a flag. +// (The default value is represented as a string.) +type Value interface { + String() string + Set(string) error +} + +// SliceValue is a secondary interface to all flags which hold a list +// of values. This allows full control over the value of list flags, +// and avoids complicated marshalling and unmarshalling to csv. +type SliceValue interface { + // Append adds the specified value to the end of the flag value list. + Append(string) error + // Replace will fully overwrite any data currently in the flag value list. + Replace([]string) error + // GetSlice returns the flag value list as an array of strings. + GetSlice() []string +} + +// sortFlags returns the flags as a slice in lexicographical sorted order. +func sortFlags(flags map[NormalizedName]*Flag) []*Flag { + list := make(sort.StringSlice, len(flags)) + i := 0 + for k := range flags { + list[i] = string(k) + i++ + } + list.Sort() + result := make([]*Flag, len(list)) + for i, name := range list { + result[i] = flags[NormalizedName(name)] + } + return result +} + +// SetNormalizeFunc allows you to add a function which can translate flag names. +// Flags added to the FlagSet will be translated and then when anything tries to +// look up the flag that will also be translated. So it would be possible to create +// a flag named "getURL" and have it translated to "geturl". A user could then pass +// "--getUrl" which may also be translated to "geturl" and everything will work. +func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { + f.normalizeNameFunc = n + f.sortedFormal = f.sortedFormal[:0] + for fname, flag := range f.formal { + nname := f.normalizeFlagName(flag.Name) + if fname == nname { + continue + } + flag.Name = string(nname) + delete(f.formal, fname) + f.formal[nname] = flag + if _, set := f.actual[fname]; set { + delete(f.actual, fname) + f.actual[nname] = flag + } + } +} + +// GetNormalizeFunc returns the previously set NormalizeFunc of a function which +// does no translation, if not set previously. +func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { + if f.normalizeNameFunc != nil { + return f.normalizeNameFunc + } + return func(f *FlagSet, name string) NormalizedName { return NormalizedName(name) } +} + +func (f *FlagSet) normalizeFlagName(name string) NormalizedName { + n := f.GetNormalizeFunc() + return n(f, name) +} + +func (f *FlagSet) out() io.Writer { + if f.output == nil { + return os.Stderr + } + return f.output +} + +// SetOutput sets the destination for usage and error messages. +// If output is nil, os.Stderr is used. +func (f *FlagSet) SetOutput(output io.Writer) { + f.output = output +} + +// VisitAll visits the flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits all flags, even those not set. +func (f *FlagSet) VisitAll(fn func(*Flag)) { + if len(f.formal) == 0 { + return + } + + var flags []*Flag + if f.SortFlags { + if len(f.formal) != len(f.sortedFormal) { + f.sortedFormal = sortFlags(f.formal) + } + flags = f.sortedFormal + } else { + flags = f.orderedFormal + } + + for _, flag := range flags { + fn(flag) + } +} + +// HasFlags returns a bool to indicate if the FlagSet has any flags defined. +func (f *FlagSet) HasFlags() bool { + return len(f.formal) > 0 +} + +// HasAvailableFlags returns a bool to indicate if the FlagSet has any flags +// that are not hidden. +func (f *FlagSet) HasAvailableFlags() bool { + for _, flag := range f.formal { + if !flag.Hidden { + return true + } + } + return false +} + +// VisitAll visits the command-line flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits all flags, even those not set. +func VisitAll(fn func(*Flag)) { + CommandLine.VisitAll(fn) +} + +// Visit visits the flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits only those flags that have been set. +func (f *FlagSet) Visit(fn func(*Flag)) { + if len(f.actual) == 0 { + return + } + + var flags []*Flag + if f.SortFlags { + if len(f.actual) != len(f.sortedActual) { + f.sortedActual = sortFlags(f.actual) + } + flags = f.sortedActual + } else { + flags = f.orderedActual + } + + for _, flag := range flags { + fn(flag) + } +} + +// Visit visits the command-line flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits only those flags that have been set. +func Visit(fn func(*Flag)) { + CommandLine.Visit(fn) +} + +// Lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) Lookup(name string) *Flag { + return f.lookup(f.normalizeFlagName(name)) +} + +// ShorthandLookup returns the Flag structure of the short handed flag, +// returning nil if none exists. +// It panics, if len(name) > 1. +func (f *FlagSet) ShorthandLookup(name string) *Flag { + if name == "" { + return nil + } + if len(name) > 1 { + msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name) + fmt.Fprintf(f.out(), msg) + panic(msg) + } + c := name[0] + return f.shorthands[c] +} + +// lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) lookup(name NormalizedName) *Flag { + return f.formal[name] +} + +// func to return a given type for a given flag name +func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { + flag := f.Lookup(name) + if flag == nil { + err := fmt.Errorf("flag accessed but not defined: %s", name) + return nil, err + } + + sval := flag.Value.String() + result, err := convFunc(sval) + if err != nil { + return nil, err + } + return result, nil +} + +// ArgsLenAtDash will return the length of f.Args at the moment when a -- was +// found during arg parsing. This allows your program to know which args were +// before the -- and which came after. +func (f *FlagSet) ArgsLenAtDash() int { + return f.argsLenAtDash +} + +// MarkDeprecated indicated that a flag is deprecated in your program. It will +// continue to function but will not show up in help or usage messages. Using +// this flag will also print the given usageMessage. +func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + if usageMessage == "" { + return fmt.Errorf("deprecated message for flag %q must be set", name) + } + flag.Deprecated = usageMessage + flag.Hidden = true + return nil +} + +// MarkShorthandDeprecated will mark the shorthand of a flag deprecated in your +// program. It will continue to function but will not show up in help or usage +// messages. Using this flag will also print the given usageMessage. +func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + if usageMessage == "" { + return fmt.Errorf("deprecated message for flag %q must be set", name) + } + flag.ShorthandDeprecated = usageMessage + return nil +} + +// MarkHidden sets a flag to 'hidden' in your program. It will continue to +// function but will not show up in help or usage messages. +func (f *FlagSet) MarkHidden(name string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + flag.Hidden = true + return nil +} + +// Lookup returns the Flag structure of the named command-line flag, +// returning nil if none exists. +func Lookup(name string) *Flag { + return CommandLine.Lookup(name) +} + +// ShorthandLookup returns the Flag structure of the short handed flag, +// returning nil if none exists. +func ShorthandLookup(name string) *Flag { + return CommandLine.ShorthandLookup(name) +} + +// Set sets the value of the named flag. +func (f *FlagSet) Set(name, value string) error { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + return fmt.Errorf("no such flag -%v", name) + } + + err := flag.Value.Set(value) + if err != nil { + var flagName string + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name) + } else { + flagName = fmt.Sprintf("--%s", flag.Name) + } + return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err) + } + + if !flag.Changed { + if f.actual == nil { + f.actual = make(map[NormalizedName]*Flag) + } + f.actual[normalName] = flag + f.orderedActual = append(f.orderedActual, flag) + + flag.Changed = true + } + + if flag.Deprecated != "" { + fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) + } + return nil +} + +// SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. +// This is sometimes used by spf13/cobra programs which want to generate additional +// bash completion information. +func (f *FlagSet) SetAnnotation(name, key string, values []string) error { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + return fmt.Errorf("no such flag -%v", name) + } + if flag.Annotations == nil { + flag.Annotations = map[string][]string{} + } + flag.Annotations[key] = values + return nil +} + +// Changed returns true if the flag was explicitly set during Parse() and false +// otherwise +func (f *FlagSet) Changed(name string) bool { + flag := f.Lookup(name) + // If a flag doesn't exist, it wasn't changed.... + if flag == nil { + return false + } + return flag.Changed +} + +// Set sets the value of the named command-line flag. +func Set(name, value string) error { + return CommandLine.Set(name, value) +} + +// PrintDefaults prints, to standard error unless configured +// otherwise, the default values of all defined flags in the set. +func (f *FlagSet) PrintDefaults() { + usages := f.FlagUsages() + fmt.Fprint(f.out(), usages) +} + +// UnquoteUsage extracts a back-quoted name from the usage +// string for a flag and returns it and the un-quoted usage. +// Given "a `name` to show" it returns ("name", "a name to show"). +// If there are no back quotes, the name is an educated guess of the +// type of the flag's value, or the empty string if the flag is boolean. +func UnquoteUsage(flag *Flag) (name string, usage string) { + // Look for a back-quoted name, but avoid the strings package. + usage = flag.Usage + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name = usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break // Only one back quote; use type name. + } + } + + return +} + +// Splits the string `s` on whitespace into an initial substring up to +// `i` runes in length and the remainder. Will go `slop` over `i` if +// that encompasses the entire string (which allows the caller to +// avoid short orphan words on the final line). +func wrapN(i, slop int, s string) (string, string) { + if i+slop > len(s) { + return s, "" + } + + w := strings.LastIndexAny(s[:i], " \t\n") + if w <= 0 { + return s, "" + } + nlPos := strings.LastIndex(s[:i], "\n") + if nlPos > 0 && nlPos < w { + return s[:nlPos], s[nlPos+1:] + } + return s[:w], s[w+1:] +} + +// Wraps the string `s` to a maximum width `w` with leading indent +// `i`. The first line is not indented (this is assumed to be done by +// caller). Pass `w` == 0 to do no wrapping +func wrap(i, w int, s string) string { + if w == 0 { + return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + // space between indent i and end of line width w into which + // we should wrap the text. + wrap := w - i + + var r, l string + + // Not enough space for sensible wrapping. Wrap as a block on + // the next line instead. + if wrap < 24 { + i = 16 + wrap = w - i + r += "\n" + strings.Repeat(" ", i) + } + // If still not enough space then don't even try to wrap. + if wrap < 24 { + return strings.Replace(s, "\n", r, -1) + } + + // Try to avoid short orphan words on the final line, by + // allowing wrapN to go a bit over if that would fit in the + // remainder of the line. + slop := 5 + wrap = wrap - slop + + // Handle first line, which is indented by the caller (or the + // special case above) + l, s = wrapN(wrap, slop, s) + r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) + + // Now wrap the rest + for s != "" { + var t string + + t, s = wrapN(wrap, slop, s) + r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + return r + +} + +// FlagUsagesWrapped returns a string containing the usage information +// for all flags in the FlagSet. Wrapped to `cols` columns (0 for no +// wrapping) +func (f *FlagSet) FlagUsagesWrapped(cols int) string { + buf := new(bytes.Buffer) + + lines := make([]string, 0, len(f.formal)) + + maxlen := 0 + f.VisitAll(func(flag *Flag) { + if flag.Hidden { + return + } + + line := "" + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) + } else { + line = fmt.Sprintf(" --%s", flag.Name) + } + + varname, usage := UnquoteUsage(flag) + if varname != "" { + line += " " + varname + } + + // This special character will be replaced with spacing once the + // correct alignment is calculated + line += "\x00" + if len(line) > maxlen { + maxlen = len(line) + } + + line += usage + if len(flag.Deprecated) != 0 { + line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) + } + + lines = append(lines, line) + }) + + for _, line := range lines { + sidx := strings.Index(line, "\x00") + spacing := strings.Repeat(" ", maxlen-sidx) + // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx + fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) + } + + return buf.String() +} + +// FlagUsages returns a string containing the usage information for all flags in +// the FlagSet +func (f *FlagSet) FlagUsages() string { + return f.FlagUsagesWrapped(0) +} + +// PrintDefaults prints to standard error the default values of all defined command-line flags. +func PrintDefaults() { + CommandLine.PrintDefaults() +} + +// defaultUsage is the default function to print a usage message. +func defaultUsage(f *FlagSet) { + fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) + f.PrintDefaults() +} + +// NOTE: Usage is not just defaultUsage(CommandLine) +// because it serves (via godoc flag Usage) as the example +// for how to write your own usage function. + +// Usage prints to standard error a usage message documenting all defined command-line flags. +// The function is a variable that may be changed to point to a custom function. +// By default it prints a simple header and calls PrintDefaults; for details about the +// format of the output and how to control it, see the documentation for PrintDefaults. +var Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + PrintDefaults() +} + +// NFlag returns the number of flags that have been set. +func (f *FlagSet) NFlag() int { return len(f.actual) } + +// NFlag returns the number of command-line flags that have been set. +func NFlag() int { return len(CommandLine.actual) } + +// Arg returns the i'th argument. Arg(0) is the first remaining argument +// after flags have been processed. +func (f *FlagSet) Arg(i int) string { + if i < 0 || i >= len(f.args) { + return "" + } + return f.args[i] +} + +// Arg returns the i'th command-line argument. Arg(0) is the first remaining argument +// after flags have been processed. +func Arg(i int) string { + return CommandLine.Arg(i) +} + +// NArg is the number of arguments remaining after flags have been processed. +func (f *FlagSet) NArg() int { return len(f.args) } + +// NArg is the number of arguments remaining after flags have been processed. +func NArg() int { return len(CommandLine.args) } + +// Args returns the non-flag arguments. +func (f *FlagSet) Args() []string { return f.args } + +// Args returns the non-flag command-line arguments. +func Args() []string { return CommandLine.args } + +// Var defines a flag with the specified name and usage string. The type and +// value of the flag are represented by the first argument, of type Value, which +// typically holds a user-defined implementation of Value. For instance, the +// caller could create a flag that turns a comma-separated string into a slice +// of strings by giving the slice the methods of Value; in particular, Set would +// decompose the comma-separated string into the slice. +func (f *FlagSet) Var(value Value, name string, usage string) { + f.VarP(value, name, "", usage) +} + +// VarPF is like VarP, but returns the flag created +func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag { + // Remember the default value as a string; it won't change. + flag := &Flag{ + Name: name, + Shorthand: shorthand, + Usage: usage, + Value: value, + DefValue: value.String(), + } + f.AddFlag(flag) + return flag +} + +// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { + f.VarPF(value, name, shorthand, usage) +} + +// AddFlag will add the flag to the FlagSet +func (f *FlagSet) AddFlag(flag *Flag) { + normalizedFlagName := f.normalizeFlagName(flag.Name) + + _, alreadyThere := f.formal[normalizedFlagName] + if alreadyThere { + msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) + fmt.Fprintln(f.out(), msg) + panic(msg) // Happens only if flags are declared with identical names + } + if f.formal == nil { + f.formal = make(map[NormalizedName]*Flag) + } + + flag.Name = string(normalizedFlagName) + f.formal[normalizedFlagName] = flag + f.orderedFormal = append(f.orderedFormal, flag) + + if flag.Shorthand == "" { + return + } + if len(flag.Shorthand) > 1 { + msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand) + fmt.Fprintf(f.out(), msg) + panic(msg) + } + if f.shorthands == nil { + f.shorthands = make(map[byte]*Flag) + } + c := flag.Shorthand[0] + used, alreadyThere := f.shorthands[c] + if alreadyThere { + msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) + fmt.Fprintf(f.out(), msg) + panic(msg) + } + f.shorthands[c] = flag +} + +// AddFlagSet adds one FlagSet to another. If a flag is already present in f +// the flag from newSet will be ignored. +func (f *FlagSet) AddFlagSet(newSet *FlagSet) { + if newSet == nil { + return + } + newSet.VisitAll(func(flag *Flag) { + if f.Lookup(flag.Name) == nil { + f.AddFlag(flag) + } + }) +} + +// Var defines a flag with the specified name and usage string. The type and +// value of the flag are represented by the first argument, of type Value, which +// typically holds a user-defined implementation of Value. For instance, the +// caller could create a flag that turns a comma-separated string into a slice +// of strings by giving the slice the methods of Value; in particular, Set would +// decompose the comma-separated string into the slice. +func Var(value Value, name string, usage string) { + CommandLine.VarP(value, name, "", usage) +} + +// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. +func VarP(value Value, name, shorthand, usage string) { + CommandLine.VarP(value, name, shorthand, usage) +} + +// failf prints to standard error a formatted error and usage message and +// returns the error. +func (f *FlagSet) failf(format string, a ...interface{}) error { + err := fmt.Errorf(format, a...) + if f.errorHandling != ContinueOnError { + fmt.Fprintln(f.out(), err) + f.usage() + } + return err +} + +// usage calls the Usage method for the flag set, or the usage function if +// the flag set is CommandLine. +func (f *FlagSet) usage() { + if f == CommandLine { + Usage() + } else if f.Usage == nil { + defaultUsage(f) + } else { + f.Usage() + } +} + +// --unknown (args will be empty) +// --unknown --next-flag ... (args will be --next-flag ...) +// --unknown arg ... (args will be arg ...) +func stripUnknownFlagValue(args []string) []string { + if len(args) == 0 { + //--unknown + return args + } + + first := args[0] + if len(first) > 0 && first[0] == '-' { + //--unknown --next-flag ... + return args + } + + //--unknown arg ... (args will be arg ...) + if len(args) > 1 { + return args[1:] + } + return nil +} + +func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { + a = args + name := s[2:] + if len(name) == 0 || name[0] == '-' || name[0] == '=' { + err = f.failf("bad flag syntax: %s", s) + return + } + + split := strings.SplitN(name, "=", 2) + name = split[0] + flag, exists := f.formal[f.normalizeFlagName(name)] + + if !exists { + switch { + case name == "help": + f.usage() + return a, ErrHelp + case f.ParseErrorsWhitelist.UnknownFlags: + // --unknown=unknownval arg ... + // we do not want to lose arg in this case + if len(split) >= 2 { + return a, nil + } + + return stripUnknownFlagValue(a), nil + default: + err = f.failf("unknown flag: --%s", name) + return + } + } + + var value string + if len(split) == 2 { + // '--flag=arg' + value = split[1] + } else if flag.NoOptDefVal != "" { + // '--flag' (arg was optional) + value = flag.NoOptDefVal + } else if len(a) > 0 { + // '--flag arg' + value = a[0] + a = a[1:] + } else { + // '--flag' (arg was required) + err = f.failf("flag needs an argument: %s", s) + return + } + + err = fn(flag, value) + if err != nil { + f.failf(err.Error()) + } + return +} + +func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) { + outArgs = args + + if strings.HasPrefix(shorthands, "test.") { + return + } + + outShorts = shorthands[1:] + c := shorthands[0] + + flag, exists := f.shorthands[c] + if !exists { + switch { + case c == 'h': + f.usage() + err = ErrHelp + return + case f.ParseErrorsWhitelist.UnknownFlags: + // '-f=arg arg ...' + // we do not want to lose arg in this case + if len(shorthands) > 2 && shorthands[1] == '=' { + outShorts = "" + return + } + + outArgs = stripUnknownFlagValue(outArgs) + return + default: + err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) + return + } + } + + var value string + if len(shorthands) > 2 && shorthands[1] == '=' { + // '-f=arg' + value = shorthands[2:] + outShorts = "" + } else if flag.NoOptDefVal != "" { + // '-f' (arg was optional) + value = flag.NoOptDefVal + } else if len(shorthands) > 1 { + // '-farg' + value = shorthands[1:] + outShorts = "" + } else if len(args) > 0 { + // '-f arg' + value = args[0] + outArgs = args[1:] + } else { + // '-f' (arg was required) + err = f.failf("flag needs an argument: %q in -%s", c, shorthands) + return + } + + if flag.ShorthandDeprecated != "" { + fmt.Fprintf(f.out(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated) + } + + err = fn(flag, value) + if err != nil { + f.failf(err.Error()) + } + return +} + +func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) { + a = args + shorthands := s[1:] + + // "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv"). + for len(shorthands) > 0 { + shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn) + if err != nil { + return + } + } + + return +} + +func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { + for len(args) > 0 { + s := args[0] + args = args[1:] + if len(s) == 0 || s[0] != '-' || len(s) == 1 { + if !f.interspersed { + f.args = append(f.args, s) + f.args = append(f.args, args...) + return nil + } + f.args = append(f.args, s) + continue + } + + if s[1] == '-' { + if len(s) == 2 { // "--" terminates the flags + f.argsLenAtDash = len(f.args) + f.args = append(f.args, args...) + break + } + args, err = f.parseLongArg(s, args, fn) + } else { + args, err = f.parseShortArg(s, args, fn) + } + if err != nil { + return + } + } + return +} + +// Parse parses flag definitions from the argument list, which should not +// include the command name. Must be called after all flags in the FlagSet +// are defined and before flags are accessed by the program. +// The return value will be ErrHelp if -help was set but not defined. +func (f *FlagSet) Parse(arguments []string) error { + if f.addedGoFlagSets != nil { + for _, goFlagSet := range f.addedGoFlagSets { + goFlagSet.Parse(nil) + } + } + f.parsed = true + + if len(arguments) < 0 { + return nil + } + + f.args = make([]string, 0, len(arguments)) + + set := func(flag *Flag, value string) error { + return f.Set(flag.Name, value) + } + + err := f.parseArgs(arguments, set) + if err != nil { + switch f.errorHandling { + case ContinueOnError: + return err + case ExitOnError: + fmt.Println(err) + os.Exit(2) + case PanicOnError: + panic(err) + } + } + return nil +} + +type parseFunc func(flag *Flag, value string) error + +// ParseAll parses flag definitions from the argument list, which should not +// include the command name. The arguments for fn are flag and value. Must be +// called after all flags in the FlagSet are defined and before flags are +// accessed by the program. The return value will be ErrHelp if -help was set +// but not defined. +func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error { + f.parsed = true + f.args = make([]string, 0, len(arguments)) + + err := f.parseArgs(arguments, fn) + if err != nil { + switch f.errorHandling { + case ContinueOnError: + return err + case ExitOnError: + os.Exit(2) + case PanicOnError: + panic(err) + } + } + return nil +} + +// Parsed reports whether f.Parse has been called. +func (f *FlagSet) Parsed() bool { + return f.parsed +} + +// Parse parses the command-line flags from os.Args[1:]. Must be called +// after all flags are defined and before flags are accessed by the program. +func Parse() { + // Ignore errors; CommandLine is set for ExitOnError. + CommandLine.Parse(os.Args[1:]) +} + +// ParseAll parses the command-line flags from os.Args[1:] and called fn for each. +// The arguments for fn are flag and value. Must be called after all flags are +// defined and before flags are accessed by the program. +func ParseAll(fn func(flag *Flag, value string) error) { + // Ignore errors; CommandLine is set for ExitOnError. + CommandLine.ParseAll(os.Args[1:], fn) +} + +// SetInterspersed sets whether to support interspersed option/non-option arguments. +func SetInterspersed(interspersed bool) { + CommandLine.SetInterspersed(interspersed) +} + +// Parsed returns true if the command-line flags have been parsed. +func Parsed() bool { + return CommandLine.Parsed() +} + +// CommandLine is the default set of command-line flags, parsed from os.Args. +var CommandLine = NewFlagSet(os.Args[0], ExitOnError) + +// NewFlagSet returns a new, empty flag set with the specified name, +// error handling property and SortFlags set to true. +func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { + f := &FlagSet{ + name: name, + errorHandling: errorHandling, + argsLenAtDash: -1, + interspersed: true, + SortFlags: true, + } + return f +} + +// SetInterspersed sets whether to support interspersed option/non-option arguments. +func (f *FlagSet) SetInterspersed(interspersed bool) { + f.interspersed = interspersed +} + +// Init sets the name and error handling property for a flag set. +// By default, the zero FlagSet uses an empty name and the +// ContinueOnError error handling policy. +func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { + f.name = name + f.errorHandling = errorHandling + f.argsLenAtDash = -1 +} + +// -- string Value +type stringValue string + +func newStringValue(val string, p *string) *stringValue { + *p = val + return (*stringValue)(p) +} + +func (s *stringValue) Set(val string) error { + *s = stringValue(val) + return nil +} +func (s *stringValue) Type() string { + return "string" +} + +func (s *stringValue) String() string { return string(*s) } + +// StringVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a string variable in which to store the value of the flag. +func (f *FlagSet) StringVar(p *string, name string, value string, usage string) { + f.VarP(newStringValue(value, p), name, "", usage) +} + +// optional interface to indicate boolean flags that can be +// supplied without "=value" text +type boolFlag interface { + Value + IsBoolFlag() bool +} + +// -- bool Value +type boolValue bool + +func newBoolValue(val bool, p *bool) *boolValue { + *p = val + return (*boolValue)(p) +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + *b = boolValue(v) + return err +} + +func (b *boolValue) Type() string { + return "bool" +} + +func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } + +func (b *boolValue) IsBoolFlag() bool { return true } + +// BoolVar defines a bool flag with specified name, default value, and usage string. +// The argument p points to a bool variable in which to store the value of the flag. +func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) { + f.BoolVarP(p, name, "", value, usage) +} + +// BoolVarP is like BoolVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) { + flag := f.VarPF(newBoolValue(value, p), name, shorthand, usage) + flag.NoOptDefVal = "true" +} From 834b687429bc9abef73d303a1639db443cd23294 Mon Sep 17 00:00:00 2001 From: shawn Date: Fri, 19 Jul 2024 17:21:26 +0800 Subject: [PATCH 09/41] bug fixes: fix bug in generator_test Signed-off-by: shawn --- tool/internal_pkg/generator/generator_test.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index cccee89e11..b907172f4f 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -56,6 +56,9 @@ func TestConfig_Pack(t *testing.T) { RecordCmd string ThriftPluginTimeLimit time.Duration TemplateDir string + InitOutputDir string + RenderTplDir string + TemplateFile string Protocol string HandlerReturnKeepResp bool } @@ -69,7 +72,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -97,6 +100,9 @@ func TestConfig_Pack(t *testing.T) { FrugalPretouch: tt.fields.FrugalPretouch, ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, TemplateDir: tt.fields.TemplateDir, + InitOutputDir: tt.fields.InitOutputDir, + RenderTplDir: tt.fields.RenderTplDir, + TemplateFile: tt.fields.TemplateFile, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { From 3339e16a2c14c22f85851d99a07f94e767485d9f Mon Sep 17 00:00:00 2001 From: shawn Date: Fri, 19 Jul 2024 22:21:26 +0800 Subject: [PATCH 10/41] bug fixes: add gen-path Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 3 - tool/cmd/kitex/args/tpl_args.go | 2 + tool/internal_pkg/util/command.go | 1 - tool/internal_pkg/util/flag.go | 324 ------------------------------ 4 files changed, 2 insertions(+), 328 deletions(-) diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 4ce2b3eed9..3941e8f19c 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -223,9 +223,6 @@ func (a *Arguments) checkServiceName() error { if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } - if a.ServiceName != "" && a.RenderTplDir != "" { - return fmt.Errorf("template render and -service cannot be used at the same time") - } if a.ServiceName != "" { a.GenerateMain = true } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 922153679c..fe937eb848 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -135,6 +135,8 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") renderCmd.Flags().StringVar(&a.TemplateFile, "f", "", "Specify template init path") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 4746690060..72bc2a23fc 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -55,7 +55,6 @@ func (c *Command) Flags() *FlagSet { // PrintUsage prints the usage of the Command func (c *Command) PrintUsage() { log.Warn("Usage: %s\n\n%s\n\n", c.Use, c.Long) - c.flags.PrintDefaults() for _, cmd := range c.commands { log.Warnf(" %s: %s\n", cmd.Use, cmd.Short) } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 798f566a7f..71906891b7 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -15,13 +15,11 @@ package util import ( - "bytes" "errors" goflag "flag" "fmt" "io" "os" - "sort" "strconv" "strings" ) @@ -118,22 +116,6 @@ type SliceValue interface { GetSlice() []string } -// sortFlags returns the flags as a slice in lexicographical sorted order. -func sortFlags(flags map[NormalizedName]*Flag) []*Flag { - list := make(sort.StringSlice, len(flags)) - i := 0 - for k := range flags { - list[i] = string(k) - i++ - } - list.Sort() - result := make([]*Flag, len(list)) - for i, name := range list { - result[i] = flags[NormalizedName(name)] - } - return result -} - // SetNormalizeFunc allows you to add a function which can translate flag names. // Flags added to the FlagSet will be translated and then when anything tries to // look up the flag that will also be translated. So it would be possible to create @@ -184,185 +166,6 @@ func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } -// VisitAll visits the flags in lexicographical order or -// in primordial order if f.SortFlags is false, calling fn for each. -// It visits all flags, even those not set. -func (f *FlagSet) VisitAll(fn func(*Flag)) { - if len(f.formal) == 0 { - return - } - - var flags []*Flag - if f.SortFlags { - if len(f.formal) != len(f.sortedFormal) { - f.sortedFormal = sortFlags(f.formal) - } - flags = f.sortedFormal - } else { - flags = f.orderedFormal - } - - for _, flag := range flags { - fn(flag) - } -} - -// HasFlags returns a bool to indicate if the FlagSet has any flags defined. -func (f *FlagSet) HasFlags() bool { - return len(f.formal) > 0 -} - -// HasAvailableFlags returns a bool to indicate if the FlagSet has any flags -// that are not hidden. -func (f *FlagSet) HasAvailableFlags() bool { - for _, flag := range f.formal { - if !flag.Hidden { - return true - } - } - return false -} - -// VisitAll visits the command-line flags in lexicographical order or -// in primordial order if f.SortFlags is false, calling fn for each. -// It visits all flags, even those not set. -func VisitAll(fn func(*Flag)) { - CommandLine.VisitAll(fn) -} - -// Visit visits the flags in lexicographical order or -// in primordial order if f.SortFlags is false, calling fn for each. -// It visits only those flags that have been set. -func (f *FlagSet) Visit(fn func(*Flag)) { - if len(f.actual) == 0 { - return - } - - var flags []*Flag - if f.SortFlags { - if len(f.actual) != len(f.sortedActual) { - f.sortedActual = sortFlags(f.actual) - } - flags = f.sortedActual - } else { - flags = f.orderedActual - } - - for _, flag := range flags { - fn(flag) - } -} - -// Visit visits the command-line flags in lexicographical order or -// in primordial order if f.SortFlags is false, calling fn for each. -// It visits only those flags that have been set. -func Visit(fn func(*Flag)) { - CommandLine.Visit(fn) -} - -// Lookup returns the Flag structure of the named flag, returning nil if none exists. -func (f *FlagSet) Lookup(name string) *Flag { - return f.lookup(f.normalizeFlagName(name)) -} - -// ShorthandLookup returns the Flag structure of the short handed flag, -// returning nil if none exists. -// It panics, if len(name) > 1. -func (f *FlagSet) ShorthandLookup(name string) *Flag { - if name == "" { - return nil - } - if len(name) > 1 { - msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name) - fmt.Fprintf(f.out(), msg) - panic(msg) - } - c := name[0] - return f.shorthands[c] -} - -// lookup returns the Flag structure of the named flag, returning nil if none exists. -func (f *FlagSet) lookup(name NormalizedName) *Flag { - return f.formal[name] -} - -// func to return a given type for a given flag name -func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { - flag := f.Lookup(name) - if flag == nil { - err := fmt.Errorf("flag accessed but not defined: %s", name) - return nil, err - } - - sval := flag.Value.String() - result, err := convFunc(sval) - if err != nil { - return nil, err - } - return result, nil -} - -// ArgsLenAtDash will return the length of f.Args at the moment when a -- was -// found during arg parsing. This allows your program to know which args were -// before the -- and which came after. -func (f *FlagSet) ArgsLenAtDash() int { - return f.argsLenAtDash -} - -// MarkDeprecated indicated that a flag is deprecated in your program. It will -// continue to function but will not show up in help or usage messages. Using -// this flag will also print the given usageMessage. -func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { - flag := f.Lookup(name) - if flag == nil { - return fmt.Errorf("flag %q does not exist", name) - } - if usageMessage == "" { - return fmt.Errorf("deprecated message for flag %q must be set", name) - } - flag.Deprecated = usageMessage - flag.Hidden = true - return nil -} - -// MarkShorthandDeprecated will mark the shorthand of a flag deprecated in your -// program. It will continue to function but will not show up in help or usage -// messages. Using this flag will also print the given usageMessage. -func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error { - flag := f.Lookup(name) - if flag == nil { - return fmt.Errorf("flag %q does not exist", name) - } - if usageMessage == "" { - return fmt.Errorf("deprecated message for flag %q must be set", name) - } - flag.ShorthandDeprecated = usageMessage - return nil -} - -// MarkHidden sets a flag to 'hidden' in your program. It will continue to -// function but will not show up in help or usage messages. -func (f *FlagSet) MarkHidden(name string) error { - flag := f.Lookup(name) - if flag == nil { - return fmt.Errorf("flag %q does not exist", name) - } - flag.Hidden = true - return nil -} - -// Lookup returns the Flag structure of the named command-line flag, -// returning nil if none exists. -func Lookup(name string) *Flag { - return CommandLine.Lookup(name) -} - -// ShorthandLookup returns the Flag structure of the short handed flag, -// returning nil if none exists. -func ShorthandLookup(name string) *Flag { - return CommandLine.ShorthandLookup(name) -} - // Set sets the value of the named flag. func (f *FlagSet) Set(name, value string) error { normalName := f.normalizeFlagName(name) @@ -414,29 +217,11 @@ func (f *FlagSet) SetAnnotation(name, key string, values []string) error { return nil } -// Changed returns true if the flag was explicitly set during Parse() and false -// otherwise -func (f *FlagSet) Changed(name string) bool { - flag := f.Lookup(name) - // If a flag doesn't exist, it wasn't changed.... - if flag == nil { - return false - } - return flag.Changed -} - // Set sets the value of the named command-line flag. func Set(name, value string) error { return CommandLine.Set(name, value) } -// PrintDefaults prints, to standard error unless configured -// otherwise, the default values of all defined flags in the set. -func (f *FlagSet) PrintDefaults() { - usages := f.FlagUsages() - fmt.Fprint(f.out(), usages) -} - // UnquoteUsage extracts a back-quoted name from the usage // string for a flag and returns it and the un-quoted usage. // Given "a `name` to show" it returns ("name", "a name to show"). @@ -530,87 +315,6 @@ func wrap(i, w int, s string) string { } -// FlagUsagesWrapped returns a string containing the usage information -// for all flags in the FlagSet. Wrapped to `cols` columns (0 for no -// wrapping) -func (f *FlagSet) FlagUsagesWrapped(cols int) string { - buf := new(bytes.Buffer) - - lines := make([]string, 0, len(f.formal)) - - maxlen := 0 - f.VisitAll(func(flag *Flag) { - if flag.Hidden { - return - } - - line := "" - if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { - line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) - } else { - line = fmt.Sprintf(" --%s", flag.Name) - } - - varname, usage := UnquoteUsage(flag) - if varname != "" { - line += " " + varname - } - - // This special character will be replaced with spacing once the - // correct alignment is calculated - line += "\x00" - if len(line) > maxlen { - maxlen = len(line) - } - - line += usage - if len(flag.Deprecated) != 0 { - line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) - } - - lines = append(lines, line) - }) - - for _, line := range lines { - sidx := strings.Index(line, "\x00") - spacing := strings.Repeat(" ", maxlen-sidx) - // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx - fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) - } - - return buf.String() -} - -// FlagUsages returns a string containing the usage information for all flags in -// the FlagSet -func (f *FlagSet) FlagUsages() string { - return f.FlagUsagesWrapped(0) -} - -// PrintDefaults prints to standard error the default values of all defined command-line flags. -func PrintDefaults() { - CommandLine.PrintDefaults() -} - -// defaultUsage is the default function to print a usage message. -func defaultUsage(f *FlagSet) { - fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) - f.PrintDefaults() -} - -// NOTE: Usage is not just defaultUsage(CommandLine) -// because it serves (via godoc flag Usage) as the example -// for how to write your own usage function. - -// Usage prints to standard error a usage message documenting all defined command-line flags. -// The function is a variable that may be changed to point to a custom function. -// By default it prints a simple header and calls PrintDefaults; for details about the -// format of the output and how to control it, see the documentation for PrintDefaults. -var Usage = func() { - fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - PrintDefaults() -} - // NFlag returns the number of flags that have been set. func (f *FlagSet) NFlag() int { return len(f.actual) } @@ -712,19 +416,6 @@ func (f *FlagSet) AddFlag(flag *Flag) { f.shorthands[c] = flag } -// AddFlagSet adds one FlagSet to another. If a flag is already present in f -// the flag from newSet will be ignored. -func (f *FlagSet) AddFlagSet(newSet *FlagSet) { - if newSet == nil { - return - } - newSet.VisitAll(func(flag *Flag) { - if f.Lookup(flag.Name) == nil { - f.AddFlag(flag) - } - }) -} - // Var defines a flag with the specified name and usage string. The type and // value of the flag are represented by the first argument, of type Value, which // typically holds a user-defined implementation of Value. For instance, the @@ -746,23 +437,10 @@ func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) if f.errorHandling != ContinueOnError { fmt.Fprintln(f.out(), err) - f.usage() } return err } -// usage calls the Usage method for the flag set, or the usage function if -// the flag set is CommandLine. -func (f *FlagSet) usage() { - if f == CommandLine { - Usage() - } else if f.Usage == nil { - defaultUsage(f) - } else { - f.Usage() - } -} - // --unknown (args will be empty) // --unknown --next-flag ... (args will be --next-flag ...) // --unknown arg ... (args will be arg ...) @@ -800,7 +478,6 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if !exists { switch { case name == "help": - f.usage() return a, ErrHelp case f.ParseErrorsWhitelist.UnknownFlags: // --unknown=unknownval arg ... @@ -854,7 +531,6 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse if !exists { switch { case c == 'h': - f.usage() err = ErrHelp return case f.ParseErrorsWhitelist.UnknownFlags: From a9d100b5325edd955de5f37d63fe59b5ed5d7970 Mon Sep 17 00:00:00 2001 From: shawn Date: Sat, 20 Jul 2024 16:26:25 +0800 Subject: [PATCH 11/41] bug fixes: treat option as unknown command Signed-off-by: shawn --- tool/cmd/kitex/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index ede33fa89b..ad17926347 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -78,7 +78,7 @@ func main() { } if os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version, curpath) - } else if os.Args[1] != "template" { + } else if !strings.HasPrefix(os.Args[1], "-") { log.Warnf("Unknown command %q", os.Args[1]) os.Exit(1) } else { From b3388a04bd5e3659410cc35a4e230730885cf201 Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 22 Jul 2024 10:32:49 +0800 Subject: [PATCH 12/41] bug fix: fix unused code in command and flag Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 5 +- tool/internal_pkg/util/command.go | 11 --- tool/internal_pkg/util/flag.go | 123 +++++++----------------------- 3 files changed, 28 insertions(+), 111 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index fe937eb848..6bfac0fdd6 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -68,15 +68,13 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { a.RenderTplDir = args[0] } var tplDir string - for i, arg := range args { + for _, arg := range args { if !strings.HasPrefix(arg, "-") { tplDir = arg - args = append(args[:i], args[i+1:]...) break } } if tplDir == "" { - cmd.PrintUsage() return fmt.Errorf("template directory is required") } log.Verbose = a.Verbose @@ -138,6 +136,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, "Specify a code gen path.") renderCmd.Flags().StringVar(&a.TemplateFile, "f", "", "Specify template init path") + renderCmd.Flags().Var(&a.Includes, "I", "Add IDL search path and template search path for includes.") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template init [flags] diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 72bc2a23fc..32e1a3ad55 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -16,7 +16,6 @@ package util import ( "fmt" - "github.com/cloudwego/kitex/tool/internal_pkg/log" "os" "strings" ) @@ -52,14 +51,6 @@ func (c *Command) Flags() *FlagSet { return c.flags } -// PrintUsage prints the usage of the Command -func (c *Command) PrintUsage() { - log.Warn("Usage: %s\n\n%s\n\n", c.Use, c.Long) - for _, cmd := range c.commands { - log.Warnf(" %s: %s\n", cmd.Use, cmd.Short) - } -} - // Parent returns a commands parent command. func (c *Command) Parent() *Command { return c.parent @@ -113,8 +104,6 @@ func nextArgs(args []string, x string) []string { for pos := 0; pos < len(args); pos++ { s := args[pos] switch { - case s == "--": - break case strings.HasPrefix(s, "-"): pos++ continue diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 71906891b7..f07e9b1d4a 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -67,7 +67,6 @@ type FlagSet struct { parsed bool actual map[NormalizedName]*Flag orderedActual []*Flag - sortedActual []*Flag formal map[NormalizedName]*Flag orderedFormal []*Flag sortedFormal []*Flag @@ -246,75 +245,6 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { return } -// Splits the string `s` on whitespace into an initial substring up to -// `i` runes in length and the remainder. Will go `slop` over `i` if -// that encompasses the entire string (which allows the caller to -// avoid short orphan words on the final line). -func wrapN(i, slop int, s string) (string, string) { - if i+slop > len(s) { - return s, "" - } - - w := strings.LastIndexAny(s[:i], " \t\n") - if w <= 0 { - return s, "" - } - nlPos := strings.LastIndex(s[:i], "\n") - if nlPos > 0 && nlPos < w { - return s[:nlPos], s[nlPos+1:] - } - return s[:w], s[w+1:] -} - -// Wraps the string `s` to a maximum width `w` with leading indent -// `i`. The first line is not indented (this is assumed to be done by -// caller). Pass `w` == 0 to do no wrapping -func wrap(i, w int, s string) string { - if w == 0 { - return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) - } - - // space between indent i and end of line width w into which - // we should wrap the text. - wrap := w - i - - var r, l string - - // Not enough space for sensible wrapping. Wrap as a block on - // the next line instead. - if wrap < 24 { - i = 16 - wrap = w - i - r += "\n" + strings.Repeat(" ", i) - } - // If still not enough space then don't even try to wrap. - if wrap < 24 { - return strings.Replace(s, "\n", r, -1) - } - - // Try to avoid short orphan words on the final line, by - // allowing wrapN to go a bit over if that would fit in the - // remainder of the line. - slop := 5 - wrap = wrap - slop - - // Handle first line, which is indented by the caller (or the - // special case above) - l, s = wrapN(wrap, slop, s) - r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) - - // Now wrap the rest - for s != "" { - var t string - - t, s = wrapN(wrap, slop, s) - r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) - } - - return r - -} - // NFlag returns the number of flags that have been set. func (f *FlagSet) NFlag() int { return len(f.actual) } @@ -359,7 +289,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) { } // VarPF is like VarP, but returns the flag created -func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag { +func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) (*Flag, error) { // Remember the default value as a string; it won't change. flag := &Flag{ Name: name, @@ -368,8 +298,11 @@ func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag { Value: value, DefValue: value.String(), } - f.AddFlag(flag) - return flag + err := f.AddFlag(flag) + if err != nil { + return nil, err + } + return flag, nil } // VarP is like Var, but accepts a shorthand letter that can be used after a single dash. @@ -378,14 +311,14 @@ func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { } // AddFlag will add the flag to the FlagSet -func (f *FlagSet) AddFlag(flag *Flag) { +func (f *FlagSet) AddFlag(flag *Flag) error { normalizedFlagName := f.normalizeFlagName(flag.Name) _, alreadyThere := f.formal[normalizedFlagName] if alreadyThere { msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) fmt.Fprintln(f.out(), msg) - panic(msg) // Happens only if flags are declared with identical names + return fmt.Errorf("%s flag redefined: %s", f.name, flag.Name) } if f.formal == nil { f.formal = make(map[NormalizedName]*Flag) @@ -396,12 +329,11 @@ func (f *FlagSet) AddFlag(flag *Flag) { f.orderedFormal = append(f.orderedFormal, flag) if flag.Shorthand == "" { - return + return nil } if len(flag.Shorthand) > 1 { - msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand) - fmt.Fprintf(f.out(), msg) - panic(msg) + fmt.Fprintf(f.out(), "%q shorthand is more than one ASCII character", flag.Shorthand) + return fmt.Errorf("%q shorthand is more than one ASCII character", flag.Shorthand) } if f.shorthands == nil { f.shorthands = make(map[byte]*Flag) @@ -409,11 +341,11 @@ func (f *FlagSet) AddFlag(flag *Flag) { c := flag.Shorthand[0] used, alreadyThere := f.shorthands[c] if alreadyThere { - msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) - fmt.Fprintf(f.out(), msg) - panic(msg) + fmt.Fprintf(f.out(), "unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) + return fmt.Errorf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) } f.shorthands[c] = flag + return nil } // Var defines a flag with the specified name and usage string. The type and @@ -640,10 +572,6 @@ func (f *FlagSet) Parse(arguments []string) error { } f.parsed = true - if len(arguments) < 0 { - return nil - } - f.args = make([]string, 0, len(arguments)) set := func(flag *Flag, value string) error { @@ -774,13 +702,6 @@ func (f *FlagSet) StringVar(p *string, name string, value string, usage string) f.VarP(newStringValue(value, p), name, "", usage) } -// optional interface to indicate boolean flags that can be -// supplied without "=value" text -type boolFlag interface { - Value - IsBoolFlag() bool -} - // -- bool Value type boolValue bool @@ -805,12 +726,20 @@ func (b *boolValue) IsBoolFlag() bool { return true } // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. -func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) { - f.BoolVarP(p, name, "", value, usage) +func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) error { + err := f.BoolVarP(p, name, "", value, usage) + if err != nil { + return err + } + return nil } // BoolVarP is like BoolVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) { - flag := f.VarPF(newBoolValue(value, p), name, shorthand, usage) +func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) error { + flag, err := f.VarPF(newBoolValue(value, p), name, shorthand, usage) + if err != nil { + return err + } flag.NoOptDefVal = "true" + return nil } From 2987170c27df35eea8b4148c9c1868f0a5dfea60 Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 22 Jul 2024 10:56:27 +0800 Subject: [PATCH 13/41] bug fixes: fix lint Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 8 ++--- tool/internal_pkg/util/flag.go | 64 +++------------------------------ 2 files changed, 9 insertions(+), 63 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 6bfac0fdd6..7cf65ce8c9 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -129,14 +129,14 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { return a.checkPath(curpath) }, } - initCmd.Flags().StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") - renderCmd.Flags().StringVar(&a.ModuleName, "module", "", + initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") + renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, "Specify a code gen path.") - renderCmd.Flags().StringVar(&a.TemplateFile, "f", "", "Specify template init path") - renderCmd.Flags().Var(&a.Includes, "I", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVarP(&a.TemplateFile, "file", "f", "", "Specify single template path") + renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template init [flags] diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index f07e9b1d4a..d7aa1531ca 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -216,41 +216,9 @@ func (f *FlagSet) SetAnnotation(name, key string, values []string) error { return nil } -// Set sets the value of the named command-line flag. -func Set(name, value string) error { - return CommandLine.Set(name, value) -} - -// UnquoteUsage extracts a back-quoted name from the usage -// string for a flag and returns it and the un-quoted usage. -// Given "a `name` to show" it returns ("name", "a name to show"). -// If there are no back quotes, the name is an educated guess of the -// type of the flag's value, or the empty string if the flag is boolean. -func UnquoteUsage(flag *Flag) (name string, usage string) { - // Look for a back-quoted name, but avoid the strings package. - usage = flag.Usage - for i := 0; i < len(usage); i++ { - if usage[i] == '`' { - for j := i + 1; j < len(usage); j++ { - if usage[j] == '`' { - name = usage[i+1 : j] - usage = usage[:i] + name + usage[j+1:] - return name, usage - } - } - break // Only one back quote; use type name. - } - } - - return -} - // NFlag returns the number of flags that have been set. func (f *FlagSet) NFlag() int { return len(f.actual) } -// NFlag returns the number of command-line flags that have been set. -func NFlag() int { return len(CommandLine.actual) } - // Arg returns the i'th argument. Arg(0) is the first remaining argument // after flags have been processed. func (f *FlagSet) Arg(i int) string { @@ -260,24 +228,12 @@ func (f *FlagSet) Arg(i int) string { return f.args[i] } -// Arg returns the i'th command-line argument. Arg(0) is the first remaining argument -// after flags have been processed. -func Arg(i int) string { - return CommandLine.Arg(i) -} - // NArg is the number of arguments remaining after flags have been processed. func (f *FlagSet) NArg() int { return len(f.args) } -// NArg is the number of arguments remaining after flags have been processed. -func NArg() int { return len(CommandLine.args) } - // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } -// Args returns the non-flag command-line arguments. -func Args() []string { return CommandLine.args } - // Var defines a flag with the specified name and usage string. The type and // value of the flag are represented by the first argument, of type Value, which // typically holds a user-defined implementation of Value. For instance, the @@ -348,21 +304,6 @@ func (f *FlagSet) AddFlag(flag *Flag) error { return nil } -// Var defines a flag with the specified name and usage string. The type and -// value of the flag are represented by the first argument, of type Value, which -// typically holds a user-defined implementation of Value. For instance, the -// caller could create a flag that turns a comma-separated string into a slice -// of strings by giving the slice the methods of Value; in particular, Set would -// decompose the comma-separated string into the slice. -func Var(value Value, name string, usage string) { - CommandLine.VarP(value, name, "", usage) -} - -// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. -func VarP(value Value, name, shorthand, usage string) { - CommandLine.VarP(value, name, shorthand, usage) -} - // failf prints to standard error a formatted error and usage message and // returns the error. func (f *FlagSet) failf(format string, a ...interface{}) error { @@ -702,6 +643,11 @@ func (f *FlagSet) StringVar(p *string, name string, value string, usage string) f.VarP(newStringValue(value, p), name, "", usage) } +// StringVarP is like StringVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringVarP(p *string, name, shorthand string, value string, usage string) { + f.VarP(newStringValue(value, p), name, shorthand, usage) +} + // -- bool Value type boolValue bool From c7326e751cea4963e65d46af0fcb941d236922e9 Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 23 Jul 2024 16:20:20 +0800 Subject: [PATCH 14/41] fix: use gofumpt -extra to fix golanci lint Signed-off-by: shawn --- tool/internal_pkg/util/flag.go | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index d7aa1531ca..3c0b681197 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -234,16 +234,6 @@ func (f *FlagSet) NArg() int { return len(f.args) } // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } -// Var defines a flag with the specified name and usage string. The type and -// value of the flag are represented by the first argument, of type Value, which -// typically holds a user-defined implementation of Value. For instance, the -// caller could create a flag that turns a comma-separated string into a slice -// of strings by giving the slice the methods of Value; in particular, Set would -// decompose the comma-separated string into the slice. -func (f *FlagSet) Var(value Value, name string, usage string) { - f.VarP(value, name, "", usage) -} - // VarPF is like VarP, but returns the flag created func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) (*Flag, error) { // Remember the default value as a string; it won't change. @@ -631,6 +621,7 @@ func (s *stringValue) Set(val string) error { *s = stringValue(val) return nil } + func (s *stringValue) Type() string { return "string" } @@ -638,13 +629,12 @@ func (s *stringValue) Type() string { func (s *stringValue) String() string { return string(*s) } // StringVar defines a string flag with specified name, default value, and usage string. -// The argument p points to a string variable in which to store the value of the flag. -func (f *FlagSet) StringVar(p *string, name string, value string, usage string) { +func (f *FlagSet) StringVar(p *string, name, value, usage string) { f.VarP(newStringValue(value, p), name, "", usage) } // StringVarP is like StringVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) StringVarP(p *string, name, shorthand string, value string, usage string) { +func (f *FlagSet) StringVarP(p *string, name, shorthand, value, usage string) { f.VarP(newStringValue(value, p), name, shorthand, usage) } From 3441cfb5e549c96fd0c8b5226728f4e6fb320ebd Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 24 Jul 2024 22:28:10 +0800 Subject: [PATCH 15/41] fix: use gofumpt to avoid golangci lint Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 7cf65ce8c9..fc1391e73c 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -16,11 +16,12 @@ package args import ( "fmt" + "os" + "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/util" - "os" - "strings" ) func (a *Arguments) TemplateArgs(version, curpath string) error { @@ -163,7 +164,7 @@ Examples: Flags: `, version) }) - //renderCmd.PrintUsage() + // renderCmd.PrintUsage() templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) kitexCmd.AddCommand(templateCmd) if _, err := kitexCmd.ExecuteC(); err != nil { From 95049bd5b32675d1bb1a6235b09c651f03cd2aa0 Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 30 Jul 2024 19:07:27 +0800 Subject: [PATCH 16/41] add InitType --- tool/cmd/kitex/args/tpl_args.go | 1 + tool/internal_pkg/generator/generator.go | 1 + tool/internal_pkg/generator/generator_test.go | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index fc1391e73c..cadb49df82 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -131,6 +131,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { }, } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") + initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index c2a2404c5b..fe95db4bbc 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -136,6 +136,7 @@ type Config struct { // subcommand template InitOutputDir string + InitType string RenderTplDir string TemplateFile string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index b907172f4f..cc39bffc31 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -57,6 +57,7 @@ func TestConfig_Pack(t *testing.T) { ThriftPluginTimeLimit time.Duration TemplateDir string InitOutputDir string + InitType string RenderTplDir string TemplateFile string Protocol string @@ -72,7 +73,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -101,6 +102,7 @@ func TestConfig_Pack(t *testing.T) { ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, TemplateDir: tt.fields.TemplateDir, InitOutputDir: tt.fields.InitOutputDir, + InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, TemplateFile: tt.fields.TemplateFile, Protocol: tt.fields.Protocol, From f5ca3334b6fa3c22f8e51aae51f13ecb24f6045e Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 31 Jul 2024 20:05:20 +0800 Subject: [PATCH 17/41] add test for command and flag --- tool/cmd/kitex/args/tpl_args.go | 17 +- tool/cmd/kitex/main.go | 4 +- .../internal_pkg/generator/custom_template.go | 3 +- .../pluginmode/thriftgo/plugin.go | 4 + tool/internal_pkg/util/command.go | 45 +-- tool/internal_pkg/util/command_test.go | 256 +++++++++++++++--- tool/internal_pkg/util/flag.go | 245 ++++++++--------- tool/internal_pkg/util/flag_test.go | 160 +++++++++++ 8 files changed, 520 insertions(+), 214 deletions(-) create mode 100644 tool/internal_pkg/util/flag_test.go diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index cadb49df82..bd619433a1 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -17,7 +17,6 @@ package args import ( "fmt" "os" - "strings" "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" @@ -65,19 +64,10 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { Use: "render", Short: "Render command", RunE: func(cmd *util.Command, args []string) error { - if len(args) > 0 { - a.RenderTplDir = args[0] - } - var tplDir string - for _, arg := range args { - if !strings.HasPrefix(arg, "-") { - tplDir = arg - break - } - } - if tplDir == "" { - return fmt.Errorf("template directory is required") + if len(args) < 2 { + return fmt.Errorf("both template directory and idl is required") } + a.RenderTplDir = args[0] log.Verbose = a.Verbose for _, e := range a.extends { @@ -132,6 +122,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") + initCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index ec931c6506..6583596b53 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,6 +17,7 @@ package main import ( "bytes" "flag" + "fmt" "io/ioutil" "os" "os/exec" @@ -79,8 +80,7 @@ func main() { if os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version, curpath) } else if !strings.HasPrefix(os.Args[1], "-") { - log.Warnf("Unknown command %q", os.Args[1]) - os.Exit(1) + err = fmt.Errorf("unknown command %q", os.Args[1]) } else { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 1de769ce1b..17bebef833 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -15,6 +15,7 @@ package generator import ( + "errors" "fmt" "io/ioutil" "os" @@ -280,7 +281,7 @@ func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, } else { err = cg.commonGenerate(tpl) } - if err == errNoNewMethod { + if errors.Is(err, errNoNewMethod) { err = nil } return cg.fs, err diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index e65c3e179e..b3cd5c2052 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -119,6 +119,10 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } + //if conv.Config.InitType != "" { + // gen.GenerateCustomPackage(&conv.Package) + //} + res := &plugin.Response{ Warnings: conv.Warnings, } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 32e1a3ad55..fda97822d3 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -30,6 +30,14 @@ type Command struct { flags *FlagSet // helpFunc is help func defined by user. usage func() + // for debug + args []string +} + +// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden +// particularly useful when testing. +func (c *Command) SetArgs(a []string) { + c.args = a } func (c *Command) AddCommand(cmds ...*Command) error { @@ -51,24 +59,11 @@ func (c *Command) Flags() *FlagSet { return c.flags } -// Parent returns a commands parent command. -func (c *Command) Parent() *Command { - return c.parent -} - // HasParent determines if the command is a child command. func (c *Command) HasParent() bool { return c.parent != nil } -// Root finds root command. -func (c *Command) Root() *Command { - if c.HasParent() { - return c.Parent().Root() - } - return c -} - // HasSubCommands determines if the command has children commands. func (c *Command) HasSubCommands() bool { return len(c.commands) > 0 @@ -79,10 +74,16 @@ func stripFlags(args []string) []string { for len(args) > 0 { s := args[0] args = args[1:] - if s != "" && !strings.HasPrefix(s, "-") { + if strings.HasPrefix(s, "-") { + // handle "-f child child" args + if len(args) <= 1 { + break + } else { + args = args[1:] + continue + } + } else if s != "" && !strings.HasPrefix(s, "-") { commands = append(commands, s) - } else if strings.HasPrefix(s, "-") { - break } } return commands @@ -109,7 +110,8 @@ func nextArgs(args []string, x string) []string { continue case !strings.HasPrefix(s, "-"): if s == x { - var ret []string + // cannot use var ret []string cause it return nil + ret := make([]string, 0) ret = append(ret, args[:pos]...) ret = append(ret, args[pos+1:]...) return ret @@ -169,8 +171,10 @@ func (c *Command) UsageFunc() func() { // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { - args := os.Args[1:] - + args := c.args + if c.args == nil { + args = os.Args[1:] + } cmd, flags, err := c.Find(args) if err != nil { return c, err @@ -191,8 +195,9 @@ func (c *Command) execute(a []string) error { if err != nil { return err } + argWoFlags := c.Flags().Args() if c.RunE != nil { - err := c.RunE(c, a) + err := c.RunE(c, argWoFlags) if err != nil { return err } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go index a5a4766850..9bb5723d67 100644 --- a/tool/internal_pkg/util/command_test.go +++ b/tool/internal_pkg/util/command_test.go @@ -15,75 +15,249 @@ package util import ( - "os" + "fmt" + "reflect" + "strings" "testing" ) -func TestAddCommand(t *testing.T) { - rootCmd := &Command{Use: "root"} - childCmd := &Command{Use: "child"} +func emptyRun(*Command, []string) error { return nil } - // Test adding a valid child command - err := rootCmd.AddCommand(childCmd) - if err != nil { - t.Fatalf("expected no error, got %v", err) +func executeCommand(root *Command, args ...string) (err error) { + _, err = executeCommandC(root, args...) + return err +} + +func executeCommandC(root *Command, args ...string) (c *Command, err error) { + root.SetArgs(args) + c, err = root.ExecuteC() + return c, err +} + +const onetwo = "one two" + +func TestSingleCommand(t *testing.T) { + rootCmd := &Command{ + Use: "root", + RunE: func(_ *Command, args []string) error { return nil }, } + aCmd := &Command{Use: "a", RunE: emptyRun} + bCmd := &Command{Use: "b", RunE: emptyRun} + rootCmd.AddCommand(aCmd, bCmd) + + _ = executeCommand(rootCmd, "one", "two") +} - if len(rootCmd.commands) != 1 { - t.Fatalf("expected 1 command, got %d", len(rootCmd.commands)) +func TestChildCommand(t *testing.T) { + var child1CmdArgs []string + rootCmd := &Command{Use: "root", RunE: emptyRun} + child1Cmd := &Command{ + Use: "child1", + RunE: func(_ *Command, args []string) error { child1CmdArgs = args; return nil }, } + child2Cmd := &Command{Use: "child2", RunE: emptyRun} + rootCmd.AddCommand(child1Cmd, child2Cmd) - if rootCmd.commands[0] != childCmd { - t.Fatalf("expected child command to be added") + err := executeCommand(rootCmd, "child1", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) } - // Test adding a command to itself - err = rootCmd.AddCommand(rootCmd) - if err == nil { - t.Fatalf("expected an error, got nil") + got := strings.Join(child1CmdArgs, " ") + if got != onetwo { + t.Errorf("child1CmdArgs expected: %q, got: %q", onetwo, got) } +} - expectedErr := "command can't be a child of itself" - if err.Error() != expectedErr { - t.Fatalf("expected error %q, got %q", expectedErr, err.Error()) +func TestCallCommandWithoutSubcommands(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + err := executeCommand(rootCmd) + if err != nil { + t.Errorf("Calling command without subcommands should not have error: %v", err) } } -func TestExecuteC(t *testing.T) { - rootCmd := &Command{ +func TestRootExecuteUnknownCommand(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + rootCmd.AddCommand(&Command{Use: "child", RunE: emptyRun}) + + _ = executeCommand(rootCmd, "unknown") +} + +func TestSubcommandExecuteC(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + + _, err := executeCommandC(rootCmd, "child") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestFind(t *testing.T) { + var foo, bar string + root := &Command{ Use: "root", - RunE: func(cmd *Command, args []string) error { - return nil - }, } + root.Flags().StringVarP(&foo, "foo", "f", "", "") + root.Flags().StringVarP(&bar, "bar", "b", "something", "") + + child := &Command{ + Use: "child", + } + root.AddCommand(child) - subCmd := &Command{ - Use: "sub", - RunE: func(cmd *Command, args []string) error { - return nil + testCases := []struct { + args []string + expectedFoundArgs []string + }{ + { + []string{"child"}, + []string{}, + }, + { + []string{"child", "child"}, + []string{"child"}, + }, + { + []string{"child", "foo", "child", "bar", "child", "baz", "child"}, + []string{"foo", "child", "bar", "child", "baz", "child"}, + }, + { + []string{"-f", "child", "child"}, + []string{"-f", "child"}, + }, + { + []string{"child", "-f", "child"}, + []string{"-f", "child"}, + }, + { + []string{"-b", "child", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b"}, + []string{"-b"}, }, + { + []string{"-b", "-f", "child", "child"}, + []string{"-b", "-f", "child"}, + }, + { + []string{"-f", "child", "-b", "something", "child"}, + []string{"-f", "child", "-b", "something"}, + }, + { + []string{"-f", "child", "child", "-b"}, + []string{"-f", "child", "-b"}, + }, + { + []string{"-f=child", "-b=something", "child"}, + []string{"-f=child", "-b=something"}, + }, + { + []string{"--foo", "child", "--bar", "something", "child"}, + []string{"--foo", "child", "--bar", "something"}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { + cmd, foundArgs, err := root.Find(tc.args) + if err != nil { + t.Fatal(err) + } + + if cmd != child { + t.Fatal("Expected cmd to be child, but it was not") + } + + if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) { + t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs) + } + }) } +} - rootCmd.AddCommand(subCmd) +func TestFlagLong(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } - // Simulate command line arguments - os.Args = []string{"root", "sub"} + var intFlagValue int + var stringFlagValue string + c.Flags().IntVar(&intFlagValue, "intf", -1, "") + c.Flags().StringVar(&stringFlagValue, "sf", "", "") - // Execute the command - cmd, err := rootCmd.ExecuteC() + err := executeCommand(c, "--intf=7", "--sf=abc", "one", "--", "two") if err != nil { - t.Fatalf("expected no error, got %v", err) + t.Errorf("Unexpected error: %v", err) + } + + if intFlagValue != 7 { + t.Errorf("Expected intFlagValue: %v, got %v", 7, intFlagValue) + } + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } - if cmd.Use != "sub" { - t.Fatalf("expected sub command to be executed, got %s", cmd.Use) + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) } +} - // Simulate command line arguments with an unknown command - os.Args = []string{"root", "unknown"} +func TestFlagShort(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } + + var intFlagValue int + var stringFlagValue string + c.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") + c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") + + err := executeCommand(c, "-i", "7", "-sabc", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if intFlagValue != 7 { + t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) + } + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) + } + + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) + } +} + +func TestChildFlag(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + + var intFlagValue int + childCmd.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") + + err := executeCommand(rootCmd, "child", "-i7") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } - _, err = rootCmd.ExecuteC() - if err == nil { - t.Fatalf("expected an error for unknown command, got nil") + if intFlagValue != 7 { + t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) } } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 3c0b681197..21096cf8f1 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -22,6 +22,7 @@ import ( "os" "strconv" "strings" + "time" ) // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. @@ -69,7 +70,6 @@ type FlagSet struct { orderedActual []*Flag formal map[NormalizedName]*Flag orderedFormal []*Flag - sortedFormal []*Flag shorthands map[byte]*Flag args []string // arguments after flags argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- @@ -83,17 +83,15 @@ type FlagSet struct { // A Flag represents the state of a flag. type Flag struct { - Name string // name as it appears on command line - Shorthand string // one-letter abbreviated flag - Usage string // help message - Value Value // value as set - DefValue string // default value (as text); for usage message - Changed bool // If the user set the value (or if left to default) - NoOptDefVal string // default value (as text); if the flag is on the command line without any options - Deprecated string // If this flag is deprecated, this string is the new or now thing to use - Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text - ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use - Annotations map[string][]string // used by cobra.Command bash autocomple code + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string // default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use } // Value is the interface to the dynamic value stored in a flag. @@ -115,29 +113,6 @@ type SliceValue interface { GetSlice() []string } -// SetNormalizeFunc allows you to add a function which can translate flag names. -// Flags added to the FlagSet will be translated and then when anything tries to -// look up the flag that will also be translated. So it would be possible to create -// a flag named "getURL" and have it translated to "geturl". A user could then pass -// "--getUrl" which may also be translated to "geturl" and everything will work. -func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { - f.normalizeNameFunc = n - f.sortedFormal = f.sortedFormal[:0] - for fname, flag := range f.formal { - nname := f.normalizeFlagName(flag.Name) - if fname == nname { - continue - } - flag.Name = string(nname) - delete(f.formal, fname) - f.formal[nname] = flag - if _, set := f.actual[fname]; set { - delete(f.actual, fname) - f.actual[nname] = flag - } - } -} - // GetNormalizeFunc returns the previously set NormalizeFunc of a function which // does no translation, if not set previously. func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { @@ -152,6 +127,16 @@ func (f *FlagSet) normalizeFlagName(name string) NormalizedName { return n(f, name) } +// Lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) Lookup(name string) *Flag { + return f.lookup(f.normalizeFlagName(name)) +} + +// lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) lookup(name NormalizedName) *Flag { + return f.formal[name] +} + func (f *FlagSet) out() io.Writer { if f.output == nil { return os.Stderr @@ -200,37 +185,6 @@ func (f *FlagSet) Set(name, value string) error { return nil } -// SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. -// This is sometimes used by spf13/cobra programs which want to generate additional -// bash completion information. -func (f *FlagSet) SetAnnotation(name, key string, values []string) error { - normalName := f.normalizeFlagName(name) - flag, ok := f.formal[normalName] - if !ok { - return fmt.Errorf("no such flag -%v", name) - } - if flag.Annotations == nil { - flag.Annotations = map[string][]string{} - } - flag.Annotations[key] = values - return nil -} - -// NFlag returns the number of flags that have been set. -func (f *FlagSet) NFlag() int { return len(f.actual) } - -// Arg returns the i'th argument. Arg(0) is the first remaining argument -// after flags have been processed. -func (f *FlagSet) Arg(i int) string { - if i < 0 || i >= len(f.args) { - return "" - } - return f.args[i] -} - -// NArg is the number of arguments remaining after flags have been processed. -func (f *FlagSet) NArg() int { return len(f.args) } - // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } @@ -526,60 +480,11 @@ func (f *FlagSet) Parse(arguments []string) error { type parseFunc func(flag *Flag, value string) error -// ParseAll parses flag definitions from the argument list, which should not -// include the command name. The arguments for fn are flag and value. Must be -// called after all flags in the FlagSet are defined and before flags are -// accessed by the program. The return value will be ErrHelp if -help was set -// but not defined. -func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error { - f.parsed = true - f.args = make([]string, 0, len(arguments)) - - err := f.parseArgs(arguments, fn) - if err != nil { - switch f.errorHandling { - case ContinueOnError: - return err - case ExitOnError: - os.Exit(2) - case PanicOnError: - panic(err) - } - } - return nil -} - // Parsed reports whether f.Parse has been called. func (f *FlagSet) Parsed() bool { return f.parsed } -// Parse parses the command-line flags from os.Args[1:]. Must be called -// after all flags are defined and before flags are accessed by the program. -func Parse() { - // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.Parse(os.Args[1:]) -} - -// ParseAll parses the command-line flags from os.Args[1:] and called fn for each. -// The arguments for fn are flag and value. Must be called after all flags are -// defined and before flags are accessed by the program. -func ParseAll(fn func(flag *Flag, value string) error) { - // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.ParseAll(os.Args[1:], fn) -} - -// SetInterspersed sets whether to support interspersed option/non-option arguments. -func SetInterspersed(interspersed bool) { - CommandLine.SetInterspersed(interspersed) -} - -// Parsed returns true if the command-line flags have been parsed. -func Parsed() bool { - return CommandLine.Parsed() -} - -// CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) // NewFlagSet returns a new, empty flag set with the specified name, @@ -595,20 +500,6 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { return f } -// SetInterspersed sets whether to support interspersed option/non-option arguments. -func (f *FlagSet) SetInterspersed(interspersed bool) { - f.interspersed = interspersed -} - -// Init sets the name and error handling property for a flag set. -// By default, the zero FlagSet uses an empty name and the -// ContinueOnError error handling policy. -func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { - f.name = name - f.errorHandling = errorHandling - f.argsLenAtDash = -1 -} - // -- string Value type stringValue string @@ -622,10 +513,6 @@ func (s *stringValue) Set(val string) error { return nil } -func (s *stringValue) Type() string { - return "string" -} - func (s *stringValue) String() string { return string(*s) } // StringVar defines a string flag with specified name, default value, and usage string. @@ -638,6 +525,21 @@ func (f *FlagSet) StringVarP(p *string, name, shorthand, value, usage string) { f.VarP(newStringValue(value, p), name, shorthand, usage) } +// String defines a string flag with specified name, default value, and usage string. +// The return value is the address of a string variable that stores the value of the flag. +func (f *FlagSet) String(name, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, "", value, usage) + return p +} + +// StringP is like String, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringP(name, shorthand, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, shorthand, value, usage) + return p +} + // -- bool Value type boolValue bool @@ -652,10 +554,6 @@ func (b *boolValue) Set(s string) error { return err } -func (b *boolValue) Type() string { - return "bool" -} - func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } func (b *boolValue) IsBoolFlag() bool { return true } @@ -679,3 +577,76 @@ func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage st flag.NoOptDefVal = "true" return nil } + +// Bool defines a bool flag with specified name, default value, and usage string. +// The return value is the address of a bool variable that stores the value of the flag. +func (f *FlagSet) Bool(name string, value bool, usage string) *bool { + return f.BoolP(name, "", value, usage) +} + +// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { + p := new(bool) + f.BoolVarP(p, name, shorthand, value, usage) + return p +} + +// -- int Value +type intValue int + +func newIntValue(val int, p *int) *intValue { + *p = val + return (*intValue)(p) +} + +func (i *intValue) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + *i = intValue(v) + return err +} + +func (i *intValue) String() string { return strconv.Itoa(int(*i)) } + +// IntVar defines an int flag with specified name, default value, and usage string. +// The argument p points to an int variable in which to store the value of the flag. +func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { + f.VarP(newIntValue(value, p), name, "", usage) +} + +// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IntVarP(p *int, name, shorthand string, value int, usage string) { + f.VarP(newIntValue(value, p), name, shorthand, usage) +} + +func (f *FlagSet) Int(name string, value int, usage string) *int { + p := new(int) + f.IntVarP(p, name, "", value, usage) + return p +} + +// -- time.Duration Value +type durationValue time.Duration + +func (d *durationValue) Set(s string) error { + v, err := time.ParseDuration(s) + *d = durationValue(v) + return err +} + +func (d *durationValue) String() string { return (*time.Duration)(d).String() } + +func newDurationValue(val time.Duration, p *time.Duration) *durationValue { + *p = val + return (*durationValue)(p) +} + +// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { + f.VarP(newDurationValue(value, p), name, shorthand, usage) +} + +func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration { + p := new(time.Duration) + f.DurationVarP(p, name, "", value, usage) + return p +} diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go new file mode 100644 index 0000000000..7bde2edcd1 --- /dev/null +++ b/tool/internal_pkg/util/flag_test.go @@ -0,0 +1,160 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "io" + "os" + "testing" + "time" +) + +func ResetForTesting() { + CommandLine = &FlagSet{ + name: os.Args[0], + errorHandling: ContinueOnError, + output: io.Discard, + } +} + +// GetCommandLine returns the default FlagSet. +func GetCommandLine() *FlagSet { + return CommandLine +} + +func TestParse(t *testing.T) { + ResetForTesting() + testParse(GetCommandLine(), t) +} + +func testParse(f *FlagSet, t *testing.T) { + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolFlag := f.Bool("bool", false, "bool value") + bool2Flag := f.Bool("bool2", false, "bool2 value") + bool3Flag := f.Bool("bool3", false, "bool3 value") + intFlag := f.Int("int", 0, "int value") + stringFlag := f.String("string", "0", "string value") + durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value") + optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" + optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" + extra := "one-extra-argument" + args := []string{ + "--bool", + "--bool2=true", + "--bool3=false", + "--int=22", + "--string=hello", + "--duration=2m", + "--optional-int-no-value", + "--optional-int-with-value=42", + extra, + } + if err := f.Parse(args); err != nil { + t.Fatal(err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolFlag != true { + t.Error("bool flag should be true, is ", *boolFlag) + } + if *bool2Flag != true { + t.Error("bool2 flag should be true, is ", *bool2Flag) + } + if *bool3Flag != false { + t.Error("bool3 flag should be false, is ", *bool2Flag) + } + if *intFlag != 22 { + t.Error("int flag should be 22, is ", *intFlag) + } + if *stringFlag != "hello" { + t.Error("string flag should be `hello`, is ", *stringFlag) + } + if *durationFlag != 2*time.Minute { + t.Error("duration flag should be 2m, is ", *durationFlag) + } + if *optionalIntNoValueFlag != 9 { + t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) + } + if *optionalIntWithValueFlag != 42 { + t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) + } + if len(f.Args()) != 1 { + t.Error("expected one argument, got", len(f.Args())) + } else if f.Args()[0] != extra { + t.Errorf("expected argument %q got %q", extra, f.Args()[0]) + } +} + +func TestShorthand(t *testing.T) { + f := NewFlagSet("shorthand", ContinueOnError) + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolaFlag := f.BoolP("boola", "a", false, "bool value") + boolbFlag := f.BoolP("boolb", "b", false, "bool2 value") + boolcFlag := f.BoolP("boolc", "c", false, "bool3 value") + booldFlag := f.BoolP("boold", "d", false, "bool4 value") + stringaFlag := f.StringP("stringa", "s", "0", "string value") + stringzFlag := f.StringP("stringz", "z", "0", "string value") + extra := "interspersed-argument" + notaflag := "--i-look-like-a-flag" + args := []string{ + "-ab", + extra, + "-cs", + "hello", + "-z=something", + "-d=true", + "--", + notaflag, + } + f.SetOutput(io.Discard) + if err := f.Parse(args); err != nil { + t.Error("expected no error, got ", err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolaFlag != true { + t.Error("boola flag should be true, is ", *boolaFlag) + } + if *boolbFlag != true { + t.Error("boolb flag should be true, is ", *boolbFlag) + } + if *boolcFlag != true { + t.Error("boolc flag should be true, is ", *boolcFlag) + } + if *booldFlag != true { + t.Error("boold flag should be true, is ", *booldFlag) + } + if *stringaFlag != "hello" { + t.Error("stringa flag should be `hello`, is ", *stringaFlag) + } + if *stringzFlag != "something" { + t.Error("stringz flag should be `something`, is ", *stringzFlag) + } + if len(f.Args()) != 2 { + t.Error("expected one argument, got", len(f.Args())) + } else if f.Args()[0] != extra { + t.Errorf("expected argument %q got %q", extra, f.Args()[0]) + } else if f.Args()[1] != notaflag { + t.Errorf("expected argument %q got %q", notaflag, f.Args()[1]) + } +} From b2c165c4f53a1a401880b047a960d7dadf7c1c0b Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 31 Jul 2024 20:18:35 +0800 Subject: [PATCH 18/41] add test for command and flag --- internal/mocks/discovery/discovery.go | 3 +-- internal/mocks/generic/generic_service.go | 3 +-- internal/mocks/generic/thrift.go | 3 +-- internal/mocks/klog/log.go | 3 +-- internal/mocks/limiter/limiter.go | 3 +-- internal/mocks/loadbalance/loadbalancer.go | 3 +-- internal/mocks/net/net.go | 5 ++--- internal/mocks/netpoll/connection.go | 3 +-- internal/mocks/proxy/proxy.go | 3 +-- internal/mocks/remote/bytebuf.go | 3 +-- internal/mocks/remote/codec.go | 3 +-- internal/mocks/remote/conn_wrapper.go | 3 +-- internal/mocks/remote/connpool.go | 3 +-- internal/mocks/remote/dialer.go | 3 +-- internal/mocks/remote/payload_codec.go | 3 +-- internal/mocks/remote/trans_handler.go | 3 +-- internal/mocks/remote/trans_meta.go | 3 +-- internal/mocks/remote/trans_pipeline.go | 3 +-- internal/mocks/stats/tracer.go | 3 +-- internal/mocks/utils/sharedticker.go | 3 +-- tool/internal_pkg/pluginmode/thriftgo/plugin.go | 4 ---- 21 files changed, 21 insertions(+), 45 deletions(-) diff --git a/internal/mocks/discovery/discovery.go b/internal/mocks/discovery/discovery.go index 60c1d15bfc..e70f64e2e0 100644 --- a/internal/mocks/discovery/discovery.go +++ b/internal/mocks/discovery/discovery.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/discovery/discovery.go diff --git a/internal/mocks/generic/generic_service.go b/internal/mocks/generic/generic_service.go index d5aba3c5d2..8ca3f06056 100644 --- a/internal/mocks/generic/generic_service.go +++ b/internal/mocks/generic/generic_service.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/generic_service.go diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 142fd9f483..92b0697092 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/thrift/thrift.go diff --git a/internal/mocks/klog/log.go b/internal/mocks/klog/log.go index 792ffb1b6c..2c83208a32 100644 --- a/internal/mocks/klog/log.go +++ b/internal/mocks/klog/log.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/klog/log.go diff --git a/internal/mocks/limiter/limiter.go b/internal/mocks/limiter/limiter.go index 0d9653e9d9..5fac2ad3e3 100644 --- a/internal/mocks/limiter/limiter.go +++ b/internal/mocks/limiter/limiter.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/limiter/limiter.go diff --git a/internal/mocks/loadbalance/loadbalancer.go b/internal/mocks/loadbalance/loadbalancer.go index 41de5c7ad1..63e4fec5bb 100644 --- a/internal/mocks/loadbalance/loadbalancer.go +++ b/internal/mocks/loadbalance/loadbalancer.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/loadbalance/loadbalancer.go diff --git a/internal/mocks/net/net.go b/internal/mocks/net/net.go index e5e2c6ac5f..7ada9f882b 100644 --- a/internal/mocks/net/net.go +++ b/internal/mocks/net/net.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: /usr/local/go/src/net/net.go @@ -22,8 +21,8 @@ package net import ( - reflect "reflect" net "net" + reflect "reflect" time "time" gomock "github.com/golang/mock/gomock" diff --git a/internal/mocks/netpoll/connection.go b/internal/mocks/netpoll/connection.go index e7a1e0265a..325fff2af7 100644 --- a/internal/mocks/netpoll/connection.go +++ b/internal/mocks/netpoll/connection.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../../netpoll/connection.go diff --git a/internal/mocks/proxy/proxy.go b/internal/mocks/proxy/proxy.go index c028766f04..2e5ff5fb13 100644 --- a/internal/mocks/proxy/proxy.go +++ b/internal/mocks/proxy/proxy.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/proxy/proxy.go diff --git a/internal/mocks/remote/bytebuf.go b/internal/mocks/remote/bytebuf.go index ab468b07b6..0795e0fa78 100644 --- a/internal/mocks/remote/bytebuf.go +++ b/internal/mocks/remote/bytebuf.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/bytebuf.go diff --git a/internal/mocks/remote/codec.go b/internal/mocks/remote/codec.go index e478202702..aaed664113 100644 --- a/internal/mocks/remote/codec.go +++ b/internal/mocks/remote/codec.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/codec.go diff --git a/internal/mocks/remote/conn_wrapper.go b/internal/mocks/remote/conn_wrapper.go index 2e4c507890..e57a75d8c6 100644 --- a/internal/mocks/remote/conn_wrapper.go +++ b/internal/mocks/remote/conn_wrapper.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/remotecli/conn_wrapper.go diff --git a/internal/mocks/remote/connpool.go b/internal/mocks/remote/connpool.go index 211b312244..fcdfd5575c 100644 --- a/internal/mocks/remote/connpool.go +++ b/internal/mocks/remote/connpool.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/connpool.go diff --git a/internal/mocks/remote/dialer.go b/internal/mocks/remote/dialer.go index 55470af875..90ff6960bd 100644 --- a/internal/mocks/remote/dialer.go +++ b/internal/mocks/remote/dialer.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/dialer.go diff --git a/internal/mocks/remote/payload_codec.go b/internal/mocks/remote/payload_codec.go index fa76f995d9..2adda83bf2 100644 --- a/internal/mocks/remote/payload_codec.go +++ b/internal/mocks/remote/payload_codec.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/payload_codec.go diff --git a/internal/mocks/remote/trans_handler.go b/internal/mocks/remote/trans_handler.go index 04baa99546..ad32a17a14 100644 --- a/internal/mocks/remote/trans_handler.go +++ b/internal/mocks/remote/trans_handler.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_handler.go diff --git a/internal/mocks/remote/trans_meta.go b/internal/mocks/remote/trans_meta.go index f82e60732f..71297e4616 100644 --- a/internal/mocks/remote/trans_meta.go +++ b/internal/mocks/remote/trans_meta.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_meta.go diff --git a/internal/mocks/remote/trans_pipeline.go b/internal/mocks/remote/trans_pipeline.go index f40093ca22..60d89fc071 100644 --- a/internal/mocks/remote/trans_pipeline.go +++ b/internal/mocks/remote/trans_pipeline.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_pipeline.go diff --git a/internal/mocks/stats/tracer.go b/internal/mocks/stats/tracer.go index 02300542b6..4f8da5476d 100644 --- a/internal/mocks/stats/tracer.go +++ b/internal/mocks/stats/tracer.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/stats/tracer.go diff --git a/internal/mocks/utils/sharedticker.go b/internal/mocks/utils/sharedticker.go index 605c886ee3..a7ac22a7db 100644 --- a/internal/mocks/utils/sharedticker.go +++ b/internal/mocks/utils/sharedticker.go @@ -12,8 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ - + */ // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/utils/sharedticker.go diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index b3cd5c2052..e65c3e179e 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -119,10 +119,6 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } - //if conv.Config.InitType != "" { - // gen.GenerateCustomPackage(&conv.Package) - //} - res := &plugin.Response{ Warnings: conv.Warnings, } From ceb8ba95bdf8531f9c9ac710cca14df1ad7c4e9d Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 31 Jul 2024 20:44:31 +0800 Subject: [PATCH 19/41] resolve conflicts --- internal/mocks/discovery/discovery.go | 3 +- internal/mocks/generic/generic_service.go | 3 +- internal/mocks/generic/thrift.go | 3 +- internal/mocks/klog/log.go | 3 +- internal/mocks/limiter/limiter.go | 3 +- internal/mocks/loadbalance/loadbalancer.go | 3 +- internal/mocks/net/net.go | 5 +-- internal/mocks/netpoll/connection.go | 3 +- internal/mocks/proxy/proxy.go | 3 +- internal/mocks/remote/bytebuf.go | 3 +- internal/mocks/remote/codec.go | 3 +- internal/mocks/remote/conn_wrapper.go | 3 +- internal/mocks/remote/connpool.go | 3 +- internal/mocks/remote/dialer.go | 3 +- internal/mocks/remote/payload_codec.go | 3 +- internal/mocks/remote/trans_handler.go | 3 +- internal/mocks/remote/trans_meta.go | 3 +- internal/mocks/remote/trans_pipeline.go | 3 +- internal/mocks/stats/tracer.go | 3 +- internal/mocks/utils/sharedticker.go | 3 +- pkg/protocol/bthrift/binary.go | 24 ++++--------- pkg/protocol/bthrift/binary_test.go | 39 ---------------------- pkg/protocol/bthrift/interface.go | 5 --- pkg/rpcinfo/rpcconfig.go | 1 - 24 files changed, 48 insertions(+), 83 deletions(-) diff --git a/internal/mocks/discovery/discovery.go b/internal/mocks/discovery/discovery.go index e70f64e2e0..2ae32eb94f 100644 --- a/internal/mocks/discovery/discovery.go +++ b/internal/mocks/discovery/discovery.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/discovery/discovery.go @@ -175,4 +176,4 @@ func (m *MockInstance) Weight() int { func (mr *MockInstanceMockRecorder) Weight() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockInstance)(nil).Weight)) -} +} \ No newline at end of file diff --git a/internal/mocks/generic/generic_service.go b/internal/mocks/generic/generic_service.go index 8ca3f06056..bc18989b81 100644 --- a/internal/mocks/generic/generic_service.go +++ b/internal/mocks/generic/generic_service.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/generic_service.go @@ -98,4 +99,4 @@ func (m *MockWithCodec) SetCodec(codec interface{}) { func (mr *MockWithCodecMockRecorder) SetCodec(codec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCodec", reflect.TypeOf((*MockWithCodec)(nil).SetCodec), codec) -} +} \ No newline at end of file diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 92b0697092..eb0c5da86a 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/thrift/thrift.go @@ -103,4 +104,4 @@ func (m *MockMessageWriter) Write(ctx context.Context, out io.Writer, msg interf func (mr *MockMessageWriterMockRecorder) Write(ctx, out, msg, requestBase interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockMessageWriter)(nil).Write), ctx, out, msg, requestBase) -} +} \ No newline at end of file diff --git a/internal/mocks/klog/log.go b/internal/mocks/klog/log.go index 2c83208a32..2b6413079b 100644 --- a/internal/mocks/klog/log.go +++ b/internal/mocks/klog/log.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/klog/log.go @@ -890,4 +891,4 @@ func (mr *MockFullLoggerMockRecorder) Warnf(format interface{}, v ...interface{} mr.mock.ctrl.T.Helper() varargs := append([]interface{}{format}, v...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockFullLogger)(nil).Warnf), varargs...) -} +} \ No newline at end of file diff --git a/internal/mocks/limiter/limiter.go b/internal/mocks/limiter/limiter.go index 5fac2ad3e3..97cec6a9cf 100644 --- a/internal/mocks/limiter/limiter.go +++ b/internal/mocks/limiter/limiter.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/limiter/limiter.go @@ -225,4 +226,4 @@ func (m *MockLimitReporter) QPSOverloadReport() { func (mr *MockLimitReporterMockRecorder) QPSOverloadReport() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QPSOverloadReport", reflect.TypeOf((*MockLimitReporter)(nil).QPSOverloadReport)) -} +} \ No newline at end of file diff --git a/internal/mocks/loadbalance/loadbalancer.go b/internal/mocks/loadbalance/loadbalancer.go index 63e4fec5bb..e39fa9bc1c 100644 --- a/internal/mocks/loadbalance/loadbalancer.go +++ b/internal/mocks/loadbalance/loadbalancer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/loadbalance/loadbalancer.go @@ -162,4 +163,4 @@ func (m *MockRebalancer) Rebalance(arg0 discovery.Change) { func (mr *MockRebalancerMockRecorder) Rebalance(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rebalance", reflect.TypeOf((*MockRebalancer)(nil).Rebalance), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/net/net.go b/internal/mocks/net/net.go index 7ada9f882b..191cd25b19 100644 --- a/internal/mocks/net/net.go +++ b/internal/mocks/net/net.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: /usr/local/go/src/net/net.go @@ -21,8 +22,8 @@ package net import ( - net "net" reflect "reflect" + net "net" time "time" gomock "github.com/golang/mock/gomock" @@ -581,4 +582,4 @@ func (m *MockbuffersWriter) writeBuffers(arg0 *net.Buffers) (int64, error) { func (mr *MockbuffersWriterMockRecorder) writeBuffers(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "writeBuffers", reflect.TypeOf((*MockbuffersWriter)(nil).writeBuffers), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/netpoll/connection.go b/internal/mocks/netpoll/connection.go index 325fff2af7..94c932557e 100644 --- a/internal/mocks/netpoll/connection.go +++ b/internal/mocks/netpoll/connection.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../../netpoll/connection.go @@ -560,4 +561,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} +} \ No newline at end of file diff --git a/internal/mocks/proxy/proxy.go b/internal/mocks/proxy/proxy.go index 2e5ff5fb13..4d15434856 100644 --- a/internal/mocks/proxy/proxy.go +++ b/internal/mocks/proxy/proxy.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/proxy/proxy.go @@ -191,4 +192,4 @@ func (m *MockContextHandler) HandleContext(arg0 context.Context) context.Context func (mr *MockContextHandlerMockRecorder) HandleContext(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleContext", reflect.TypeOf((*MockContextHandler)(nil).HandleContext), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/bytebuf.go b/internal/mocks/remote/bytebuf.go index 0795e0fa78..24a4d7b9af 100644 --- a/internal/mocks/remote/bytebuf.go +++ b/internal/mocks/remote/bytebuf.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/bytebuf.go @@ -437,4 +438,4 @@ func (m *MockByteBuffer) WriteString(s string) (int, error) { func (mr *MockByteBufferMockRecorder) WriteString(s interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockByteBuffer)(nil).WriteString), s) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/codec.go b/internal/mocks/remote/codec.go index aaed664113..56593b0907 100644 --- a/internal/mocks/remote/codec.go +++ b/internal/mocks/remote/codec.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/codec.go @@ -193,4 +194,4 @@ func (m *MockMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, func (mr *MockMetaDecoderMockRecorder) DecodePayload(ctx, msg, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePayload", reflect.TypeOf((*MockMetaDecoder)(nil).DecodePayload), ctx, msg, in) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/conn_wrapper.go b/internal/mocks/remote/conn_wrapper.go index e57a75d8c6..03d159e9a5 100644 --- a/internal/mocks/remote/conn_wrapper.go +++ b/internal/mocks/remote/conn_wrapper.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/remotecli/conn_wrapper.go @@ -60,4 +61,4 @@ func (m *MockConnReleaser) ReleaseConn(err error, ri rpcinfo.RPCInfo) { func (mr *MockConnReleaserMockRecorder) ReleaseConn(err, ri interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseConn", reflect.TypeOf((*MockConnReleaser)(nil).ReleaseConn), err, ri) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/connpool.go b/internal/mocks/remote/connpool.go index fcdfd5575c..7fbf882329 100644 --- a/internal/mocks/remote/connpool.go +++ b/internal/mocks/remote/connpool.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/connpool.go @@ -308,4 +309,4 @@ func (m *MockIsActive) IsActive() bool { func (mr *MockIsActiveMockRecorder) IsActive() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsActive", reflect.TypeOf((*MockIsActive)(nil).IsActive)) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/dialer.go b/internal/mocks/remote/dialer.go index 90ff6960bd..cebd08a0b0 100644 --- a/internal/mocks/remote/dialer.go +++ b/internal/mocks/remote/dialer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/dialer.go @@ -64,4 +65,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/payload_codec.go b/internal/mocks/remote/payload_codec.go index 2adda83bf2..1db0d6128a 100644 --- a/internal/mocks/remote/payload_codec.go +++ b/internal/mocks/remote/payload_codec.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/payload_codec.go @@ -91,4 +92,4 @@ func (m *MockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message func (mr *MockPayloadCodecMockRecorder) Unmarshal(ctx, message, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmarshal", reflect.TypeOf((*MockPayloadCodec)(nil).Unmarshal), ctx, message, in) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_handler.go b/internal/mocks/remote/trans_handler.go index ad32a17a14..210ef935fa 100644 --- a/internal/mocks/remote/trans_handler.go +++ b/internal/mocks/remote/trans_handler.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_handler.go @@ -570,4 +571,4 @@ func (m *MockGracefulShutdown) GracefulShutdown(ctx context.Context) error { func (mr *MockGracefulShutdownMockRecorder) GracefulShutdown(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GracefulShutdown", reflect.TypeOf((*MockGracefulShutdown)(nil).GracefulShutdown), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_meta.go b/internal/mocks/remote/trans_meta.go index 71297e4616..38153da94d 100644 --- a/internal/mocks/remote/trans_meta.go +++ b/internal/mocks/remote/trans_meta.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_meta.go @@ -132,4 +133,4 @@ func (m *MockStreamingMetaHandler) OnReadStream(ctx context.Context) (context.Co func (mr *MockStreamingMetaHandlerMockRecorder) OnReadStream(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReadStream", reflect.TypeOf((*MockStreamingMetaHandler)(nil).OnReadStream), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_pipeline.go b/internal/mocks/remote/trans_pipeline.go index 60d89fc071..687e1a5fac 100644 --- a/internal/mocks/remote/trans_pipeline.go +++ b/internal/mocks/remote/trans_pipeline.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_pipeline.go @@ -267,4 +268,4 @@ func (m *MockDuplexBoundHandler) Write(ctx context.Context, conn net.Conn, send func (mr *MockDuplexBoundHandlerMockRecorder) Write(ctx, conn, send interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDuplexBoundHandler)(nil).Write), ctx, conn, send) -} +} \ No newline at end of file diff --git a/internal/mocks/stats/tracer.go b/internal/mocks/stats/tracer.go index 4f8da5476d..0734297631 100644 --- a/internal/mocks/stats/tracer.go +++ b/internal/mocks/stats/tracer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/stats/tracer.go @@ -74,4 +75,4 @@ func (m *MockTracer) Start(ctx context.Context) context.Context { func (mr *MockTracerMockRecorder) Start(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTracer)(nil).Start), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/utils/sharedticker.go b/internal/mocks/utils/sharedticker.go index a7ac22a7db..996951cbd1 100644 --- a/internal/mocks/utils/sharedticker.go +++ b/internal/mocks/utils/sharedticker.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/utils/sharedticker.go @@ -59,4 +60,4 @@ func (m *MockTickerTask) Tick() { func (mr *MockTickerTaskMockRecorder) Tick() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tick", reflect.TypeOf((*MockTickerTask)(nil).Tick)) -} +} \ No newline at end of file diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 0668e01cc8..8db034e1dd 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -30,28 +30,19 @@ import ( var ( // Binary protocol for bthrift. - Binary binaryProtocol - _ BTProtocol = binaryProtocol{} + Binary binaryProtocol + _ BTProtocol = binaryProtocol{} + spanCache = mem.NewSpanCache(1024 * 1024) + spanCacheEnable bool = false ) -var allocator Allocator - const binaryInplaceThreshold = 4096 // 4k type binaryProtocol struct{} // SetSpanCache enable/disable binary protocol bytes/string allocator func SetSpanCache(enable bool) { - if enable { - SetAllocator(mem.NewSpanCache(1024 * 1024)) - } else { - SetAllocator(nil) - } -} - -// SetAllocator set binary protocol bytes/string allocator. -func SetAllocator(alloc Allocator) { - allocator = alloc + spanCacheEnable = enable } func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int { @@ -500,9 +491,8 @@ func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err erro if size < 0 || size > len(buf) { return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") } - alloc := allocator - if alloc != nil { - value = alloc.Copy(buf[length : length+size]) + if spanCacheEnable { + value = spanCache.Copy(buf[length : length+size]) } else { value = make([]byte, size) copy(value, buf[length:length+size]) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 8e395bb5cd..a0754bcd55 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -291,24 +291,6 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } -// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator -func TestWriteAndReadStringWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - wn := Binary.WriteString(buf, "kitex") - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadString(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - test.Assert(t, v == "kitex") - SetSpanCache(false) -} - // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -328,27 +310,6 @@ func TestWriteAndReadBinary(t *testing.T) { } } -// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator -func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - val := []byte("kitex") - wn := Binary.WriteBinary(buf, val) - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadBinary(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - for i := 0; i < len(v); i++ { - test.Assert(t, val[i] == v[i]) - } - SetSpanCache(false) -} - // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index e65d667318..75fa0ce951 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -97,8 +97,3 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } - -type Allocator interface { - Make(n int) []byte - Copy(buf []byte) (p []byte) -} diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 88346ff99c..3c0f654ca6 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -17,7 +17,6 @@ package rpcinfo import ( - "fmt" "sync" "time" From 53995b409275a4966a2bffacbd0f5195c8d8a925 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 2 Aug 2024 20:05:16 +0800 Subject: [PATCH 20/41] refactor: rm apache thrift in internal/mocks (#1474) also: fix(server): invoker return err if no apache codec fix(server): listening on loopback addr --- client/mocks_test.go | 28 +-- internal/mocks/serviceinfo.go | 32 +-- internal/mocks/thrift_ttransport.go | 378 ---------------------------- internal/test/port.go | 2 +- pkg/remote/remotesvr/server_test.go | 2 +- server/invoke.go | 48 +++- server/invoke_test.go | 9 +- server/option_advanced_test.go | 8 +- server/option_test.go | 40 +-- server/server_test.go | 50 ++-- 10 files changed, 120 insertions(+), 477 deletions(-) delete mode 100644 internal/mocks/thrift_ttransport.go diff --git a/client/mocks_test.go b/client/mocks_test.go index 4064c70703..616047547e 100644 --- a/client/mocks_test.go +++ b/client/mocks_test.go @@ -16,28 +16,6 @@ package client -import ( - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// MockTStruct implements the thrift.TStruct interface. -type MockTStruct struct { - WriteFunc func(p thrift.TProtocol) (e error) - ReadFunc func(p thrift.TProtocol) (e error) -} - -// Write implements the thrift.TStruct interface. -func (m MockTStruct) Write(p thrift.TProtocol) (e error) { - if m.WriteFunc != nil { - return m.WriteFunc(p) - } - return -} - -// Read implements the thrift.TStruct interface. -func (m MockTStruct) Read(p thrift.TProtocol) (e error) { - if m.ReadFunc != nil { - return m.ReadFunc(p) - } - return -} +// MockTStruct was implemented the thrift.TStruct interface. +// But actually Read/Write are not in use, so removed... only empty struct left for testing +type MockTStruct struct{} diff --git a/internal/mocks/serviceinfo.go b/internal/mocks/serviceinfo.go index 7d6f0020b5..a2f82c185c 100644 --- a/internal/mocks/serviceinfo.go +++ b/internal/mocks/serviceinfo.go @@ -21,7 +21,7 @@ import ( "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -195,13 +195,9 @@ type myServiceMockArgs struct { Req *MyRequest `thrift:"req,1" json:"req"` } -func (p *myServiceMockArgs) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockArgs) Write(oprot thrift.TProtocol) error { - return nil -} +func (p *myServiceMockArgs) BLength() int { return 1 } +func (p *myServiceMockArgs) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { return 1 } +func (p *myServiceMockArgs) FastRead(buf []byte) (int, error) { return 1, nil } // MyRequest . type MyRequest struct { @@ -212,13 +208,9 @@ type myServiceMockResult struct { Success *MyResponse `thrift:"success,0" json:"success,omitempty"` } -func (p *myServiceMockResult) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockResult) Write(oprot thrift.TProtocol) error { - return nil -} +func (p *myServiceMockResult) BLength() int { return 1 } +func (p *myServiceMockResult) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { return 1 } +func (p *myServiceMockResult) FastRead(buf []byte) (int, error) { return 1, nil } // MyResponse . type MyResponse struct { @@ -230,13 +222,11 @@ type myServiceMockExceptionResult struct { MyException *MyException `thrift:"stException,1" json:"stException,omitempty"` } -func (p *myServiceMockExceptionResult) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockExceptionResult) Write(oprot thrift.TProtocol) error { - return nil +func (p *myServiceMockExceptionResult) BLength() int { return 1 } +func (p *myServiceMockExceptionResult) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { + return 1 } +func (p *myServiceMockExceptionResult) FastRead(buf []byte) (int, error) { return 1, nil } // MyException . type MyException struct { diff --git a/internal/mocks/thrift_ttransport.go b/internal/mocks/thrift_ttransport.go deleted file mode 100644 index 58920453f1..0000000000 --- a/internal/mocks/thrift_ttransport.go +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mocks - -import ( - "context" - - "github.com/apache/thrift/lib/go/thrift" -) - -type MockThriftTTransport struct { - WriteMessageBeginFunc func(name string, typeID thrift.TMessageType, seqID int32) error - WriteMessageEndFunc func() error - WriteStructBeginFunc func(name string) error - WriteStructEndFunc func() error - WriteFieldBeginFunc func(name string, typeID thrift.TType, id int16) error - WriteFieldEndFunc func() error - WriteFieldStopFunc func() error - WriteMapBeginFunc func(keyType, valueType thrift.TType, size int) error - WriteMapEndFunc func() error - WriteListBeginFunc func(elemType thrift.TType, size int) error - WriteListEndFunc func() error - WriteSetBeginFunc func(elemType thrift.TType, size int) error - WriteSetEndFunc func() error - WriteBoolFunc func(value bool) error - WriteByteFunc func(value int8) error - WriteI16Func func(value int16) error - WriteI32Func func(value int32) error - WriteI64Func func(value int64) error - WriteDoubleFunc func(value float64) error - WriteStringFunc func(value string) error - WriteBinaryFunc func(value []byte) error - ReadMessageBeginFunc func() (name string, typeID thrift.TMessageType, seqID int32, err error) - ReadMessageEndFunc func() error - ReadStructBeginFunc func() (name string, err error) - ReadStructEndFunc func() error - ReadFieldBeginFunc func() (name string, typeID thrift.TType, id int16, err error) - ReadFieldEndFunc func() error - ReadMapBeginFunc func() (keyType, valueType thrift.TType, size int, err error) - ReadMapEndFunc func() error - ReadListBeginFunc func() (elemType thrift.TType, size int, err error) - ReadListEndFunc func() error - ReadSetBeginFunc func() (elemType thrift.TType, size int, err error) - ReadSetEndFunc func() error - ReadBoolFunc func() (value bool, err error) - ReadByteFunc func() (value int8, err error) - ReadI16Func func() (value int16, err error) - ReadI32Func func() (value int32, err error) - ReadI64Func func() (value int64, err error) - ReadDoubleFunc func() (value float64, err error) - ReadStringFunc func() (value string, err error) - ReadBinaryFunc func() (value []byte, err error) - SkipFunc func(fieldType thrift.TType) (err error) - FlushFunc func(ctx context.Context) (err error) - TransportFunc func() thrift.TTransport -} - -func (m *MockThriftTTransport) WriteMessageBegin(name string, typeID thrift.TMessageType, seqID int32) error { - if m.WriteMessageBeginFunc != nil { - return m.WriteMessageBeginFunc(name, typeID, seqID) - } - return nil -} - -func (m *MockThriftTTransport) WriteMessageEnd() error { - if m.WriteMessageEndFunc != nil { - return m.WriteMessageEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteStructBegin(name string) error { - if m.WriteStructBeginFunc != nil { - return m.WriteStructBeginFunc(name) - } - return nil -} - -func (m *MockThriftTTransport) WriteStructEnd() error { - if m.WriteStructEndFunc != nil { - return m.WriteStructEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldBegin(name string, typeID thrift.TType, id int16) error { - if m.WriteFieldBeginFunc != nil { - return m.WriteFieldBeginFunc(name, typeID, id) - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldEnd() error { - if m.WriteFieldEndFunc != nil { - return m.WriteFieldEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldStop() error { - if m.WriteFieldStopFunc != nil { - return m.WriteFieldStopFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteMapBegin(keyType, valueType thrift.TType, size int) error { - if m.WriteMapBeginFunc != nil { - return m.WriteMapBeginFunc(keyType, valueType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteMapEnd() error { - if m.WriteMapEndFunc != nil { - return m.WriteMapEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteListBegin(elemType thrift.TType, size int) error { - if m.WriteListBeginFunc != nil { - return m.WriteListBeginFunc(elemType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteListEnd() error { - if m.WriteListEndFunc != nil { - return m.WriteListEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteSetBegin(elemType thrift.TType, size int) error { - if m.WriteSetBeginFunc != nil { - return m.WriteSetBeginFunc(elemType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteSetEnd() error { - if m.WriteSetEndFunc != nil { - return m.WriteSetEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteBool(value bool) error { - if m.WriteBoolFunc != nil { - return m.WriteBoolFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteByte(value int8) error { - if m.WriteByteFunc != nil { - return m.WriteByteFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI16(value int16) error { - if m.WriteI16Func != nil { - return m.WriteI16Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI32(value int32) error { - if m.WriteI32Func != nil { - return m.WriteI32Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI64(value int64) error { - if m.WriteI64Func != nil { - return m.WriteI64Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteDouble(value float64) error { - if m.WriteDoubleFunc != nil { - return m.WriteDoubleFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteString(value string) error { - if m.WriteStringFunc != nil { - return m.WriteStringFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteBinary(value []byte) error { - if m.WriteBinaryFunc != nil { - return m.WriteBinaryFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) ReadMessageBegin() (name string, typeID thrift.TMessageType, seqID int32, err error) { - if m.ReadMessageBeginFunc != nil { - return m.ReadMessageBeginFunc() - } - return "", thrift.INVALID_TMESSAGE_TYPE, 0, nil -} - -func (m *MockThriftTTransport) ReadMessageEnd() error { - if m.ReadMessageEndFunc != nil { - return m.ReadMessageEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadStructBegin() (name string, err error) { - if m.ReadStructBeginFunc != nil { - return m.ReadStructBeginFunc() - } - return "", nil -} - -func (m *MockThriftTTransport) ReadStructEnd() error { - if m.ReadStructEndFunc != nil { - return m.ReadStructEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadFieldBegin() (name string, typeID thrift.TType, id int16, err error) { - if m.ReadFieldBeginFunc != nil { - return m.ReadFieldBeginFunc() - } - return "", thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadFieldEnd() error { - if m.ReadFieldEndFunc != nil { - return m.ReadFieldEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadMapBegin() (keyType, valueType thrift.TType, size int, err error) { - if m.ReadMapBeginFunc != nil { - return m.ReadMapBeginFunc() - } - return thrift.STOP, thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadMapEnd() error { - if m.ReadMapEndFunc != nil { - return m.ReadMapEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadListBegin() (elemType thrift.TType, size int, err error) { - if m.ReadListBeginFunc != nil { - return m.ReadListBeginFunc() - } - return thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadListEnd() error { - if m.ReadListEndFunc != nil { - return m.ReadListEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadSetBegin() (elemType thrift.TType, size int, err error) { - if m.ReadSetBeginFunc != nil { - return m.ReadSetBeginFunc() - } - return thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadSetEnd() error { - if m.ReadSetEndFunc != nil { - return m.ReadSetEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadBool() (value bool, err error) { - if m.ReadBoolFunc != nil { - return m.ReadBoolFunc() - } - return false, nil -} - -func (m *MockThriftTTransport) ReadByte() (value int8, err error) { - if m.ReadByteFunc != nil { - return m.ReadByteFunc() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI16() (value int16, err error) { - if m.ReadI16Func != nil { - return m.ReadI16Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI32() (value int32, err error) { - if m.ReadI32Func != nil { - return m.ReadI32Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI64() (value int64, err error) { - if m.ReadI64Func != nil { - return m.ReadI64Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadDouble() (value float64, err error) { - if m.ReadDoubleFunc != nil { - return m.ReadDoubleFunc() - } - return 0.0, nil -} - -func (m *MockThriftTTransport) ReadString() (value string, err error) { - if m.ReadStringFunc != nil { - return m.ReadStringFunc() - } - return "", nil -} - -func (m *MockThriftTTransport) ReadBinary() (value []byte, err error) { - if m.ReadBinaryFunc != nil { - return m.ReadBinaryFunc() - } - return nil, nil -} - -func (m *MockThriftTTransport) Skip(fieldType thrift.TType) (err error) { - if m.SkipFunc != nil { - return m.SkipFunc(fieldType) - } - return nil -} - -func (m *MockThriftTTransport) Flush(ctx context.Context) (err error) { - if m.FlushFunc != nil { - return m.FlushFunc(ctx) - } - return nil -} - -func (m *MockThriftTTransport) Transport() thrift.TTransport { - if m.TransportFunc != nil { - return m.TransportFunc() - } - return nil -} diff --git a/internal/test/port.go b/internal/test/port.go index 5a620c13f9..4193c67cd3 100644 --- a/internal/test/port.go +++ b/internal/test/port.go @@ -47,7 +47,7 @@ func GetLocalAddress() string { for { time.Sleep(time.Millisecond * time.Duration(1+rand.Intn(10))) port := atomic.AddUint32(&curPort, 1+uint32(rand.Intn(10))) - addr := "127.0.0.1:" + strconv.Itoa(int(port)) + addr := "localhost:" + strconv.Itoa(int(port)) if !IsAddressInUse(addr) { trace := strings.Split(string(debug.Stack()), "\n") if len(trace) > 6 { diff --git a/pkg/remote/remotesvr/server_test.go b/pkg/remote/remotesvr/server_test.go index 1c9a1f8a72..078e50dfbd 100644 --- a/pkg/remote/remotesvr/server_test.go +++ b/pkg/remote/remotesvr/server_test.go @@ -35,7 +35,7 @@ func TestServerStart(t *testing.T) { transSvr := &mocks.MockTransServer{ CreateListenerFunc: func(addr net.Addr) (listener net.Listener, err error) { isCreateListener = true - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, BootstrapServerFunc: func(net.Listener) (err error) { diff --git a/server/invoke.go b/server/invoke.go index 8fada9d0c9..d26c8b2aa7 100644 --- a/server/invoke.go +++ b/server/invoke.go @@ -19,6 +19,7 @@ package server // Invoker is for calling handler function wrapped by Kitex suites without connection. import ( + "context" "errors" internal_server "github.com/cloudwego/kitex/internal/server" @@ -41,8 +42,42 @@ type Invoker interface { } type tInvoker struct { - invoke.Handler *server + + h invoke.Handler +} + +// invokerMetaDecoder is used to update `PayloadLen` of `remote.Message`. +// It fixes kitex returning err when apache codec is not available due to msg.PayloadLen() == 0. +// Because users may not add transport header like transport.Framed +// to invoke.Message when calling msg.SetRequestBytes. +// This is NOT expected and it's caused by kitex design fault. +type invokerMetaDecoder struct { + remote.Codec + + d remote.MetaDecoder +} + +func (d *invokerMetaDecoder) DecodeMeta(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + err := d.d.DecodeMeta(ctx, msg, in) + if err != nil { + return err + } + // cool ... no need to do anything. + // added transport header? + if msg.PayloadLen() > 0 { + return nil + } + // use the whole buffer + // coz for invoker remote.ByteBuffer always contains the whole msg payload + if n := in.ReadableLen(); n > 0 { + msg.SetPayloadLen(n) + } + return nil +} + +func (d *invokerMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + return d.d.DecodePayload(ctx, msg, in) } // NewInvoker creates new Invoker. @@ -51,6 +86,13 @@ func NewInvoker(opts ...Option) Invoker { opt: internal_server.NewOptions(opts), svcs: newServices(), } + if codec, ok := s.opt.RemoteOpt.Codec.(remote.MetaDecoder); ok { + // see comment on type `invokerMetaDecoder` + s.opt.RemoteOpt.Codec = &invokerMetaDecoder{ + Codec: s.opt.RemoteOpt.Codec, + d: codec, + } + } s.init() return &tInvoker{ server: s, @@ -69,7 +111,7 @@ func (s *tInvoker) Init() (err error) { doAddBoundHandler(transInfoHdlr, s.server.opt.RemoteOpt) } s.Lock() - s.Handler, err = s.newInvokeHandler() + s.h, err = s.newInvokeHandler() s.Unlock() if err != nil { return err @@ -82,7 +124,7 @@ func (s *tInvoker) Init() (err error) { // Call implements the InvokeCaller interface. func (s *tInvoker) Call(msg invoke.Message) error { - return s.Handler.Call(msg) + return s.h.Call(msg) } func (s *tInvoker) newInvokeHandler() (handler invoke.Handler, err error) { diff --git a/server/invoke_test.go b/server/invoke_test.go index 77089b4389..30a5178df7 100644 --- a/server/invoke_test.go +++ b/server/invoke_test.go @@ -22,11 +22,11 @@ import ( "sync/atomic" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/trans/invoke" - "github.com/cloudwego/kitex/pkg/utils" ) // TestInvokerCall tests Invoker, call Kitex server just like SDK. @@ -47,10 +47,9 @@ func TestInvokerCall(t *testing.T) { } args := mocks.NewMockArgs() - codec := utils.NewThriftMessageCodec() // call success - b, _ := codec.Encode("mock", thrift.CALL, 0, args.(thrift.TStruct)) + b, _ := thrift.MarshalFastMsg("mock", thrift.CALL, 0, args.(thrift.FastCodec)) msg := invoke.NewMessage(nil, nil) err = msg.SetRequestBytes(b) test.Assert(t, err == nil) @@ -66,7 +65,7 @@ func TestInvokerCall(t *testing.T) { test.Assert(t, gotErr.Load() == nil) // call fails - b, _ = codec.Encode("mockError", thrift.CALL, 0, args.(thrift.TStruct)) + b, _ = thrift.MarshalFastMsg("mockError", thrift.CALL, 0, args.(thrift.FastCodec)) msg = invoke.NewMessage(nil, nil) err = msg.SetRequestBytes(b) test.Assert(t, err == nil) diff --git a/server/option_advanced_test.go b/server/option_advanced_test.go index 75ae1f64c1..66ba46de80 100644 --- a/server/option_advanced_test.go +++ b/server/option_advanced_test.go @@ -55,7 +55,7 @@ func TestACLRulesOption(t *testing.T) { return nil }) - svr := NewServer(WithACLRules(rules...)) + svr, _ := NewTestServer(WithACLRules(rules...)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -98,7 +98,7 @@ func (m *myLimitReporter) QPSOverloadReport() { // TestLimitReporterOption tests the creation of a server with LimitReporter option func TestLimitReporterOption(t *testing.T) { my := &myLimitReporter{} - svr := NewServer(WithLimitReporter(my)) + svr, _ := NewTestServer(WithLimitReporter(my)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -122,7 +122,7 @@ func TestGenericOptionPanic(t *testing.T) { // TestGenericOption tests the creation of a server with RemoteOpt.PayloadCodec option func TestGenericOption(t *testing.T) { g := generic.BinaryThriftGeneric() - svr := NewServer(WithGeneric(g)) + svr, _ := NewTestServer(WithGeneric(g)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -190,7 +190,7 @@ func TestWithBoundHandler(t *testing.T) { func TestExitSignalOption(t *testing.T) { stopSignal := make(chan error, 1) stopErr := errors.New("stop signal") - svr := NewServer(WithExitSignal(func() <-chan error { + svr, _ := NewTestServer(WithExitSignal(func() <-chan error { return stopSignal })) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) diff --git a/server/option_test.go b/server/option_test.go index e0d5302d10..ea7818a9cb 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -88,7 +88,7 @@ func TestOptionDebugInfo(t *testing.T) { // TestProxyOption tests the creation of a server with Proxy option func TestProxyOption(t *testing.T) { var opts []Option - addr, err := net.ResolveTCPAddr("tcp", ":8888") + addr, err := net.ResolveTCPAddr("tcp", "localhost:8888") test.Assert(t, err == nil, err) opts = append(opts, WithServiceAddr(addr)) opts = append(opts, WithProxy(&proxyMock{})) @@ -136,7 +136,7 @@ func (m *mockDiagnosis) ProbePairs() map[diagnosis.ProbeName]diagnosis.ProbeFunc func TestExitWaitTimeOption(t *testing.T) { // random timeout value testTimeOut := time.Duration(time.Now().Unix()) * time.Microsecond - svr := NewServer(WithExitWaitTime(testTimeOut)) + svr, _ := NewTestServer(WithExitWaitTime(testTimeOut)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -153,7 +153,7 @@ func TestExitWaitTimeOption(t *testing.T) { func TestMaxConnIdleTimeOption(t *testing.T) { // random timeout value testTimeOut := time.Duration(time.Now().Unix()) - svr := NewServer(WithMaxConnIdleTime(testTimeOut)) + svr, _ := NewTestServer(WithMaxConnIdleTime(testTimeOut)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -177,7 +177,7 @@ func (t *myTracer) Finish(ctx context.Context) { // TestTracerOption tests the creation of a server with TracerCtl option func TestTracerOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -191,7 +191,7 @@ func TestTracerOption(t *testing.T) { test.Assert(t, iSvr1.opt.TracerCtl.HasTracer() != true) tracer := &myTracer{} - svr2 := NewServer(WithTracer(tracer)) + svr2, _ := NewTestServer(WithTracer(tracer)) err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -207,7 +207,7 @@ func TestTracerOption(t *testing.T) { // TestStatsLevelOption tests the creation of a server with StatsLevel option func TestStatsLevelOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -220,7 +220,7 @@ func TestStatsLevelOption(t *testing.T) { test.Assert(t, iSvr1.opt.StatsLevel != nil) test.Assert(t, *iSvr1.opt.StatsLevel == stats.LevelDisabled) - svr2 := NewServer(WithStatsLevel(stats.LevelDetailed)) + svr2, _ := NewTestServer(WithStatsLevel(stats.LevelDetailed)) err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -243,7 +243,7 @@ func (s *mySuiteOption) Options() []Option { // TestSuiteOption tests the creation of a server with SuiteOption option func TestSuiteOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr1.Stop() test.Assert(t, err == nil, err) @@ -264,7 +264,7 @@ func TestSuiteOption(t *testing.T) { WithExitWaitTime(tmpWaitTime), WithMaxConnIdleTime(tmpConnIdleTime), }} - svr2 := NewServer(WithSuite(suiteOpt)) + svr2, _ := NewTestServer(WithSuite(suiteOpt)) time.AfterFunc(100*time.Millisecond, func() { err := svr2.Stop() test.Assert(t, err == nil, err) @@ -283,7 +283,7 @@ func TestSuiteOption(t *testing.T) { // TestMuxTransportOption tests the creation of a server,with netpollmux remote.ServerTransHandlerFactory option, func TestMuxTransportOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr1.Stop() test.Assert(t, err == nil, err) @@ -295,7 +295,7 @@ func TestMuxTransportOption(t *testing.T) { iSvr1 := svr1.(*server) test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory())) - svr2 := NewServer(WithMuxTransport()) + svr2, _ := NewTestServer(WithMuxTransport()) time.AfterFunc(100*time.Millisecond, func() { err := svr2.Stop() test.Assert(t, err == nil, err) @@ -312,7 +312,7 @@ func TestMuxTransportOption(t *testing.T) { // TestPayloadCodecOption tests the creation of a server with RemoteOpt.PayloadCodec option func TestPayloadCodecOption(t *testing.T) { t.Run("NotSetPayloadCodec", func(t *testing.T) { - svr := NewServer() + svr, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -337,7 +337,7 @@ func TestPayloadCodecOption(t *testing.T) { test.Assert(t, protobuf.IsProtobufCodec(pc)) }) t.Run("SetPreRegisteredProtobufCodec", func(t *testing.T) { - svr := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec())) + svr, _ := NewTestServer(WithPayloadCodec(protobuf.NewProtobufCodec())) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -364,7 +364,7 @@ func TestPayloadCodecOption(t *testing.T) { t.Run("SetPreRegisteredThriftCodec", func(t *testing.T) { thriftCodec := thrift.NewThriftCodecDisableFastMode(false, true) - svr := NewServer(WithPayloadCodec(thriftCodec)) + svr, _ := NewTestServer(WithPayloadCodec(thriftCodec)) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -393,7 +393,7 @@ func TestPayloadCodecOption(t *testing.T) { t.Run("SetNonPreRegisteredCodec", func(t *testing.T) { // generic.BinaryThriftGeneric().PayloadCodec() is not the pre registered codec, RemoteOpt.PayloadCodec won't be nil binaryThriftCodec := generic.BinaryThriftGeneric().PayloadCodec() - svr := NewServer(WithPayloadCodec(binaryThriftCodec)) + svr, _ := NewTestServer(WithPayloadCodec(binaryThriftCodec)) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -431,7 +431,7 @@ func TestRemoteOptGRPCCfgUintValueOption(t *testing.T) { randUint3 := uint32(rand.Int31n(100)) + 1 randUint4 := uint32(rand.Int31n(100)) + 1 - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCInitialWindowSize(randUint1), WithGRPCInitialConnWindowSize(randUint2), WithGRPCMaxConcurrentStreams(randUint3), @@ -462,7 +462,7 @@ func TestGRPCKeepaliveEnforcementPolicyOption(t *testing.T) { MinTime: time.Duration(randInt) * time.Second, PermitWithoutStream: true, } - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCKeepaliveEnforcementPolicy(kep), ) @@ -493,7 +493,7 @@ func TestGRPCKeepaliveParamsOption(t *testing.T) { Time: randTimeDuration4, Timeout: randTimeDuration5, } - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCKeepaliveParams(kp), ) @@ -516,7 +516,7 @@ func TestWithProfilerMessageTagging(t *testing.T) { var msgTagging2 remote.MessageTagging = func(ctx context.Context, msg remote.Message) (context.Context, []string) { return context.WithValue(ctx, "ctx2", 2), []string{"b", "2", "c", "2"} } - svr := NewServer(WithProfilerMessageTagging(msgTagging1), WithProfilerMessageTagging(msgTagging2)) + svr, _ := NewTestServer(WithProfilerMessageTagging(msgTagging1), WithProfilerMessageTagging(msgTagging2)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -543,7 +543,7 @@ func TestWithProfilerMessageTagging(t *testing.T) { } func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) { - svr := NewServer(WithRefuseTrafficWithoutServiceName()) + svr, _ := NewTestServer(WithRefuseTrafficWithoutServiceName()) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { diff --git a/server/server_test.go b/server/server_test.go index e9b71f5e94..ede109d170 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -55,10 +55,26 @@ import ( var svcInfo = mocks.ServiceInfo() +// NOTE: always use this method to get addr for server listening +// should be used with `WithServiceAddr(addr)` +func getAddrForListener() net.Addr { + addr := test.GetLocalAddress() + ret, _ := net.ResolveTCPAddr("tcp", addr) + return ret +} + +// NewTestServer calls NewServer with a random addr +// DO NOT USE `NewServer` and `s.Run()` without specifying addr, it listens on :8888 ... +func NewTestServer(ops ...Option) (Server, net.Addr) { + addr := getAddrForListener() + svr := NewServer(append(ops, WithServiceAddr(addr))...) + return svr, addr +} + func TestServerRun(t *testing.T) { var opts []Option opts = append(opts, WithMetaHandler(noopMetahandler{})) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) time.AfterFunc(time.Millisecond*500, func() { err := svr.Stop() @@ -84,8 +100,7 @@ func TestServerRun(t *testing.T) { } func TestReusePortServerRun(t *testing.T) { - hostPort := test.GetLocalAddress() - addr, _ := net.ResolveTCPAddr("tcp", hostPort) + addr := getAddrForListener() var opts []Option opts = append(opts, WithReusePort(true)) opts = append(opts, WithServiceAddr(addr), WithExitWaitTime(time.Microsecond*10)) @@ -261,7 +276,7 @@ func TestServiceRegisterFailed(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -287,7 +302,7 @@ func TestServiceDeregisterFailed(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -308,7 +323,6 @@ func TestServiceRegistryInfo(t *testing.T) { checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) test.Assert(t, info.Weight == registryInfo.Weight, info.Addr) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) test.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) test.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) } @@ -329,7 +343,7 @@ func TestServiceRegistryInfo(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -346,7 +360,6 @@ func TestServiceRegistryInfo(t *testing.T) { func TestServiceRegistryNoInitInfo(t *testing.T) { checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) } var rCount int var drCount int @@ -364,7 +377,7 @@ func TestServiceRegistryNoInitInfo(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -386,7 +399,6 @@ func TestServiceRegistryInfoWithNilTags(t *testing.T) { } checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) test.Assert(t, info.Weight == registryInfo.Weight, info.Weight) test.Assert(t, info.Tags["aa"] == "bb", info.Tags) } @@ -410,7 +422,7 @@ func TestServiceRegistryInfoWithNilTags(t *testing.T) { opts = append(opts, WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ Tags: map[string]string{"aa": "bb"}, })) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -453,7 +465,7 @@ func TestServiceRegistryInfoWithSkipListenAddr(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -495,7 +507,7 @@ func TestServiceRegistryInfoWithoutSkipListenAddr(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -512,7 +524,7 @@ func TestServiceRegistryInfoWithoutSkipListenAddr(t *testing.T) { func TestGRPCServerMultipleServices(t *testing.T) { var opts []Option opts = append(opts, withGRPCTransport()) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler()) @@ -672,7 +684,7 @@ func TestServerBoundHandler(t *testing.T) { } for _, tcase := range cases { opts := append(tcase.opts, WithExitWaitTime(time.Millisecond*10)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() @@ -691,9 +703,9 @@ func TestServerBoundHandler(t *testing.T) { } func TestInvokeHandlerWithContextBackup(t *testing.T) { - testInvokeHandlerWithSession(t, true, ":8888") + testInvokeHandlerWithSession(t, true, "localhost:8888") os.Setenv(localsession.SESSION_CONFIG_KEY, "true,100,1h") - testInvokeHandlerWithSession(t, false, ":8889") + testInvokeHandlerWithSession(t, false, "localhost:8889") } func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { @@ -840,7 +852,7 @@ func TestInvokeHandlerExec(t *testing.T) { }, CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { var err error - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, } @@ -903,7 +915,7 @@ func TestInvokeHandlerPanic(t *testing.T) { }, CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { var err error - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, } From e63ac907ead40406fefde16c075169f4d09d292b Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Fri, 2 Aug 2024 20:21:22 +0800 Subject: [PATCH 21/41] chore: remove github.com/stretchr/testify direct dependency (#1475) --- CREDITS | 1 - go.mod | 2 +- internal/generic/proto/json_test.go | 12 +++++------- internal/generic/thrift/read_test.go | 4 ++-- pkg/utils/json_fuzz_test.go | 25 +++++++++++++------------ 5 files changed, 21 insertions(+), 23 deletions(-) delete mode 100644 CREDITS diff --git a/CREDITS b/CREDITS deleted file mode 100644 index fa01335116..0000000000 --- a/CREDITS +++ /dev/null @@ -1 +0,0 @@ -github.com/stretchr/testify \ No newline at end of file diff --git a/go.mod b/go.mod index 1afb570b3e..c86898ad6f 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 - github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.9.3 golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 @@ -45,6 +44,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oleiade/lane v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/internal/generic/proto/json_test.go b/internal/generic/proto/json_test.go index 748a17e5c7..1bec69eb12 100644 --- a/internal/generic/proto/json_test.go +++ b/internal/generic/proto/json_test.go @@ -20,12 +20,12 @@ import ( "context" "encoding/json" "io/ioutil" + "reflect" "testing" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/proto" "github.com/cloudwego/dynamicgo/testdata/kitex_gen/pb/example2" - "github.com/stretchr/testify/require" goprotowire "google.golang.org/protobuf/encoding/protowire" "github.com/cloudwego/kitex/internal/test" @@ -76,14 +76,13 @@ func TestWrite(t *testing.T) { l += tagLen buf = buf[tagLen:] offset, err := act.FastRead(buf, int8(wtyp), int32(id)) - require.Nil(t, err) + test.Assert(t, err == nil) buf = buf[offset:] l += offset } test.Assert(t, err == nil) - // compare exp and act struct - require.Equal(t, exp, act) + test.Assert(t, reflect.DeepEqual(exp, act)) } // Check NewReadJSON converting protobuf wire format to JSON @@ -114,7 +113,7 @@ func TestRead(t *testing.T) { l += tagLen in = in[tagLen:] offset, err := exp.FastRead(in, int8(wtyp), int32(id)) - require.Nil(t, err) + test.Assert(t, err == nil) in = in[offset:] l += offset } @@ -127,9 +126,8 @@ func TestRead(t *testing.T) { str, ok := out.(string) test.Assert(t, ok) json.Unmarshal([]byte(str), &act) - // compare exp and act struct - require.Equal(t, exp, act) + test.Assert(t, reflect.DeepEqual(exp, act)) } // helper methods diff --git a/internal/generic/thrift/read_test.go b/internal/generic/thrift/read_test.go index e50030d33a..2a77e120ee 100644 --- a/internal/generic/thrift/read_test.go +++ b/internal/generic/thrift/read_test.go @@ -25,9 +25,9 @@ import ( "github.com/cloudwego/gopkg/protocol/thrift" "github.com/jhump/protoreflect/desc/protoparse" - "github.com/stretchr/testify/require" "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/remote" ) @@ -656,7 +656,7 @@ func Test_readStruct(t *testing.T) { t.Errorf("readStruct() error = %v, wantErr %v", err, tt.wantErr) return } - require.Equal(t, tt.want, got) + test.Assert(t, reflect.DeepEqual(tt.want, got)) }) } } diff --git a/pkg/utils/json_fuzz_test.go b/pkg/utils/json_fuzz_test.go index c09851947f..40b5b3b720 100644 --- a/pkg/utils/json_fuzz_test.go +++ b/pkg/utils/json_fuzz_test.go @@ -21,9 +21,10 @@ package utils import ( "encoding/json" + "reflect" "testing" - "github.com/stretchr/testify/require" + "github.com/cloudwego/kitex/internal/test" ) func FuzzJSONStr2Map(f *testing.F) { @@ -42,10 +43,9 @@ func FuzzJSONStr2Map(f *testing.F) { } map1, err1 := JSONStr2Map(data) map2, err2 := _JSONStr2Map(data) - require.Equal(t, err2 == nil, err1 == nil, "json:%v", data) - if err2 == nil { - require.Equal(t, map2, map1, "json:%v", data) - } + test.Assert(t, err1 == nil, data) + test.Assert(t, err2 == nil, data) + test.Assert(t, reflect.DeepEqual(map1, map2), data) }) } @@ -64,13 +64,14 @@ func FuzzMap2JSON(f *testing.F) { if err := json.Unmarshal([]byte(data), &m); err != nil { return } - map1, err1 := Map2JSONStr(m) - map2, err2 := _Map2JSONStr(m) - require.Equal(t, err2 == nil, err1 == nil, "json:%v", data) - require.Equal(t, len(map2), len(map1), "json:%v", data) + str1, err1 := Map2JSONStr(m) + str2, err2 := _Map2JSONStr(m) + test.Assert(t, err1 == nil, data) + test.Assert(t, err2 == nil, data) + test.Assert(t, len(str1) == len(str2)) var m1, m2 map[string]string - require.NoError(t, json.Unmarshal([]byte(map1), &m1)) - require.NoError(t, json.Unmarshal([]byte(map2), &m2)) - require.Equal(t, m2, m1) + test.Assert(t, json.Unmarshal([]byte(str1), &m1) == nil) + test.Assert(t, json.Unmarshal([]byte(str2), &m2) == nil) + test.Assert(t, reflect.DeepEqual(m1, m2)) }) } From 5b18dabb8eec69efefdcaa918e4d57c636fbc4fa Mon Sep 17 00:00:00 2001 From: Joway Date: Mon, 5 Aug 2024 16:47:27 +0800 Subject: [PATCH 22/41] chore: upgrade gopkg to v0.1.0 (#1477) --- go.mod | 2 +- go.sum | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index c86898ad6f..ab45d6f291 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 + github.com/bytedance/gopkg v0.1.0 github.com/bytedance/sonic v1.11.8 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.2.9 diff --git a/go.sum b/go.sum index 8fe9f91469..586c0212f6 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,9 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 h1:rDwLxYTMoKHaw4cS0bQhaTZnkXp5e6ediCggGcRD/CA= github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= +github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.8 h1:Zw/j1KfiS+OYTi9lyB3bb0CFxPJVkM17k1wyDG32LRA= github.com/bytedance/sonic v1.11.8/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= From cfc3e9faab47ed89d820fdfcb6095c6ed561e407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AA=E5=8D=93=E5=BF=97?= Date: Thu, 11 Jul 2024 19:28:25 +0800 Subject: [PATCH 23/41] perf: custom allocator for fast codec ReadString/ReadBinary (#1427) # Conflicts: # pkg/protocol/bthrift/binary.go --- pkg/protocol/bthrift/binary_test.go | 39 +++++++++++++++++++++++++++++ pkg/protocol/bthrift/interface.go | 5 ++++ 2 files changed, 44 insertions(+) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index a0754bcd55..8e395bb5cd 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -291,6 +291,24 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } +// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator +func TestWriteAndReadStringWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + wn := Binary.WriteString(buf, "kitex") + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadString(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + test.Assert(t, v == "kitex") + SetSpanCache(false) +} + // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -310,6 +328,27 @@ func TestWriteAndReadBinary(t *testing.T) { } } +// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator +func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + val := []byte("kitex") + wn := Binary.WriteBinary(buf, val) + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadBinary(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + for i := 0; i < len(v); i++ { + test.Assert(t, val[i] == v[i]) + } + SetSpanCache(false) +} + // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index 75fa0ce951..e65d667318 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -97,3 +97,8 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } + +type Allocator interface { + Make(n int) []byte + Copy(buf []byte) (p []byte) +} From ffe91a24a53d5e5b58db9955512d0eeffb3741a9 Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 31 Jul 2024 20:44:31 +0800 Subject: [PATCH 24/41] resolve conflicts --- pkg/protocol/bthrift/binary_test.go | 39 ----------------------------- pkg/protocol/bthrift/interface.go | 5 ---- 2 files changed, 44 deletions(-) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 8e395bb5cd..a0754bcd55 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -291,24 +291,6 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } -// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator -func TestWriteAndReadStringWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - wn := Binary.WriteString(buf, "kitex") - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadString(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - test.Assert(t, v == "kitex") - SetSpanCache(false) -} - // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -328,27 +310,6 @@ func TestWriteAndReadBinary(t *testing.T) { } } -// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator -func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - val := []byte("kitex") - wn := Binary.WriteBinary(buf, val) - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadBinary(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - for i := 0; i < len(v); i++ { - test.Assert(t, val[i] == v[i]) - } - SetSpanCache(false) -} - // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index e65d667318..75fa0ce951 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -97,8 +97,3 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } - -type Allocator interface { - Make(n int) []byte - Copy(buf []byte) (p []byte) -} From d43a571cd3d08a23cf462d05eb6aebb37ba4cb81 Mon Sep 17 00:00:00 2001 From: shawn Date: Sun, 4 Aug 2024 19:43:57 +0800 Subject: [PATCH 25/41] feat: add init Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 4 - tool/cmd/kitex/args/tpl_args.go | 282 ++++++++++++++++------- tool/cmd/kitex/main.go | 2 +- tool/internal_pkg/generator/generator.go | 7 +- tool/internal_pkg/tpl/mock.go | 27 +++ 5 files changed, 235 insertions(+), 87 deletions(-) create mode 100644 tool/internal_pkg/tpl/mock.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 3941e8f19c..d676979009 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -311,10 +311,6 @@ func (a *Arguments) BuildCmd(out io.Writer) (*exec.Cmd, error) { Stderr: io.MultiWriter(out, os.Stderr), } - if err != nil { - return nil, err - } - if a.IDLType == "thrift" { os.Setenv(EnvPluginMode, thriftgo.PluginName) cmd.Args = append(cmd.Args, "thriftgo") diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index bd619433a1..5ef114757a 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -16,13 +16,211 @@ package args import ( "fmt" - "os" - "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "os" + "path/filepath" +) + +// Constants . +const ( + KitexGenPath = "kitex_gen" + DefaultCodec = "thrift" + + BuildFileName = "build.sh" + BootstrapFileName = "bootstrap.sh" + ToolVersionFileName = "kitex_info.yaml" + HandlerFileName = "handler.go" + MainFileName = "main.go" + ClientFileName = "client.go" + ServerFileName = "server.go" + InvokerFileName = "invoker.go" + ServiceFileName = "*service.go" + ExtensionFilename = "extensions.yaml" + + ClientMockFilename = "client_mock.go" +) + +var defaultTemplates = map[string]string{ + BuildFileName: tpl.BuildTpl, + BootstrapFileName: tpl.BootstrapTpl, + ToolVersionFileName: tpl.ToolVersionTpl, + HandlerFileName: tpl.HandlerTpl, + MainFileName: tpl.MainTpl, + ClientFileName: tpl.ClientTpl, + ServerFileName: tpl.ServerTpl, + InvokerFileName: tpl.InvokerTpl, + ServiceFileName: tpl.ServiceTpl, +} + +var mockTemplates = map[string]string{ + ClientMockFilename: tpl.ClientMockTpl, +} + +const ( + DefaultType = "default" + MockType = "mock" ) +type TemplateGenerator func(string) error + +var genTplMap = map[string]TemplateGenerator{ + DefaultType: GenTemplates, + MockType: GenMockTemplates, +} + +// GenTemplates is the entry for command kitex template, +// it will create the specified path +func GenTemplates(path string) error { + for key := range defaultTemplates { + if key == BootstrapFileName { + defaultTemplates[key] = util.JoinPath(path, "script", BootstrapFileName) + } + } + return InitTemplates(path, defaultTemplates) +} + +func GenMockTemplates(path string) error { + return InitTemplates(path, mockTemplates) +} + +// InitTemplates creates template files. +func InitTemplates(path string, templates map[string]string) error { + if err := MkdirIfNotExist(path); err != nil { + return err + } + + for k, v := range templates { + if err := createTemplate(filepath.Join(path, k+".tpl"), v); err != nil { + return err + } + } + + return nil +} + +// GetTemplateDir returns the category path. +func GetTemplateDir(category string) (string, error) { + home, err := filepath.Abs(".") + if err != nil { + return "", err + } + return filepath.Join(home, category), nil +} + +// MkdirIfNotExist makes directories if the input path is not exists +func MkdirIfNotExist(dir string) error { + if len(dir) == 0 { + return nil + } + + if _, err := os.Stat(dir); os.IsNotExist(err) { + return os.MkdirAll(dir, os.ModePerm) + } + + return nil +} + +func createTemplate(file, content string) error { + if util.Exists(file) { + return nil + } + + f, err := os.Create(file) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(content) + return err +} + +func (a *Arguments) Init(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + path := a.InitOutputDir + initType := a.InitType + if initType == "" { + initType = DefaultType + } + if path == "" { + path = curpath + } + if err := genTplMap[initType](path); err != nil { + return err + } + fmt.Printf("Templates are generated in %s\n", path) + os.Exit(0) + return nil +} + +func (a *Arguments) Render(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + if len(args) < 2 { + return fmt.Errorf("both template directory and idl is required") + } + a.RenderTplDir = args[0] + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(cmd.Flags().Args()[1:]) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Clean(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(cmd.Flags().Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + func (a *Arguments) TemplateArgs(version, curpath string) error { kitexCmd := &util.Command{ Use: "kitex", @@ -35,94 +233,20 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { initCmd := &util.Command{ Use: "init", Short: "Init command", - RunE: func(cmd *util.Command, args []string) error { - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Init, } renderCmd := &util.Command{ Use: "render", Short: "Render command", - RunE: func(cmd *util.Command, args []string) error { - if len(args) < 2 { - return fmt.Errorf("both template directory and idl is required") - } - a.RenderTplDir = args[0] - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()[1:]) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Render, } cleanCmd := &util.Command{ Use: "clean", Short: "Clean command", - RunE: func(cmd *util.Command, args []string) error { - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Clean, } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") - initCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 6583596b53..237297067f 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -52,7 +52,7 @@ func init() { if err := versions.RegisterMinDepVersion( &versions.MinDepVersion{ RefPath: "github.com/cloudwego/kitex", - Version: "v0.11.0", + Version: "v0.10.3", }, ); err != nil { log.Warn(err) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index fe95db4bbc..238cfcace6 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -16,6 +16,7 @@ package generator import ( + "errors" "fmt" "go/token" "path/filepath" @@ -135,9 +136,9 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string + InitOutputDir string //specify the location path of init subcommand InitType string - RenderTplDir string + RenderTplDir string // specify the path of template directory for render subcommand TemplateFile string GenPath string @@ -374,7 +375,7 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error pkg.ServiceInfo.ServiceName) f, err := comp.CompleteMethods() if err != nil { - if err == errNoNewMethod { + if errors.Is(err, errNoNewMethod) { return fs, nil } return nil, err diff --git a/tool/internal_pkg/tpl/mock.go b/tool/internal_pkg/tpl/mock.go new file mode 100644 index 0000000000..27da3a59d7 --- /dev/null +++ b/tool/internal_pkg/tpl/mock.go @@ -0,0 +1,27 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpl + +var ClientMockTpl = ` +// Code generated by MockGen. DO NOT EDIT. +// Source: kitex_gen/example/shop/item/itemserviceb/client.go +// +// Generated by this command: +// +// ___1go_build_go_uber_org_mock_mockgen -source=kitex_gen/example/shop/item/itemserviceb/client.go -destination=client_mock.go -package=main +// + +// Package main is a generated GoMock package. +` From 5fc7f4adaef239ad6cc50ac33e1ddc4de6d116a1 Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 5 Aug 2024 21:40:47 +0800 Subject: [PATCH 26/41] feat: render with multiple files and debug mode Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 109 ++++++++-------- tool/cmd/kitex/main.go | 2 +- .../internal_pkg/generator/custom_template.go | 120 +++++++++++------- tool/internal_pkg/generator/generator.go | 8 +- tool/internal_pkg/generator/generator_test.go | 7 +- .../pluginmode/thriftgo/plugin.go | 12 ++ tool/internal_pkg/tpl/multiple_services.go | 20 +++ tool/internal_pkg/util/flag.go | 56 +++++--- tool/internal_pkg/util/flag_test.go | 6 - 9 files changed, 208 insertions(+), 132 deletions(-) create mode 100644 tool/internal_pkg/tpl/multiple_services.go diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 5ef114757a..191261844b 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -16,12 +16,15 @@ package args import ( "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" - "os" - "path/filepath" ) // Constants . @@ -40,7 +43,7 @@ const ( ServiceFileName = "*service.go" ExtensionFilename = "extensions.yaml" - ClientMockFilename = "client_mock.go" + MultipleServicesFileName = "multiple_services.go" ) var defaultTemplates = map[string]string{ @@ -55,35 +58,30 @@ var defaultTemplates = map[string]string{ ServiceFileName: tpl.ServiceTpl, } -var mockTemplates = map[string]string{ - ClientMockFilename: tpl.ClientMockTpl, +var multipleServicesTpl = map[string]string{ + MultipleServicesFileName: tpl.MultipleServicesTpl, } const ( - DefaultType = "default" - MockType = "mock" + DefaultType = "default" + MultipleServicesType = "multiple_services" ) type TemplateGenerator func(string) error var genTplMap = map[string]TemplateGenerator{ - DefaultType: GenTemplates, - MockType: GenMockTemplates, + DefaultType: GenTemplates, + MultipleServicesType: GenMultipleServicesTemplates, } // GenTemplates is the entry for command kitex template, // it will create the specified path func GenTemplates(path string) error { - for key := range defaultTemplates { - if key == BootstrapFileName { - defaultTemplates[key] = util.JoinPath(path, "script", BootstrapFileName) - } - } return InitTemplates(path, defaultTemplates) } -func GenMockTemplates(path string) error { - return InitTemplates(path, mockTemplates) +func GenMultipleServicesTemplates(path string) error { + return InitTemplates(path, multipleServicesTpl) } // InitTemplates creates template files. @@ -92,8 +90,18 @@ func InitTemplates(path string, templates map[string]string) error { return err } - for k, v := range templates { - if err := createTemplate(filepath.Join(path, k+".tpl"), v); err != nil { + for name, content := range templates { + var filePath string + if name == BootstrapFileName { + bootstrapDir := filepath.Join(path, "script") + if err := MkdirIfNotExist(bootstrapDir); err != nil { + return err + } + filePath = filepath.Join(bootstrapDir, name+".tpl") + } else { + filePath = filepath.Join(path, name+".tpl") + } + if err := createTemplate(filePath, content); err != nil { return err } } @@ -164,10 +172,6 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) } - if len(args) < 2 { - return fmt.Errorf("both template directory and idl is required") - } - a.RenderTplDir = args[0] log.Verbose = a.Verbose for _, e := range a.extends { @@ -177,7 +181,7 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { } } - err = a.checkIDL(cmd.Flags().Args()[1:]) + err = a.checkIDL(args) if err != nil { return err } @@ -197,31 +201,35 @@ func (a *Arguments) Clean(cmd *util.Command, args []string) error { if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) } - log.Verbose = a.Verbose - for _, e := range a.extends { - err := e.Check(a) + magicString := "// Kitex template debug file. use template clean to delete it." + err = filepath.WalkDir(curpath, func(path string, info fs.DirEntry, err error) error { if err != nil { return err } - } - - err = a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() + if info.IsDir() { + return nil + } + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read file %s faild: %v", path, err) + } + if strings.Contains(string(content), magicString) { + if err := os.Remove(path); err != nil { + return fmt.Errorf("delete file %s failed: %v", path, err) + } + } + return nil + }) if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath + return fmt.Errorf("error cleaning debug template files: %v", err) } - return a.checkPath(curpath) + fmt.Println("clean debug template files successfully...") + os.Exit(0) + return nil } -func (a *Arguments) TemplateArgs(version, curpath string) error { +func (a *Arguments) TemplateArgs(version string) error { kitexCmd := &util.Command{ Use: "kitex", Short: "Kitex command", @@ -246,38 +254,29 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { RunE: a.Clean, } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") - renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", + initCmd.Flags().StringVar(&a.InitType, "type", "", "Specify template init type") + renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") + renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, "Specify a code gen path.") - renderCmd.Flags().StringVarP(&a.TemplateFile, "file", "f", "", "Specify single template path") + renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") + renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template init [flags] - -Examples: - kitex template init -o /path/to/output - kitex template init - -Flags: `, version) }) renderCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s -Usage: template render [template dir_path] [flags] IDL +Usage: template render --dir [template dir_path] [flags] IDL `, version) }) cleanCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template clean - -Examples: - kitex template clean - -Flags: `, version) }) // renderCmd.PrintUsage() diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 237297067f..1fb5372ffe 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -78,7 +78,7 @@ func main() { os.Exit(1) } if os.Args[1] == "template" { - err = args.TemplateArgs(kitex.Version, curpath) + err = args.TemplateArgs(kitex.Version) } else if !strings.HasPrefix(os.Args[1], "-") { err = fmt.Errorf("unknown command %q", os.Args[1]) } else { diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 17bebef833..1162844270 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -200,34 +200,23 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } -func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { - files, _ := os.ReadDir(currentDir) +func readTemplates(dir string) ([]*Template, error) { + files, _ := ioutil.ReadDir(dir) + var ts []*Template for _, f := range files { - // filter dir and non-tpl files - if f.IsDir() { - subDir := filepath.Join(currentDir, f.Name()) - subTemplates, err := readTpls(rootDir, subDir, ts) - if err != nil { - return nil, err - } - ts = append(ts, subTemplates...) - } else if strings.HasSuffix(f.Name(), ".tpl") { - p := filepath.Join(currentDir, f.Name()) - tplData, err := os.ReadFile(p) + // filter dir and non-yaml files + if f.Name() != ExtensionFilename && !f.IsDir() && (strings.HasSuffix(f.Name(), "yaml") || strings.HasSuffix(f.Name(), "yml")) { + p := filepath.Join(dir, f.Name()) + tplData, err := ioutil.ReadFile(p) if err != nil { return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) } - // Remove the .tpl suffix from the Path and compute relative path - relativePath, err := filepath.Rel(rootDir, p) - if err != nil { - return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) - } - trimmedPath := strings.TrimSuffix(relativePath, ".tpl") t := &Template{ - Path: trimmedPath, - Body: string(tplData), UpdateBehavior: &Update{Type: string(skip)}, } + if err = yaml.Unmarshal(tplData, t); err != nil { + return nil, fmt.Errorf("%s: unmarshal layout config failed, err: %s", f.Name(), err.Error()) + } ts = append(ts, t) } } @@ -235,6 +224,20 @@ func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { return ts, nil } +func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { + cg := NewCustomGenerator(pkg, outputPath) + // special handling Methods field + if tpl.LoopMethod { + err = cg.loopGenerate(tpl) + } else { + err = cg.commonGenerate(tpl) + } + if errors.Is(err, errNoNewMethod) { + err = nil + } + return cg.fs, err +} + func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { g.updatePackageInfo(pkg) @@ -273,40 +276,65 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, return fs, nil } -func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { - cg := NewCustomGenerator(pkg, outputPath) - // special handling Methods field - if tpl.LoopMethod { - err = cg.loopGenerate(tpl) - } else { - err = cg.commonGenerate(tpl) - } - if errors.Is(err, errNoNewMethod) { - err = nil - } - return cg.fs, err -} - -func readTemplates(dir string) ([]*Template, error) { - files, _ := ioutil.ReadDir(dir) - var ts []*Template +func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { + files, _ := os.ReadDir(currentDir) for _, f := range files { - // filter dir and non-yaml files - if f.Name() != ExtensionFilename && !f.IsDir() && (strings.HasSuffix(f.Name(), "yaml") || strings.HasSuffix(f.Name(), "yml")) { - p := filepath.Join(dir, f.Name()) - tplData, err := ioutil.ReadFile(p) + // filter dir and non-tpl files + if f.IsDir() { + subDir := filepath.Join(currentDir, f.Name()) + subTemplates, err := readTpls(rootDir, subDir, ts) if err != nil { - return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) + return nil, err + } + ts = append(ts, subTemplates...) + } else if strings.HasSuffix(f.Name(), ".tpl") { + p := filepath.Join(currentDir, f.Name()) + tplData, err := os.ReadFile(p) + if err != nil { + return nil, fmt.Errorf("read file from %s failed, err: %v", p, err.Error()) } + // Remove the .tpl suffix from the Path and compute relative path + relativePath, err := filepath.Rel(rootDir, p) + if err != nil { + return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) + } + trimmedPath := strings.TrimSuffix(relativePath, ".tpl") t := &Template{ + Path: trimmedPath, + Body: string(tplData), UpdateBehavior: &Update{Type: string(skip)}, } - if err = yaml.Unmarshal(tplData, t); err != nil { - return nil, fmt.Errorf("%s: unmarshal layout config failed, err: %s", f.Name(), err.Error()) - } ts = append(ts, t) } } return ts, nil } + +func (g *generator) RenderWithMultipleFiles(pkg *PackageInfo) (fs []*File, err error) { + for _, file := range g.Config.TemplateFiles { + content, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("read file from %s failed, err: %v", file, err.Error()) + } + var updatedContent string + if g.Config.DebugTpl { + // --debug时 在模板内容顶部加上一段magic string用于区分 + updatedContent = "// Kitex template debug file. use template clean to delete it.\n\n" + string(content) + } else { + updatedContent = string(content) + } + filename := filepath.Base(strings.TrimSuffix(file, ".tpl")) + tpl := &Template{ + Path: filename, + Body: updatedContent, + UpdateBehavior: &Update{Type: string(skip)}, + } + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + return +} diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 238cfcace6..c7478887c0 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -98,6 +98,7 @@ type Generator interface { GenerateMainPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackageWithTpl(pkg *PackageInfo) ([]*File, error) + RenderWithMultipleFiles(pkg *PackageInfo) ([]*File, error) } // Config . @@ -136,10 +137,11 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string //specify the location path of init subcommand + InitOutputDir string // specify the location path of init subcommand InitType string - RenderTplDir string // specify the path of template directory for render subcommand - TemplateFile string + RenderTplDir string // specify the path of template directory for render subcommand + TemplateFiles []string // specify the path of single file or multiple file to render + DebugTpl bool GenPath string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index cc39bffc31..ebb5245b70 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -59,7 +59,8 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir string InitType string RenderTplDir string - TemplateFile string + TemplateFiles []string + DebugTpl bool Protocol string HandlerReturnKeepResp bool } @@ -73,7 +74,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -104,7 +105,7 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir: tt.fields.InitOutputDir, InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, - TemplateFile: tt.fields.TemplateFile, + TemplateFiles: tt.fields.TemplateFiles, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index e65c3e179e..45560251d0 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -107,6 +107,18 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } + if len(conv.Config.TemplateFiles) > 0 { + if len(conv.Services) == 0 { + return conv.failResp(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.RenderWithMultipleFiles(&conv.Package) + if err != nil { + return conv.failResp(err) + } + files = append(files, fs...) + } + if conv.Config.RenderTplDir != "" { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) diff --git a/tool/internal_pkg/tpl/multiple_services.go b/tool/internal_pkg/tpl/multiple_services.go new file mode 100644 index 0000000000..0ade8ffd8c --- /dev/null +++ b/tool/internal_pkg/tpl/multiple_services.go @@ -0,0 +1,20 @@ +package tpl + +var MultipleServicesTpl = `package main + +import ( + {{- range $path, $aliases := .Imports}} + {{- if not $aliases}} + "{{$path}}" + {{- else}} + {{- range $alias, $is := $aliases}} + {{$alias}} "{{$path}}" + {{- end}} + {{- end}} + {{- end}} +) + +// {{.ServiceName}}Impl implements the last service interface defined in the IDL. +type {{.ServiceName}}Impl struct{} +{{template "HandlerMethod" .}} +` diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 21096cf8f1..d80a7ac6cf 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -15,6 +15,8 @@ package util import ( + "bytes" + "encoding/csv" "errors" goflag "flag" "fmt" @@ -22,7 +24,6 @@ import ( "os" "strconv" "strings" - "time" ) // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. @@ -624,29 +625,48 @@ func (f *FlagSet) Int(name string, value int, usage string) *int { return p } -// -- time.Duration Value -type durationValue time.Duration +// -- stringArray Value +type stringArrayValue struct { + value *[]string + changed bool +} -func (d *durationValue) Set(s string) error { - v, err := time.ParseDuration(s) - *d = durationValue(v) - return err +func newStringArrayValue(val []string, p *[]string) *stringArrayValue { + ssv := new(stringArrayValue) + ssv.value = p + *ssv.value = val + return ssv } -func (d *durationValue) String() string { return (*time.Duration)(d).String() } +func (s *stringArrayValue) Set(val string) error { + if !s.changed { + *s.value = []string{val} + s.changed = true + } else { + *s.value = append(*s.value, val) + } + return nil +} -func newDurationValue(val time.Duration, p *time.Duration) *durationValue { - *p = val - return (*durationValue)(p) +func writeAsCSV(vals []string) (string, error) { + b := &bytes.Buffer{} + w := csv.NewWriter(b) + err := w.Write(vals) + if err != nil { + return "", err + } + w.Flush() + return strings.TrimSuffix(b.String(), "\n"), nil } -// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { - f.VarP(newDurationValue(value, p), name, shorthand, usage) +func (s *stringArrayValue) String() string { + str, _ := writeAsCSV(*s.value) + return "[" + str + "]" } -func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration { - p := new(time.Duration) - f.DurationVarP(p, name, "", value, usage) - return p +// StringArrayVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a []string variable in which to store the values of the multiple flags. +// The value of each argument will not try to be separated by comma. Use a StringSlice for that. +func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { + f.VarP(newStringArrayValue(value, p), name, "", usage) } diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go index 7bde2edcd1..364d379c1f 100644 --- a/tool/internal_pkg/util/flag_test.go +++ b/tool/internal_pkg/util/flag_test.go @@ -18,7 +18,6 @@ import ( "io" "os" "testing" - "time" ) func ResetForTesting() { @@ -48,7 +47,6 @@ func testParse(f *FlagSet, t *testing.T) { bool3Flag := f.Bool("bool3", false, "bool3 value") intFlag := f.Int("int", 0, "int value") stringFlag := f.String("string", "0", "string value") - durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value") optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") f.Lookup("optional-int-no-value").NoOptDefVal = "9" optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") @@ -60,7 +58,6 @@ func testParse(f *FlagSet, t *testing.T) { "--bool3=false", "--int=22", "--string=hello", - "--duration=2m", "--optional-int-no-value", "--optional-int-with-value=42", extra, @@ -86,9 +83,6 @@ func testParse(f *FlagSet, t *testing.T) { if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } - if *durationFlag != 2*time.Minute { - t.Error("duration flag should be 2m, is ", *durationFlag) - } if *optionalIntNoValueFlag != 9 { t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) } From d2d051214e1bb1febea70b8881a0f191e27eb96c Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 5 Aug 2024 21:42:51 +0800 Subject: [PATCH 27/41] feat: add subcommand clean Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 191261844b..222008c949 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -203,11 +203,11 @@ func (a *Arguments) Clean(cmd *util.Command, args []string) error { } magicString := "// Kitex template debug file. use template clean to delete it." - err = filepath.WalkDir(curpath, func(path string, info fs.DirEntry, err error) error { + err = filepath.WalkDir(curpath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } - if info.IsDir() { + if d.IsDir() { return nil } content, err := os.ReadFile(path) From 6efc947e84233334a4406227f3b5874fce1c1b3d Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 5 Aug 2024 22:10:06 +0800 Subject: [PATCH 28/41] feat: add kitex_render_meta.yaml for render Signed-off-by: shawn --- .../internal_pkg/generator/custom_template.go | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 1162844270..52306dc8b1 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -276,7 +276,44 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, return fs, nil } +const kitexRenderMetaFile = "kitex_render_meta.yaml" + +// Meta 代表kitex_render_meta.yaml文件的结构 +type Meta struct { + Templates []Template `yaml:"templates"` +} + +func readMetaFile(rootDir string) (*Meta, error) { + metaPath := filepath.Join(rootDir, kitexRenderMetaFile) + metaData, err := os.ReadFile(metaPath) + if err != nil { + return nil, fmt.Errorf("failed to read meta file from %s: %v", metaPath, err) + } + + var meta Meta + err = yaml.Unmarshal(metaData, &meta) + if err != nil { + return nil, fmt.Errorf("failed to parse yaml file %s: %v", metaPath, err) + } + + return &meta, nil +} + +func getUpdateBehavior(meta *Meta, relativePath string) *Update { + for _, template := range meta.Templates { + // fmt.Println(template.Path, "==", relativePath) + if template.Path == relativePath { + return template.UpdateBehavior + } + } + return &Update{Type: string(skip)} +} + func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { + meta, err := readMetaFile(rootDir) + if err != nil { + return nil, err + } files, _ := os.ReadDir(currentDir) for _, f := range files { // filter dir and non-tpl files @@ -299,10 +336,11 @@ func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) } trimmedPath := strings.TrimSuffix(relativePath, ".tpl") + updateBehavior := getUpdateBehavior(meta, relativePath) t := &Template{ Path: trimmedPath, Body: string(tplData), - UpdateBehavior: &Update{Type: string(skip)}, + UpdateBehavior: updateBehavior, } ts = append(ts, t) } From 26cbfcd41a4ae7715b919d9566112855e2cf96f5 Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 6 Aug 2024 16:19:12 +0800 Subject: [PATCH 29/41] feat: add remote repo for render and middlewares Signed-off-by: shawn --- internal/mocks/discovery/discovery.go | 3 +- internal/mocks/generic/generic_service.go | 3 +- internal/mocks/generic/thrift.go | 3 +- internal/mocks/klog/log.go | 3 +- internal/mocks/limiter/limiter.go | 3 +- internal/mocks/loadbalance/loadbalancer.go | 3 +- internal/mocks/net/net.go | 5 +- internal/mocks/netpoll/connection.go | 3 +- internal/mocks/proxy/proxy.go | 3 +- internal/mocks/remote/bytebuf.go | 3 +- internal/mocks/remote/codec.go | 3 +- internal/mocks/remote/conn_wrapper.go | 3 +- internal/mocks/remote/connpool.go | 3 +- internal/mocks/remote/dialer.go | 3 +- internal/mocks/remote/payload_codec.go | 3 +- internal/mocks/remote/trans_handler.go | 3 +- internal/mocks/remote/trans_meta.go | 3 +- internal/mocks/remote/trans_pipeline.go | 3 +- internal/mocks/stats/tracer.go | 3 +- internal/mocks/utils/sharedticker.go | 3 +- tool/cmd/kitex/args/args.go | 7 +- tool/cmd/kitex/args/tpl_args.go | 19 ++- .../internal_pkg/generator/custom_template.go | 135 +++++++++++++++--- tool/internal_pkg/generator/generator.go | 8 +- tool/internal_pkg/generator/generator_test.go | 6 +- tool/internal_pkg/generator/type.go | 9 ++ tool/internal_pkg/tpl/mock.go | 27 ---- 27 files changed, 174 insertions(+), 99 deletions(-) delete mode 100644 tool/internal_pkg/tpl/mock.go diff --git a/internal/mocks/discovery/discovery.go b/internal/mocks/discovery/discovery.go index 2ae32eb94f..e70f64e2e0 100644 --- a/internal/mocks/discovery/discovery.go +++ b/internal/mocks/discovery/discovery.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/discovery/discovery.go @@ -176,4 +175,4 @@ func (m *MockInstance) Weight() int { func (mr *MockInstanceMockRecorder) Weight() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockInstance)(nil).Weight)) -} \ No newline at end of file +} diff --git a/internal/mocks/generic/generic_service.go b/internal/mocks/generic/generic_service.go index bc18989b81..8ca3f06056 100644 --- a/internal/mocks/generic/generic_service.go +++ b/internal/mocks/generic/generic_service.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/generic_service.go @@ -99,4 +98,4 @@ func (m *MockWithCodec) SetCodec(codec interface{}) { func (mr *MockWithCodecMockRecorder) SetCodec(codec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCodec", reflect.TypeOf((*MockWithCodec)(nil).SetCodec), codec) -} \ No newline at end of file +} diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index eb0c5da86a..92b0697092 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/thrift/thrift.go @@ -104,4 +103,4 @@ func (m *MockMessageWriter) Write(ctx context.Context, out io.Writer, msg interf func (mr *MockMessageWriterMockRecorder) Write(ctx, out, msg, requestBase interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockMessageWriter)(nil).Write), ctx, out, msg, requestBase) -} \ No newline at end of file +} diff --git a/internal/mocks/klog/log.go b/internal/mocks/klog/log.go index 2b6413079b..2c83208a32 100644 --- a/internal/mocks/klog/log.go +++ b/internal/mocks/klog/log.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/klog/log.go @@ -891,4 +890,4 @@ func (mr *MockFullLoggerMockRecorder) Warnf(format interface{}, v ...interface{} mr.mock.ctrl.T.Helper() varargs := append([]interface{}{format}, v...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockFullLogger)(nil).Warnf), varargs...) -} \ No newline at end of file +} diff --git a/internal/mocks/limiter/limiter.go b/internal/mocks/limiter/limiter.go index 97cec6a9cf..5fac2ad3e3 100644 --- a/internal/mocks/limiter/limiter.go +++ b/internal/mocks/limiter/limiter.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/limiter/limiter.go @@ -226,4 +225,4 @@ func (m *MockLimitReporter) QPSOverloadReport() { func (mr *MockLimitReporterMockRecorder) QPSOverloadReport() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QPSOverloadReport", reflect.TypeOf((*MockLimitReporter)(nil).QPSOverloadReport)) -} \ No newline at end of file +} diff --git a/internal/mocks/loadbalance/loadbalancer.go b/internal/mocks/loadbalance/loadbalancer.go index e39fa9bc1c..63e4fec5bb 100644 --- a/internal/mocks/loadbalance/loadbalancer.go +++ b/internal/mocks/loadbalance/loadbalancer.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/loadbalance/loadbalancer.go @@ -163,4 +162,4 @@ func (m *MockRebalancer) Rebalance(arg0 discovery.Change) { func (mr *MockRebalancerMockRecorder) Rebalance(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rebalance", reflect.TypeOf((*MockRebalancer)(nil).Rebalance), arg0) -} \ No newline at end of file +} diff --git a/internal/mocks/net/net.go b/internal/mocks/net/net.go index 191cd25b19..7ada9f882b 100644 --- a/internal/mocks/net/net.go +++ b/internal/mocks/net/net.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: /usr/local/go/src/net/net.go @@ -22,8 +21,8 @@ package net import ( - reflect "reflect" net "net" + reflect "reflect" time "time" gomock "github.com/golang/mock/gomock" @@ -582,4 +581,4 @@ func (m *MockbuffersWriter) writeBuffers(arg0 *net.Buffers) (int64, error) { func (mr *MockbuffersWriterMockRecorder) writeBuffers(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "writeBuffers", reflect.TypeOf((*MockbuffersWriter)(nil).writeBuffers), arg0) -} \ No newline at end of file +} diff --git a/internal/mocks/netpoll/connection.go b/internal/mocks/netpoll/connection.go index 94c932557e..325fff2af7 100644 --- a/internal/mocks/netpoll/connection.go +++ b/internal/mocks/netpoll/connection.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../../netpoll/connection.go @@ -561,4 +560,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} \ No newline at end of file +} diff --git a/internal/mocks/proxy/proxy.go b/internal/mocks/proxy/proxy.go index 4d15434856..2e5ff5fb13 100644 --- a/internal/mocks/proxy/proxy.go +++ b/internal/mocks/proxy/proxy.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/proxy/proxy.go @@ -192,4 +191,4 @@ func (m *MockContextHandler) HandleContext(arg0 context.Context) context.Context func (mr *MockContextHandlerMockRecorder) HandleContext(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleContext", reflect.TypeOf((*MockContextHandler)(nil).HandleContext), arg0) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/bytebuf.go b/internal/mocks/remote/bytebuf.go index 24a4d7b9af..0795e0fa78 100644 --- a/internal/mocks/remote/bytebuf.go +++ b/internal/mocks/remote/bytebuf.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/bytebuf.go @@ -438,4 +437,4 @@ func (m *MockByteBuffer) WriteString(s string) (int, error) { func (mr *MockByteBufferMockRecorder) WriteString(s interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockByteBuffer)(nil).WriteString), s) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/codec.go b/internal/mocks/remote/codec.go index 56593b0907..aaed664113 100644 --- a/internal/mocks/remote/codec.go +++ b/internal/mocks/remote/codec.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/codec.go @@ -194,4 +193,4 @@ func (m *MockMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, func (mr *MockMetaDecoderMockRecorder) DecodePayload(ctx, msg, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePayload", reflect.TypeOf((*MockMetaDecoder)(nil).DecodePayload), ctx, msg, in) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/conn_wrapper.go b/internal/mocks/remote/conn_wrapper.go index 03d159e9a5..e57a75d8c6 100644 --- a/internal/mocks/remote/conn_wrapper.go +++ b/internal/mocks/remote/conn_wrapper.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/remotecli/conn_wrapper.go @@ -61,4 +60,4 @@ func (m *MockConnReleaser) ReleaseConn(err error, ri rpcinfo.RPCInfo) { func (mr *MockConnReleaserMockRecorder) ReleaseConn(err, ri interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseConn", reflect.TypeOf((*MockConnReleaser)(nil).ReleaseConn), err, ri) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/connpool.go b/internal/mocks/remote/connpool.go index 7fbf882329..fcdfd5575c 100644 --- a/internal/mocks/remote/connpool.go +++ b/internal/mocks/remote/connpool.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/connpool.go @@ -309,4 +308,4 @@ func (m *MockIsActive) IsActive() bool { func (mr *MockIsActiveMockRecorder) IsActive() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsActive", reflect.TypeOf((*MockIsActive)(nil).IsActive)) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/dialer.go b/internal/mocks/remote/dialer.go index cebd08a0b0..90ff6960bd 100644 --- a/internal/mocks/remote/dialer.go +++ b/internal/mocks/remote/dialer.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/dialer.go @@ -65,4 +64,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/payload_codec.go b/internal/mocks/remote/payload_codec.go index 1db0d6128a..2adda83bf2 100644 --- a/internal/mocks/remote/payload_codec.go +++ b/internal/mocks/remote/payload_codec.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/payload_codec.go @@ -92,4 +91,4 @@ func (m *MockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message func (mr *MockPayloadCodecMockRecorder) Unmarshal(ctx, message, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmarshal", reflect.TypeOf((*MockPayloadCodec)(nil).Unmarshal), ctx, message, in) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/trans_handler.go b/internal/mocks/remote/trans_handler.go index 210ef935fa..ad32a17a14 100644 --- a/internal/mocks/remote/trans_handler.go +++ b/internal/mocks/remote/trans_handler.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_handler.go @@ -571,4 +570,4 @@ func (m *MockGracefulShutdown) GracefulShutdown(ctx context.Context) error { func (mr *MockGracefulShutdownMockRecorder) GracefulShutdown(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GracefulShutdown", reflect.TypeOf((*MockGracefulShutdown)(nil).GracefulShutdown), ctx) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/trans_meta.go b/internal/mocks/remote/trans_meta.go index 38153da94d..71297e4616 100644 --- a/internal/mocks/remote/trans_meta.go +++ b/internal/mocks/remote/trans_meta.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_meta.go @@ -133,4 +132,4 @@ func (m *MockStreamingMetaHandler) OnReadStream(ctx context.Context) (context.Co func (mr *MockStreamingMetaHandlerMockRecorder) OnReadStream(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReadStream", reflect.TypeOf((*MockStreamingMetaHandler)(nil).OnReadStream), ctx) -} \ No newline at end of file +} diff --git a/internal/mocks/remote/trans_pipeline.go b/internal/mocks/remote/trans_pipeline.go index 687e1a5fac..60d89fc071 100644 --- a/internal/mocks/remote/trans_pipeline.go +++ b/internal/mocks/remote/trans_pipeline.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_pipeline.go @@ -268,4 +267,4 @@ func (m *MockDuplexBoundHandler) Write(ctx context.Context, conn net.Conn, send func (mr *MockDuplexBoundHandlerMockRecorder) Write(ctx, conn, send interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDuplexBoundHandler)(nil).Write), ctx, conn, send) -} \ No newline at end of file +} diff --git a/internal/mocks/stats/tracer.go b/internal/mocks/stats/tracer.go index 0734297631..4f8da5476d 100644 --- a/internal/mocks/stats/tracer.go +++ b/internal/mocks/stats/tracer.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/stats/tracer.go @@ -75,4 +74,4 @@ func (m *MockTracer) Start(ctx context.Context) context.Context { func (mr *MockTracerMockRecorder) Start(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTracer)(nil).Start), ctx) -} \ No newline at end of file +} diff --git a/internal/mocks/utils/sharedticker.go b/internal/mocks/utils/sharedticker.go index 996951cbd1..a7ac22a7db 100644 --- a/internal/mocks/utils/sharedticker.go +++ b/internal/mocks/utils/sharedticker.go @@ -14,7 +14,6 @@ * limitations under the License. */ - // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/utils/sharedticker.go @@ -60,4 +59,4 @@ func (m *MockTickerTask) Tick() { func (mr *MockTickerTaskMockRecorder) Tick() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tick", reflect.TypeOf((*MockTickerTask)(nil).Tick)) -} \ No newline at end of file +} diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index d676979009..0ccb63c41e 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,14 +212,11 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" && a.RenderTplDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" { if a.Use != "" { - return fmt.Errorf("-use must be used with -service or -template-dir or template render") + return fmt.Errorf("-use must be used with -service or -template-dir") } } - if a.TemplateDir != "" && a.RenderTplDir != "" { - return fmt.Errorf("template render and -template-dir cannot be used at the same time") - } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 222008c949..a43c0171ea 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -167,6 +167,16 @@ func (a *Arguments) Init(cmd *util.Command, args []string) error { return nil } +func (a *Arguments) checkTplArgs() error { + if a.TemplateDir != "" && a.RenderTplDir != "" { + return fmt.Errorf("template render --dir and -template-dir cannot be used at the same time") + } + if a.RenderTplDir != "" && len(a.TemplateFiles) > 0 { + return fmt.Errorf("template render --dir and --file option cannot be specified at the same time") + } + return nil +} + func (a *Arguments) Render(cmd *util.Command, args []string) error { curpath, err := filepath.Abs(".") if err != nil { @@ -189,6 +199,10 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { if err != nil { return err } + err = a.checkTplArgs() + if err != nil { + return err + } // todo finish protobuf if a.IDLType != "thrift" { a.GenPath = generator.KitexGenPath @@ -263,7 +277,9 @@ func (a *Arguments) TemplateArgs(version string) error { "Specify a code gen path.") renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") - renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template init [flags] @@ -279,7 +295,6 @@ Usage: template render --dir [template dir_path] [flags] IDL Usage: kitex template clean `, version) }) - // renderCmd.PrintUsage() templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) kitexCmd.AddCommand(templateCmd) if _, err := kitexCmd.ExecuteC(); err != nil { diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 52306dc8b1..f4ca18b690 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -238,12 +238,86 @@ func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, return cg.fs, err } +// parseMeta parses the meta flag and returns a map where the value is a slice of strings +func parseMeta(metaFlags string) (map[string][]string, error) { + meta := make(map[string][]string) + if metaFlags == "" { + return meta, nil + } + + // split for each key=value pairs + pairs := strings.Split(metaFlags, ";") + for _, pair := range pairs { + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + key := kv[0] + values := strings.Split(kv[1], ",") + meta[key] = values + } else { + return nil, fmt.Errorf("Invalid meta format: %s\n", pair) + } + } + return meta, nil +} + +func parseMiddlewares(middlewares []MiddlewareForResolve) ([]GlobalMiddleware, error) { + var mwList []GlobalMiddleware + + for _, mw := range middlewares { + content, err := os.ReadFile(mw.Path) + if err != nil { + return nil, fmt.Errorf("failed to read middleware file %s: %v", mw.Path, err) + } + mwList = append(mwList, GlobalMiddleware{ + Name: mw.Name, + Content: string(content), + }) + } + return mwList, nil +} + func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { g.updatePackageInfo(pkg) g.setImports(HandlerFileName, pkg) - var tpls []*Template - tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) + pkg.ExtendMeta, err = parseMeta(g.MetaFlags) + if err != nil { + return nil, err + } + if g.Config.IncludesTpl != "" { + inc := g.Config.IncludesTpl + if strings.HasPrefix(inc, "git@") || strings.HasPrefix(inc, "http://") || strings.HasPrefix(inc, "https://") { + localGitPath, errMsg, gitErr := util.RunGitCommand(inc) + if gitErr != nil { + if errMsg == "" { + errMsg = gitErr.Error() + } + return nil, fmt.Errorf("failed to pull IDL from git:%s\nYou can execute 'rm -rf ~/.kitex' to clean the git cache and try again", errMsg) + } + if g.RenderTplDir != "" { + g.RenderTplDir = filepath.Join(localGitPath, g.RenderTplDir) + } else { + g.RenderTplDir = localGitPath + } + if util.Exists(g.RenderTplDir) { + return nil, fmt.Errorf("the render template directory path you specified does not exists int the git path") + } + } + } + var meta *Meta + metaPath := filepath.Join(g.RenderTplDir, kitexRenderMetaFile) + if util.Exists(metaPath) { + meta, err = readMetaFile(metaPath) + if err != nil { + return nil, err + } + middlewares, err := parseMiddlewares(meta.MWs) + if err != nil { + return nil, err + } + pkg.MWs = middlewares + } + tpls, err := readTpls(g.RenderTplDir, g.RenderTplDir, meta) if err != nil { return nil, err } @@ -278,13 +352,25 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, const kitexRenderMetaFile = "kitex_render_meta.yaml" -// Meta 代表kitex_render_meta.yaml文件的结构 +// Meta represents the structure of the kitex_render_meta.yaml file. type Meta struct { - Templates []Template `yaml:"templates"` + Templates []Template `yaml:"templates"` + MWs []MiddlewareForResolve `yaml:"middlewares"` + ExtendMeta []ExtendMeta `yaml:"extend_meta"` } -func readMetaFile(rootDir string) (*Meta, error) { - metaPath := filepath.Join(rootDir, kitexRenderMetaFile) +type MiddlewareForResolve struct { + // name of the middleware + Name string `yaml:"name"` + // path of the middleware + Path string `yaml:"path"` +} + +type ExtendMeta struct { + key string +} + +func readMetaFile(metaPath string) (*Meta, error) { metaData, err := os.ReadFile(metaPath) if err != nil { return nil, fmt.Errorf("failed to read meta file from %s: %v", metaPath, err) @@ -299,27 +385,28 @@ func readMetaFile(rootDir string) (*Meta, error) { return &meta, nil } -func getUpdateBehavior(meta *Meta, relativePath string) *Update { - for _, template := range meta.Templates { - // fmt.Println(template.Path, "==", relativePath) - if template.Path == relativePath { - return template.UpdateBehavior +func getMetadata(meta *Meta, relativePath string) *Template { + for i := range meta.Templates { + if meta.Templates[i].Path == relativePath { + return &meta.Templates[i] } } - return &Update{Type: string(skip)} + return &Template{ + UpdateBehavior: &Update{Type: string(skip)}, + } } -func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { - meta, err := readMetaFile(rootDir) - if err != nil { - return nil, err +func readTpls(rootDir, currentDir string, meta *Meta) (ts []*Template, error error) { + defaultMetadata := &Template{ + UpdateBehavior: &Update{Type: string(skip)}, } + files, _ := os.ReadDir(currentDir) for _, f := range files { // filter dir and non-tpl files if f.IsDir() { subDir := filepath.Join(currentDir, f.Name()) - subTemplates, err := readTpls(rootDir, subDir, ts) + subTemplates, err := readTpls(rootDir, subDir, meta) if err != nil { return nil, err } @@ -336,11 +423,19 @@ func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) } trimmedPath := strings.TrimSuffix(relativePath, ".tpl") - updateBehavior := getUpdateBehavior(meta, relativePath) + // If kitex_render_meta.yaml exists, get the corresponding metadata; otherwise, use the default metadata + var metadata *Template + if meta != nil { + metadata = getMetadata(meta, relativePath) + } else { + metadata = defaultMetadata + } t := &Template{ Path: trimmedPath, Body: string(tplData), - UpdateBehavior: updateBehavior, + UpdateBehavior: metadata.UpdateBehavior, + LoopMethod: metadata.LoopMethod, + LoopService: metadata.LoopService, } ts = append(ts, t) } @@ -357,7 +452,7 @@ func (g *generator) RenderWithMultipleFiles(pkg *PackageInfo) (fs []*File, err e } var updatedContent string if g.Config.DebugTpl { - // --debug时 在模板内容顶部加上一段magic string用于区分 + // when --debug is enabled, add a magic string at the top of the template content for distinction. updatedContent = "// Kitex template debug file. use template clean to delete it.\n\n" + string(content) } else { updatedContent = string(content) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index c7478887c0..0a1966c421 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -137,11 +137,13 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string // specify the location path of init subcommand - InitType string + InitOutputDir string // specify the location path of init subcommand + InitType string // specify the type for init subcommand RenderTplDir string // specify the path of template directory for render subcommand TemplateFiles []string // specify the path of single file or multiple file to render - DebugTpl bool + DebugTpl bool // turn on the debug mode + IncludesTpl string // specify the path of remote template repository for render subcommand + MetaFlags string // Metadata in key=value format, keys separated by ';' values separated by ',' GenPath string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index ebb5245b70..0cba0caca8 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -61,6 +61,8 @@ func TestConfig_Pack(t *testing.T) { RenderTplDir string TemplateFiles []string DebugTpl bool + IncludesTpl string + MetaFlags string Protocol string HandlerReturnKeepResp bool } @@ -74,7 +76,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -106,6 +108,8 @@ func TestConfig_Pack(t *testing.T) { InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, TemplateFiles: tt.fields.TemplateFiles, + IncludesTpl: tt.fields.IncludesTpl, + MetaFlags: tt.fields.MetaFlags, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 367ac997e4..27ba00a711 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -51,6 +51,15 @@ type PackageInfo struct { Protocol transport.Protocol IDLName string ServerPkg string + ExtendMeta map[string][]string // key-value metadata for render + MWs []GlobalMiddleware +} + +type GlobalMiddleware struct { + // the name of the middleware + Name string + // the content of the middleware + Content string } // AddImport . diff --git a/tool/internal_pkg/tpl/mock.go b/tool/internal_pkg/tpl/mock.go deleted file mode 100644 index 27da3a59d7..0000000000 --- a/tool/internal_pkg/tpl/mock.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2024 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tpl - -var ClientMockTpl = ` -// Code generated by MockGen. DO NOT EDIT. -// Source: kitex_gen/example/shop/item/itemserviceb/client.go -// -// Generated by this command: -// -// ___1go_build_go_uber_org_mock_mockgen -source=kitex_gen/example/shop/item/itemserviceb/client.go -destination=client_mock.go -package=main -// - -// Package main is a generated GoMock package. -` From 5dda3b7a4b86b66ff6c224d246ba35e5735b8161 Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 6 Aug 2024 16:35:07 +0800 Subject: [PATCH 30/41] fix: remove redundant line Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 1 - .../internal_pkg/generator/custom_template.go | 28 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index a43c0171ea..64c2b52f30 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -278,7 +278,6 @@ func (a *Arguments) TemplateArgs(version string) error { renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") - renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index f4ca18b690..c78fa92390 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -200,6 +200,20 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } +func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { + cg := NewCustomGenerator(pkg, outputPath) + // special handling Methods field + if tpl.LoopMethod { + err = cg.loopGenerate(tpl) + } else { + err = cg.commonGenerate(tpl) + } + if errors.Is(err, errNoNewMethod) { + err = nil + } + return cg.fs, err +} + func readTemplates(dir string) ([]*Template, error) { files, _ := ioutil.ReadDir(dir) var ts []*Template @@ -224,20 +238,6 @@ func readTemplates(dir string) ([]*Template, error) { return ts, nil } -func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { - cg := NewCustomGenerator(pkg, outputPath) - // special handling Methods field - if tpl.LoopMethod { - err = cg.loopGenerate(tpl) - } else { - err = cg.commonGenerate(tpl) - } - if errors.Is(err, errNoNewMethod) { - err = nil - } - return cg.fs, err -} - // parseMeta parses the meta flag and returns a map where the value is a slice of strings func parseMeta(metaFlags string) (map[string][]string, error) { meta := make(map[string][]string) From 86f44d554769ff99165ce249d0a29f79d14e2497 Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 6 Aug 2024 20:03:40 +0800 Subject: [PATCH 31/41] fix: fix single command Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 133 ++++++++++-- tool/cmd/kitex/main.go | 4 +- tool/internal_pkg/util/command.go | 53 ++++- tool/internal_pkg/util/command_test.go | 24 +-- tool/internal_pkg/util/flag.go | 285 ++++++++++++++++++++++--- tool/internal_pkg/util/flag_test.go | 17 -- 6 files changed, 423 insertions(+), 93 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 64c2b52f30..a190ff7f2d 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -177,6 +177,72 @@ func (a *Arguments) checkTplArgs() error { return nil } +func (a *Arguments) Root(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Template(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + func (a *Arguments) Render(cmd *util.Command, args []string) error { curpath, err := filepath.Abs(".") if err != nil { @@ -247,10 +313,12 @@ func (a *Arguments) TemplateArgs(version string) error { kitexCmd := &util.Command{ Use: "kitex", Short: "Kitex command", + RunE: a.Root, } templateCmd := &util.Command{ Use: "template", Short: "Template command", + RunE: a.Template, } initCmd := &util.Command{ Use: "init", @@ -267,8 +335,12 @@ func (a *Arguments) TemplateArgs(version string) error { Short: "Clean command", RunE: a.Clean, } + kitexCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + templateCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVar(&a.InitType, "type", "", "Specify template init type") + initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") @@ -279,20 +351,55 @@ func (a *Arguments) TemplateArgs(version string) error { renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") - initCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: kitex template init [flags] -`, version) + templateCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Template operation + +Usage: + kitex template [command] + +Available Commands: + init Initialize the templates according to the type + render Render the template files + clean Clean the debug templates + `) }) - renderCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: template render --dir [template dir_path] [flags] IDL -`, version) + initCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Initialize the templates according to the type + +Usage: + kitex template init [flags] + +Flags: + -o, --output string Output directory + -t, --type string The init type of the template + `) }) - cleanCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: kitex template clean -`, version) + renderCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Render the template files + +Usage: + kitex template render [flags] + +Flags: + --dir string Output directory + --debug bool Turn on the debug mode + --file stringArray Specify multiple files for render + -I, --Includes string Add an template git search path for includes. + --meta string Specify meta data for render + --module string Specify the Go module name to generate go.mod. + -t, --type string The init type of the template + `) + }) + cleanCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Clean the debug templates + +Usage: + kitex template clean + `) }) templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) kitexCmd.AddCommand(templateCmd) diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 1fb5372ffe..0777986bd0 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -77,9 +77,9 @@ func main() { log.Warn("Get current path failed:", err.Error()) os.Exit(1) } - if os.Args[1] == "template" { + if len(os.Args) > 1 && os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version) - } else if !strings.HasPrefix(os.Args[1], "-") { + } else if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { err = fmt.Errorf("unknown command %q", os.Args[1]) } else { // run as kitex diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index fda97822d3..35fe4b609d 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -15,7 +15,9 @@ package util import ( + "errors" "fmt" + "io" "os" "strings" ) @@ -29,7 +31,7 @@ type Command struct { parent *Command flags *FlagSet // helpFunc is help func defined by user. - usage func() + helpFunc func(*Command, []string) // for debug args []string } @@ -161,12 +163,43 @@ func (c *Command) ParseFlags(args []string) error { return err } -func (c *Command) SetUsageFunc(f func()) { - c.usage = f +// SetHelpFunc sets help function. Can be defined by Application. +func (c *Command) SetHelpFunc(f func(*Command, []string)) { + c.helpFunc = f } -func (c *Command) UsageFunc() func() { - return c.usage +// HelpFunc returns either the function set by SetHelpFunc for this command +// or a parent, or it returns a function with default help behavior. +func (c *Command) HelpFunc() func(*Command, []string) { + if c.helpFunc != nil { + return c.helpFunc + } + if c.HasParent() { + return c.parent.HelpFunc() + } + return nil +} + +// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrln(i ...interface{}) { + c.PrintErr(fmt.Sprintln(i...)) +} + +// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErr(i ...interface{}) { + fmt.Fprint(c.ErrOrStderr(), i...) +} + +// ErrOrStderr returns output to stderr +func (c *Command) ErrOrStderr() io.Writer { + return c.getErr(os.Stderr) +} + +func (c *Command) getErr(def io.Writer) io.Writer { + if c.HasParent() { + return c.parent.getErr(def) + } + return def } // ExecuteC executes the command. @@ -181,8 +214,16 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { } err = cmd.execute(flags) if err != nil { - cmd.usage() + // Always show help if requested, even if SilenceErrors is in + // effect + if errors.Is(err, ErrHelp) { + cmd.HelpFunc()(cmd, args) + return cmd, nil + } } + //if err != nil { + // cmd.usage() + //} return cmd, err } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go index 9bb5723d67..ae83875b52 100644 --- a/tool/internal_pkg/util/command_test.go +++ b/tool/internal_pkg/util/command_test.go @@ -191,19 +191,14 @@ func TestFlagLong(t *testing.T) { RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, } - var intFlagValue int var stringFlagValue string - c.Flags().IntVar(&intFlagValue, "intf", -1, "") c.Flags().StringVar(&stringFlagValue, "sf", "", "") - err := executeCommand(c, "--intf=7", "--sf=abc", "one", "--", "two") + err := executeCommand(c, "--sf=abc", "one", "--", "two") if err != nil { t.Errorf("Unexpected error: %v", err) } - if intFlagValue != 7 { - t.Errorf("Expected intFlagValue: %v, got %v", 7, intFlagValue) - } if stringFlagValue != "abc" { t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } @@ -221,19 +216,14 @@ func TestFlagShort(t *testing.T) { RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, } - var intFlagValue int var stringFlagValue string - c.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") - err := executeCommand(c, "-i", "7", "-sabc", "one", "two") + err := executeCommand(c, "-sabc", "one", "two") if err != nil { t.Errorf("Unexpected error: %v", err) } - if intFlagValue != 7 { - t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) - } if stringFlagValue != "abc" { t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } @@ -248,16 +238,8 @@ func TestChildFlag(t *testing.T) { rootCmd := &Command{Use: "root", RunE: emptyRun} childCmd := &Command{Use: "child", RunE: emptyRun} rootCmd.AddCommand(childCmd) - - var intFlagValue int - childCmd.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") - - err := executeCommand(rootCmd, "child", "-i7") + err := executeCommand(rootCmd, "child") if err != nil { t.Errorf("Unexpected error: %v", err) } - - if intFlagValue != 7 { - t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) - } } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index d80a7ac6cf..3c936b72ef 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "os" + "sort" "strconv" "strings" ) @@ -71,6 +72,7 @@ type FlagSet struct { orderedActual []*Flag formal map[NormalizedName]*Flag orderedFormal []*Flag + sortedFormal []*Flag shorthands map[byte]*Flag args []string // arguments after flags argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- @@ -100,6 +102,7 @@ type Flag struct { type Value interface { String() string Set(string) error + Type() string } // SliceValue is a secondary interface to all flags which hold a list @@ -114,6 +117,22 @@ type SliceValue interface { GetSlice() []string } +// sortFlags returns the flags as a slice in lexicographical sorted order. +func sortFlags(flags map[NormalizedName]*Flag) []*Flag { + list := make(sort.StringSlice, len(flags)) + i := 0 + for k := range flags { + list[i] = string(k) + i++ + } + list.Sort() + result := make([]*Flag, len(list)) + for i, name := range list { + result[i] = flags[NormalizedName(name)] + } + return result +} + // GetNormalizeFunc returns the previously set NormalizeFunc of a function which // does no translation, if not set previously. func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { @@ -186,6 +205,201 @@ func (f *FlagSet) Set(name, value string) error { return nil } +func (f *FlagSet) VisitAll(fn func(*Flag)) { + if len(f.formal) == 0 { + return + } + + var flags []*Flag + if f.SortFlags { + if len(f.formal) != len(f.sortedFormal) { + f.sortedFormal = sortFlags(f.formal) + } + flags = f.sortedFormal + } else { + flags = f.orderedFormal + } + + for _, flag := range flags { + fn(flag) + } +} + +func UnquoteUsage(flag *Flag) (name, usage string) { + // Look for a back-quoted name, but avoid the strings package. + usage = flag.Usage + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name = usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break // Only one back quote; use type name. + } + } + + name = flag.Value.Type() + switch name { + case "bool": + name = "" + case "float64": + name = "float" + case "int64": + name = "int" + case "uint64": + name = "uint" + case "stringSlice": + name = "strings" + case "intSlice": + name = "ints" + case "uintSlice": + name = "uints" + case "boolSlice": + name = "bools" + } + + return +} + +func (f *FlagSet) FlagUsagesWrapped(cols int) string { + buf := new(bytes.Buffer) + + lines := make([]string, 0, len(f.formal)) + + maxlen := 0 + f.VisitAll(func(flag *Flag) { + line := "" + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) + } else { + line = fmt.Sprintf(" --%s", flag.Name) + } + + varname, usage := UnquoteUsage(flag) + if varname != "" { + line += " " + varname + } + if flag.NoOptDefVal != "" { + switch flag.Value.Type() { + case "string": + line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) + case "bool": + if flag.NoOptDefVal != "true" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + case "count": + if flag.NoOptDefVal != "+1" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + default: + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + } + + // This special character will be replaced with spacing once the + // correct alignment is calculated + line += "\x00" + if len(line) > maxlen { + maxlen = len(line) + } + + line += usage + if len(flag.Deprecated) != 0 { + line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) + } + + lines = append(lines, line) + }) + + for _, line := range lines { + sidx := strings.Index(line, "\x00") + spacing := strings.Repeat(" ", maxlen-sidx) + // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx + fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) + } + + return buf.String() +} + +func wrap(i, w int, s string) string { + if w == 0 { + return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + // space between indent i and end of line width w into which + // we should wrap the text. + wrap := w - i + + var r, l string + + // Not enough space for sensible wrapping. Wrap as a block on + // the next line instead. + if wrap < 24 { + i = 16 + wrap = w - i + r += "\n" + strings.Repeat(" ", i) + } + // If still not enough space then don't even try to wrap. + if wrap < 24 { + return strings.Replace(s, "\n", r, -1) + } + + // Try to avoid short orphan words on the final line, by + // allowing wrapN to go a bit over if that would fit in the + // remainder of the line. + slop := 5 + wrap = wrap - slop + + // Handle first line, which is indented by the caller (or the + // special case above) + l, s = wrapN(wrap, slop, s) + r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) + + // Now wrap the rest + for s != "" { + var t string + + t, s = wrapN(wrap, slop, s) + r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + return r +} + +func wrapN(i, slop int, s string) (string, string) { + if i+slop > len(s) { + return s, "" + } + + w := strings.LastIndexAny(s[:i], " \t\n") + if w <= 0 { + return s, "" + } + nlPos := strings.LastIndex(s[:i], "\n") + if nlPos > 0 && nlPos < w { + return s[:nlPos], s[nlPos+1:] + } + return s[:w], s[w+1:] +} + +func (f *FlagSet) FlagUsages() string { + return f.FlagUsagesWrapped(0) +} + +func (f *FlagSet) PrintDefaults() { + usages := f.FlagUsages() + fmt.Fprint(f.out(), usages) +} + +// defaultUsage is the default function to print a usage message. +func defaultUsage(f *FlagSet) { + fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) + f.PrintDefaults() +} + // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } @@ -255,6 +469,7 @@ func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) if f.errorHandling != ContinueOnError { fmt.Fprintln(f.out(), err) + f.usage() } return err } @@ -296,6 +511,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if !exists { switch { case name == "help": + f.usage() return a, ErrHelp case f.ParseErrorsWhitelist.UnknownFlags: // --unknown=unknownval arg ... @@ -349,6 +565,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse if !exists { switch { case c == 'h': + f.usage() err = ErrHelp return case f.ParseErrorsWhitelist.UnknownFlags: @@ -501,6 +718,28 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { return f } +// PrintDefaults prints to standard error the default values of all defined command-line flags. +func PrintDefaults() { + CommandLine.PrintDefaults() +} + +var Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + PrintDefaults() +} + +// usage calls the Usage method for the flag set, or the usage function if +// the flag set is CommandLine. +func (f *FlagSet) usage() { + if f == CommandLine { + Usage() + } else if f.Usage == nil { + defaultUsage(f) + } else { + f.Usage() + } +} + // -- string Value type stringValue string @@ -516,6 +755,10 @@ func (s *stringValue) Set(val string) error { func (s *stringValue) String() string { return string(*s) } +func (s *stringValue) Type() string { + return "string" +} + // StringVar defines a string flag with specified name, default value, and usage string. func (f *FlagSet) StringVar(p *string, name, value, usage string) { f.VarP(newStringValue(value, p), name, "", usage) @@ -557,6 +800,10 @@ func (b *boolValue) Set(s string) error { func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } +func (b *boolValue) Type() string { + return "bool" +} + func (b *boolValue) IsBoolFlag() bool { return true } // BoolVar defines a bool flag with specified name, default value, and usage string. @@ -592,39 +839,6 @@ func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool return p } -// -- int Value -type intValue int - -func newIntValue(val int, p *int) *intValue { - *p = val - return (*intValue)(p) -} - -func (i *intValue) Set(s string) error { - v, err := strconv.ParseInt(s, 0, 64) - *i = intValue(v) - return err -} - -func (i *intValue) String() string { return strconv.Itoa(int(*i)) } - -// IntVar defines an int flag with specified name, default value, and usage string. -// The argument p points to an int variable in which to store the value of the flag. -func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { - f.VarP(newIntValue(value, p), name, "", usage) -} - -// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) IntVarP(p *int, name, shorthand string, value int, usage string) { - f.VarP(newIntValue(value, p), name, shorthand, usage) -} - -func (f *FlagSet) Int(name string, value int, usage string) *int { - p := new(int) - f.IntVarP(p, name, "", value, usage) - return p -} - // -- stringArray Value type stringArrayValue struct { value *[]string @@ -664,9 +878,12 @@ func (s *stringArrayValue) String() string { return "[" + str + "]" } +func (s *stringArrayValue) Type() string { + return "stringArray" +} + // StringArrayVar defines a string flag with specified name, default value, and usage string. // The argument p points to a []string variable in which to store the values of the multiple flags. -// The value of each argument will not try to be separated by comma. Use a StringSlice for that. func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { f.VarP(newStringArrayValue(value, p), name, "", usage) } diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go index 364d379c1f..916f906df5 100644 --- a/tool/internal_pkg/util/flag_test.go +++ b/tool/internal_pkg/util/flag_test.go @@ -45,21 +45,13 @@ func testParse(f *FlagSet, t *testing.T) { boolFlag := f.Bool("bool", false, "bool value") bool2Flag := f.Bool("bool2", false, "bool2 value") bool3Flag := f.Bool("bool3", false, "bool3 value") - intFlag := f.Int("int", 0, "int value") stringFlag := f.String("string", "0", "string value") - optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") - f.Lookup("optional-int-no-value").NoOptDefVal = "9" - optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") - f.Lookup("optional-int-no-value").NoOptDefVal = "9" extra := "one-extra-argument" args := []string{ "--bool", "--bool2=true", "--bool3=false", - "--int=22", "--string=hello", - "--optional-int-no-value", - "--optional-int-with-value=42", extra, } if err := f.Parse(args); err != nil { @@ -77,18 +69,9 @@ func testParse(f *FlagSet, t *testing.T) { if *bool3Flag != false { t.Error("bool3 flag should be false, is ", *bool2Flag) } - if *intFlag != 22 { - t.Error("int flag should be 22, is ", *intFlag) - } if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } - if *optionalIntNoValueFlag != 9 { - t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) - } - if *optionalIntWithValueFlag != 42 { - t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) - } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra { From eee50851cf115053b389ec239e83296fff5a25f6 Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 10 Jul 2024 17:25:22 +0800 Subject: [PATCH 32/41] feat: add template render subcommand using .tpl Signed-off-by: shawn fix: fix some issue Signed-off-by: shawn perf: custom allocator for fast codec ReadString/ReadBinary (#1427) chore: remove useless reflection api (#1433) optimize(lb): rebalance when instance weights updated (#1397) fix: support setting PurePayload Transport Protocol (#1436) # Conflicts: # internal/generic/thrift/http_fallback.go # pkg/protocol/bthrift/binary.go # pkg/rpcinfo/rpcconfig.go # pkg/rpcinfo/rpcconfig_test.go # tool/cmd/kitex/args/tpl_args.go # tool/cmd/kitex/main.go # tool/internal_pkg/generator/generator.go # tool/internal_pkg/pluginmode/thriftgo/plugin.go # tool/internal_pkg/util/command.go --- pkg/protocol/bthrift/binary_test.go | 39 ++ pkg/protocol/bthrift/interface.go | 5 + tool/cmd/kitex/args/args.go | 10 +- tool/cmd/kitex/args/tpl_args.go | 575 +++++++----------- tool/cmd/kitex/main.go | 9 +- .../internal_pkg/generator/custom_template.go | 73 +++ tool/internal_pkg/generator/generator.go | 14 +- .../pluginmode/thriftgo/plugin.go | 14 +- tool/internal_pkg/util/command.go | 232 ++----- 9 files changed, 389 insertions(+), 582 deletions(-) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index ba86bd6a53..25382f95c6 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -291,6 +291,24 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } +// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator +func TestWriteAndReadStringWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + wn := Binary.WriteString(buf, "kitex") + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadString(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + test.Assert(t, v == "kitex") + SetSpanCache(false) +} + // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -310,6 +328,27 @@ func TestWriteAndReadBinary(t *testing.T) { } } +// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator +func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { + buf := make([]byte, 128) + exceptWs := "000000056b69746578" + exceptSize := 9 + val := []byte("kitex") + wn := Binary.WriteBinary(buf, val) + ws := fmt.Sprintf("%x", buf[:wn]) + test.Assert(t, wn == exceptSize, wn, exceptSize) + test.Assert(t, ws == exceptWs, ws, exceptWs) + + SetSpanCache(true) + v, length, err := Binary.ReadBinary(buf) + test.Assert(t, nil == err) + test.Assert(t, exceptSize == length) + for i := 0; i < len(v); i++ { + test.Assert(t, val[i] == v[i]) + } + SetSpanCache(false) +} + // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index 75fa0ce951..e65d667318 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -97,3 +97,8 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } + +type Allocator interface { + Make(n int) []byte + Copy(buf []byte) (p []byte) +} diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index fab0cb7c0e..f57526fb3e 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,14 +212,20 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" && a.TplDir == "" { if a.Use != "" { - return fmt.Errorf("-use must be used with -service or -template-dir") + return fmt.Errorf("-use must be used with -service or -template-dir or template render") } } + if a.TemplateDir != "" && a.TplDir != "" { + return fmt.Errorf("template render and -template-dir cannot be used at the same time") + } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } + if a.ServiceName != "" && a.TplDir != "" { + return fmt.Errorf("template render and -service cannot be used at the same time") + } if a.ServiceName != "" { a.GenerateMain = true } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index a190ff7f2d..4a00a3c6be 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -1,409 +1,264 @@ -// Copyright 2024 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package args import ( + "flag" "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" - "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "os" + "strings" ) -// Constants . -const ( - KitexGenPath = "kitex_gen" - DefaultCodec = "thrift" - - BuildFileName = "build.sh" - BootstrapFileName = "bootstrap.sh" - ToolVersionFileName = "kitex_info.yaml" - HandlerFileName = "handler.go" - MainFileName = "main.go" - ClientFileName = "client.go" - ServerFileName = "server.go" - InvokerFileName = "invoker.go" - ServiceFileName = "*service.go" - ExtensionFilename = "extensions.yaml" - - MultipleServicesFileName = "multiple_services.go" -) - -var defaultTemplates = map[string]string{ - BuildFileName: tpl.BuildTpl, - BootstrapFileName: tpl.BootstrapTpl, - ToolVersionFileName: tpl.ToolVersionTpl, - HandlerFileName: tpl.HandlerTpl, - MainFileName: tpl.MainTpl, - ClientFileName: tpl.ClientTpl, - ServerFileName: tpl.ServerTpl, - InvokerFileName: tpl.InvokerTpl, - ServiceFileName: tpl.ServiceTpl, -} - -var multipleServicesTpl = map[string]string{ - MultipleServicesFileName: tpl.MultipleServicesTpl, -} - -const ( - DefaultType = "default" - MultipleServicesType = "multiple_services" -) - -type TemplateGenerator func(string) error - -var genTplMap = map[string]TemplateGenerator{ - DefaultType: GenTemplates, - MultipleServicesType: GenMultipleServicesTemplates, -} - -// GenTemplates is the entry for command kitex template, -// it will create the specified path -func GenTemplates(path string) error { - return InitTemplates(path, defaultTemplates) -} - -func GenMultipleServicesTemplates(path string) error { - return InitTemplates(path, multipleServicesTpl) -} - -// InitTemplates creates template files. -func InitTemplates(path string, templates map[string]string) error { - if err := MkdirIfNotExist(path); err != nil { - return err - } - - for name, content := range templates { - var filePath string - if name == BootstrapFileName { - bootstrapDir := filepath.Join(path, "script") - if err := MkdirIfNotExist(bootstrapDir); err != nil { - return err - } - filePath = filepath.Join(bootstrapDir, name+".tpl") - } else { - filePath = filepath.Join(path, name+".tpl") - } - if err := createTemplate(filePath, content); err != nil { - return err - } - } - - return nil -} - -// GetTemplateDir returns the category path. -func GetTemplateDir(category string) (string, error) { - home, err := filepath.Abs(".") - if err != nil { - return "", err - } - return filepath.Join(home, category), nil -} - -// MkdirIfNotExist makes directories if the input path is not exists -func MkdirIfNotExist(dir string) error { - if len(dir) == 0 { - return nil - } - - if _, err := os.Stat(dir); os.IsNotExist(err) { - return os.MkdirAll(dir, os.ModePerm) - } - - return nil -} - -func createTemplate(file, content string) error { - if util.Exists(file) { - return nil - } - - f, err := os.Create(file) - if err != nil { - return err - } - defer f.Close() - - _, err = f.WriteString(content) - return err -} - -func (a *Arguments) Init(cmd *util.Command, args []string) error { - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) - } - path := a.InitOutputDir - initType := a.InitType - if initType == "" { - initType = DefaultType - } - if path == "" { - path = curpath - } - if err := genTplMap[initType](path); err != nil { - return err - } - fmt.Printf("Templates are generated in %s\n", path) - os.Exit(0) - return nil -} - -func (a *Arguments) checkTplArgs() error { - if a.TemplateDir != "" && a.RenderTplDir != "" { - return fmt.Errorf("template render --dir and -template-dir cannot be used at the same time") - } - if a.RenderTplDir != "" && len(a.TemplateFiles) > 0 { - return fmt.Errorf("template render --dir and --file option cannot be specified at the same time") - } - return nil -} - -func (a *Arguments) Root(cmd *util.Command, args []string) error { - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) - } - log.Verbose = a.Verbose +func (a *Arguments) addBasicFlags(f *flag.FlagSet, version string) *flag.FlagSet { + f.BoolVar(&a.NoFastAPI, "no-fast-api", false, + "Generate codes without injecting fast method.") + f.StringVar(&a.ModuleName, "module", "", + "Specify the Go module name to generate go.mod.") + f.StringVar(&a.ServiceName, "service", "", + "Specify the service name to generate server side codes.") + f.StringVar(&a.Use, "use", "", + "Specify the kitex_gen package to import when generate server side codes.") + f.BoolVar(&a.Verbose, "v", false, "") // short for -verbose + f.BoolVar(&a.Verbose, "verbose", false, + "Turn on verbose mode.") + f.BoolVar(&a.GenerateInvoker, "invoker", false, + "Generate invoker side codes when service name is specified.") + f.StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + f.Var(&a.Includes, "I", "Add an IDL search path for includes.") + f.Var(&a.ThriftOptions, "thrift", "Specify arguments for the thrift go compiler.") + f.Var(&a.Hessian2Options, "hessian2", "Specify arguments for the hessian2 codec.") + f.DurationVar(&a.ThriftPluginTimeLimit, "thrift-plugin-time-limit", generator.DefaultThriftPluginTimeLimit, "Specify thrift plugin execution time limit.") + f.StringVar(&a.CompilerPath, "compiler-path", "", "Specify the path of thriftgo/protoc.") + f.Var(&a.ThriftPlugins, "thrift-plugin", "Specify thrift plugin arguments for the thrift compiler.") + f.Var(&a.ProtobufOptions, "protobuf", "Specify arguments for the protobuf compiler.") + f.Var(&a.ProtobufPlugins, "protobuf-plugin", "Specify protobuf plugin arguments for the protobuf compiler.(plugin_name:options:out_dir)") + f.BoolVar(&a.CombineService, "combine-service", false, + "Combine services in root thrift file.") + f.BoolVar(&a.CopyIDL, "copy-idl", false, + "Copy each IDL file to the output path.") + f.BoolVar(&a.HandlerReturnKeepResp, "handler-return-keep-resp", false, + "When the server-side handler returns both err and resp, the resp return is retained for use in middleware where both err and resp can be used simultaneously. Note: At the RPC communication level, if the handler returns an err, the framework still only returns err to the client without resp.") + f.StringVar(&a.ExtensionFile, "template-extension", a.ExtensionFile, + "Specify a file for template extension.") + f.BoolVar(&a.FrugalPretouch, "frugal-pretouch", false, + "Use frugal to compile arguments and results when new clients and servers.") + f.BoolVar(&a.Record, "record", false, + "Record Kitex cmd into kitex-all.sh.") + f.StringVar(&a.TemplateDir, "template-dir", "", + "Use custom template to generate codes.") + f.StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + f.BoolVar(&a.DeepCopyAPI, "deep-copy-api", false, + "Generate codes with injecting deep copy method.") + f.StringVar(&a.Protocol, "protocol", "", + "Specify a protocol for codec") + f.BoolVar(&a.NoDependencyCheck, "no-dependency-check", false, + "Skip dependency checking.") + a.RecordCmd = os.Args + a.Version = version + a.ThriftOptions = append(a.ThriftOptions, + "naming_style=golint", + "ignore_initialisms", + "gen_setter", + "gen_deep_equal", + "compatible_names", + "frugal_tag", + "thrift_streaming", + "no_processor", + ) for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err = a.checkIDL(args) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - err = a.checkTplArgs() - if err != nil { - return err + e.Apply(f) } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) + return f } -func (a *Arguments) Template(cmd *util.Command, args []string) error { - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) - } - log.Verbose = a.Verbose +func (a *Arguments) buildInitFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("init", flag.ContinueOnError) + f.StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template init [flags] - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } +Examples: + %s template init -o /path/to/output + %s template init - err = a.checkIDL(args) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - err = a.checkTplArgs() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath +Flags: +`, version, os.Args[0], os.Args[0], os.Args[0]) + f.PrintDefaults() } - return a.checkPath(curpath) + return f } -func (a *Arguments) Render(cmd *util.Command, args []string) error { - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) - } - log.Verbose = a.Verbose +func (a *Arguments) buildRenderFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("render", flag.ContinueOnError) + f.StringVar(&a.TemplateFile, "f", "", "Specify template init path") + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template render [template dir_path] [flags] IDL - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } +Examples: + %s template render ${template dir_path} -module ${module_name} idl/hello.thrift + %s template render ${template dir_path} -f service.go.tpl -module ${module_name} idl/hello.thrift + %s template render ${template dir_path} -module ${module_name} -I xxx.git idl/hello.thrift - err = a.checkIDL(args) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - err = a.checkTplArgs() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath +Flags: +`, version, os.Args[0], os.Args[0], os.Args[0], os.Args[0]) + f.PrintDefaults() } - return a.checkPath(curpath) + return f } -func (a *Arguments) Clean(cmd *util.Command, args []string) error { - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) - } +func (a *Arguments) buildCleanFlags(version string) *flag.FlagSet { + f := flag.NewFlagSet("clean", flag.ContinueOnError) + f = a.addBasicFlags(f, version) + f.Usage = func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: %s template clean - magicString := "// Kitex template debug file. use template clean to delete it." - err = filepath.WalkDir(curpath, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - content, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("read file %s faild: %v", path, err) - } - if strings.Contains(string(content), magicString) { - if err := os.Remove(path); err != nil { - return fmt.Errorf("delete file %s failed: %v", path, err) - } - } - return nil - }) - if err != nil { - return fmt.Errorf("error cleaning debug template files: %v", err) +Examples: + %s template clean + +Flags: +`, version, os.Args[0], os.Args[0]) + f.PrintDefaults() } - fmt.Println("clean debug template files successfully...") - os.Exit(0) - return nil + return f } -func (a *Arguments) TemplateArgs(version string) error { - kitexCmd := &util.Command{ - Use: "kitex", - Short: "Kitex command", - RunE: a.Root, - } +func (a *Arguments) TemplateArgs(version, curpath string) error { templateCmd := &util.Command{ Use: "template", Short: "Template command", - RunE: a.Template, + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Template command executed") + return nil + }, } initCmd := &util.Command{ Use: "init", Short: "Init command", - RunE: a.Init, + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Init command executed") + f := a.buildInitFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, } renderCmd := &util.Command{ Use: "render", Short: "Render command", - RunE: a.Render, + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Render command executed") + if len(args) > 0 { + a.TplDir = args[0] + } + var tplDir string + for i, arg := range args { + if !strings.HasPrefix(arg, "-") { + tplDir = arg + args = append(args[:i], args[i+1:]...) + break + } + } + if tplDir == "" { + cmd.PrintUsage() + return fmt.Errorf("template directory is required") + } + + f := a.buildRenderFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, } cleanCmd := &util.Command{ Use: "clean", Short: "Clean command", - RunE: a.Clean, - } - kitexCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, - "Specify a code gen path.") - templateCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, - "Specify a code gen path.") - initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") - renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") - renderCmd.Flags().StringVar(&a.ModuleName, "module", "", - "Specify the Go module name to generate go.mod.") - renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") - renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, - "Specify a code gen path.") - renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") - renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") - renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") - renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") - templateCmd.SetHelpFunc(func(*util.Command, []string) { - fmt.Fprintln(os.Stderr, ` -Template operation - -Usage: - kitex template [command] - -Available Commands: - init Initialize the templates according to the type - render Render the template files - clean Clean the debug templates - `) - }) - initCmd.SetHelpFunc(func(*util.Command, []string) { - fmt.Fprintln(os.Stderr, ` -Initialize the templates according to the type - -Usage: - kitex template init [flags] - -Flags: - -o, --output string Output directory - -t, --type string The init type of the template - `) - }) - renderCmd.SetHelpFunc(func(*util.Command, []string) { - fmt.Fprintln(os.Stderr, ` -Render the template files + RunE: func(cmd *util.Command, args []string) error { + fmt.Println("Clean command executed") + f := a.buildCleanFlags(version) + if err := f.Parse(args); err != nil { + return err + } + log.Verbose = a.Verbose -Usage: - kitex template render [flags] + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } -Flags: - --dir string Output directory - --debug bool Turn on the debug mode - --file stringArray Specify multiple files for render - -I, --Includes string Add an template git search path for includes. - --meta string Specify meta data for render - --module string Specify the Go module name to generate go.mod. - -t, --type string The init type of the template - `) - }) - cleanCmd.SetHelpFunc(func(*util.Command, []string) { - fmt.Fprintln(os.Stderr, ` -Clean the debug templates + err := a.checkIDL(f.Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) + }, + } + templateCmd.AddCommand(initCmd) + templateCmd.AddCommand(renderCmd) + templateCmd.AddCommand(cleanCmd) -Usage: - kitex template clean - `) - }) - templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) - kitexCmd.AddCommand(templateCmd) - if _, err := kitexCmd.ExecuteC(); err != nil { + if _, err := templateCmd.ExecuteC(); err != nil { return err } return nil diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 0777986bd0..c03869d81c 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,7 +17,6 @@ package main import ( "bytes" "flag" - "fmt" "io/ioutil" "os" "os/exec" @@ -52,7 +51,7 @@ func init() { if err := versions.RegisterMinDepVersion( &versions.MinDepVersion{ RefPath: "github.com/cloudwego/kitex", - Version: "v0.10.3", + Version: "v0.9.0", }, ); err != nil { log.Warn(err) @@ -77,10 +76,8 @@ func main() { log.Warn("Get current path failed:", err.Error()) os.Exit(1) } - if len(os.Args) > 1 && os.Args[1] == "template" { - err = args.TemplateArgs(kitex.Version) - } else if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { - err = fmt.Errorf("unknown command %q", os.Args[1]) + if os.Args[1] == "template" { + err = args.TemplateArgs(kitex.Version, curpath) } else { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index c78fa92390..82efd8eb91 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -200,6 +200,79 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } +func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { + files, _ := os.ReadDir(currentDir) + for _, f := range files { + // filter dir and non-tpl files + if f.IsDir() { + subDir := filepath.Join(currentDir, f.Name()) + subTemplates, err := readTpls(rootDir, subDir, ts) + if err != nil { + return nil, err + } + ts = append(ts, subTemplates...) + } else if strings.HasSuffix(f.Name(), ".tpl") { + p := filepath.Join(currentDir, f.Name()) + tplData, err := os.ReadFile(p) + if err != nil { + return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) + } + // Remove the .tpl suffix from the Path and compute relative path + relativePath, err := filepath.Rel(rootDir, p) + if err != nil { + return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) + } + trimmedPath := strings.TrimSuffix(relativePath, ".tpl") + t := &Template{ + Path: trimmedPath, + Body: string(tplData), + UpdateBehavior: &Update{Type: string(skip)}, + } + ts = append(ts, t) + } + } + + return ts, nil +} + +func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { + g.updatePackageInfo(pkg) + + g.setImports(HandlerFileName, pkg) + var tpls []*Template + tpls, err = readTpls(g.TplDir, g.TplDir, tpls) + if err != nil { + return nil, err + } + for _, tpl := range tpls { + newPath := filepath.Join(g.OutputPath, tpl.Path) + dir := filepath.Dir(newPath) + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, fmt.Errorf("failed to create directory %s: %v", dir, err) + } + if tpl.LoopService && g.CombineService { + svrInfo, cs := pkg.ServiceInfo, pkg.CombineServices + + for i := range cs { + pkg.ServiceInfo = cs[i] + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + pkg.ServiceInfo, pkg.CombineServices = svrInfo, cs + } else { + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + } + return fs, nil +} + func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { cg := NewCustomGenerator(pkg, outputPath) // special handling Methods field diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 0a1966c421..92f4f8006d 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -16,7 +16,6 @@ package generator import ( - "errors" "fmt" "go/token" "path/filepath" @@ -98,7 +97,6 @@ type Generator interface { GenerateMainPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackageWithTpl(pkg *PackageInfo) ([]*File, error) - RenderWithMultipleFiles(pkg *PackageInfo) ([]*File, error) } // Config . @@ -137,13 +135,9 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string // specify the location path of init subcommand - InitType string // specify the type for init subcommand - RenderTplDir string // specify the path of template directory for render subcommand - TemplateFiles []string // specify the path of single file or multiple file to render - DebugTpl bool // turn on the debug mode - IncludesTpl string // specify the path of remote template repository for render subcommand - MetaFlags string // Metadata in key=value format, keys separated by ';' values separated by ',' + InitOutputDir string + TplDir string + TemplateFile string GenPath string @@ -379,7 +373,7 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error pkg.ServiceInfo.ServiceName) f, err := comp.CompleteMethods() if err != nil { - if errors.Is(err, errNoNewMethod) { + if err == errNoNewMethod { return fs, nil } return nil, err diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 45560251d0..01c0d1fe83 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -107,19 +107,7 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } - if len(conv.Config.TemplateFiles) > 0 { - if len(conv.Services) == 0 { - return conv.failResp(errors.New("no service defined in the IDL")) - } - conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] - fs, err := gen.RenderWithMultipleFiles(&conv.Package) - if err != nil { - return conv.failResp(err) - } - files = append(files, fs...) - } - - if conv.Config.RenderTplDir != "" { + if conv.Config.TplDir != "" { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 35fe4b609d..2d5e1e596e 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -1,25 +1,9 @@ -// Copyright 2024 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package util import ( - "errors" - "fmt" - "io" + "flag" + "github.com/cloudwego/kitex/tool/internal_pkg/log" "os" - "strings" ) type Command struct { @@ -29,219 +13,85 @@ type Command struct { RunE func(cmd *Command, args []string) error commands []*Command parent *Command - flags *FlagSet - // helpFunc is help func defined by user. - helpFunc func(*Command, []string) - // for debug - args []string -} - -// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden -// particularly useful when testing. -func (c *Command) SetArgs(a []string) { - c.args = a + flags *flag.FlagSet } -func (c *Command) AddCommand(cmds ...*Command) error { +func (c *Command) AddCommand(cmds ...*Command) { for i, x := range cmds { if cmds[i] == c { - return fmt.Errorf("command can't be a child of itself") + panic("Command can't be a child of itself") } cmds[i].parent = c c.commands = append(c.commands, x) } - return nil } // Flags returns the FlagSet of the Command -func (c *Command) Flags() *FlagSet { - if c.flags == nil { - c.flags = NewFlagSet(c.Use, ContinueOnError) - } +func (c *Command) Flags() *flag.FlagSet { return c.flags } -// HasParent determines if the command is a child command. -func (c *Command) HasParent() bool { - return c.parent != nil -} - -// HasSubCommands determines if the command has children commands. -func (c *Command) HasSubCommands() bool { - return len(c.commands) > 0 -} - -func stripFlags(args []string) []string { - commands := make([]string, 0) - for len(args) > 0 { - s := args[0] - args = args[1:] - if strings.HasPrefix(s, "-") { - // handle "-f child child" args - if len(args) <= 1 { - break - } else { - args = args[1:] - continue - } - } else if s != "" && !strings.HasPrefix(s, "-") { - commands = append(commands, s) - } - } - return commands -} - -func (c *Command) findNext(next string) *Command { +// PrintUsage prints the usage of the Command +func (c *Command) PrintUsage() { + log.Warn("Usage: %s\n\n%s\n\n", c.Use, c.Long) + c.flags.PrintDefaults() for _, cmd := range c.commands { - if cmd.Use == next { - return cmd - } + log.Warnf(" %s: %s\n", cmd.Use, cmd.Short) } - return nil } -func nextArgs(args []string, x string) []string { - if len(args) == 0 { - return args - } - for pos := 0; pos < len(args); pos++ { - s := args[pos] - switch { - case strings.HasPrefix(s, "-"): - pos++ - continue - case !strings.HasPrefix(s, "-"): - if s == x { - // cannot use var ret []string cause it return nil - ret := make([]string, 0) - ret = append(ret, args[:pos]...) - ret = append(ret, args[pos+1:]...) - return ret - } - } - } - return args +// Parent returns a commands parent command. +func (c *Command) Parent() *Command { + return c.parent } -func validateArgs(cmd *Command, args []string) error { - // no subcommand, always take args - if !cmd.HasSubCommands() { - return nil - } +// HasParent determines if the command is a child command. +func (c *Command) HasParent() bool { + return c.parent != nil +} - // root command with subcommands, do subcommand checking. - if !cmd.HasParent() && len(args) > 0 { - return fmt.Errorf("unknown command %q", args[0]) +// Root finds root command. +func (c *Command) Root() *Command { + if c.HasParent() { + return c.Parent().Root() } - return nil + return c } -// Find the target command given the args and command tree func (c *Command) Find(args []string) (*Command, []string, error) { - var innerFind func(*Command, []string) (*Command, []string) - - innerFind = func(c *Command, innerArgs []string) (*Command, []string) { - argsWithoutFlags := stripFlags(innerArgs) - if len(argsWithoutFlags) == 0 { - return c, innerArgs - } - nextSubCmd := argsWithoutFlags[0] - - cmd := c.findNext(nextSubCmd) - if cmd != nil { - return innerFind(cmd, nextArgs(innerArgs, nextSubCmd)) - } - return c, innerArgs + if len(args) == 0 { + return c, args, nil } - commandFound, a := innerFind(c, args) - return commandFound, a, validateArgs(commandFound, stripFlags(a)) -} - -// ParseFlags parses persistent flag tree and local flags. -func (c *Command) ParseFlags(args []string) error { - err := c.Flags().Parse(args) - return err -} -// SetHelpFunc sets help function. Can be defined by Application. -func (c *Command) SetHelpFunc(f func(*Command, []string)) { - c.helpFunc = f -} - -// HelpFunc returns either the function set by SetHelpFunc for this command -// or a parent, or it returns a function with default help behavior. -func (c *Command) HelpFunc() func(*Command, []string) { - if c.helpFunc != nil { - return c.helpFunc - } - if c.HasParent() { - return c.parent.HelpFunc() + for _, cmd := range c.commands { + if cmd.Use == args[0] { + return cmd.Find(args[1:]) + } } - return nil -} - -// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. -func (c *Command) PrintErrln(i ...interface{}) { - c.PrintErr(fmt.Sprintln(i...)) -} - -// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. -func (c *Command) PrintErr(i ...interface{}) { - fmt.Fprint(c.ErrOrStderr(), i...) -} -// ErrOrStderr returns output to stderr -func (c *Command) ErrOrStderr() io.Writer { - return c.getErr(os.Stderr) -} - -func (c *Command) getErr(def io.Writer) io.Writer { - if c.HasParent() { - return c.parent.getErr(def) - } - return def + return c, args, nil } // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { - args := c.args - if c.args == nil { - args = os.Args[1:] + args := os.Args[2:] + // Regardless of what command execute is called on, run on Root only + if c.HasParent() { + return c.Root().ExecuteC() } + cmd, flags, err := c.Find(args) if err != nil { + log.Warn(err) return c, err } - err = cmd.execute(flags) - if err != nil { - // Always show help if requested, even if SilenceErrors is in - // effect - if errors.Is(err, ErrHelp) { - cmd.HelpFunc()(cmd, args) - return cmd, nil - } - } - //if err != nil { - // cmd.usage() - //} - - return cmd, err -} -func (c *Command) execute(a []string) error { - if c == nil { - return fmt.Errorf("called Execute() on a nil Command") - } - err := c.ParseFlags(a) - if err != nil { - return err - } - argWoFlags := c.Flags().Args() - if c.RunE != nil { - err := c.RunE(c, argWoFlags) + if cmd.RunE != nil { + err = cmd.RunE(cmd, flags) if err != nil { - return err + log.Warn(err) } } - return nil + + return cmd, nil } From 4ca56a43460e78938918963909a834e5b0924044 Mon Sep 17 00:00:00 2001 From: shawn Date: Fri, 19 Jul 2024 16:45:29 +0800 Subject: [PATCH 33/41] bug fixes: fix render Signed-off-by: shawn bug fixes: fix bug in generator_test Signed-off-by: shawn bug fixes: add gen-path Signed-off-by: shawn bug fixes: treat option as unknown command Signed-off-by: shawn bug fix: fix unused code in command and flag Signed-off-by: shawn bug fixes: fix lint Signed-off-by: shawn fix: use gofumpt -extra to fix golanci lint Signed-off-by: shawn fix: use gofumpt to avoid golangci lint Signed-off-by: shawn add InitType perf: custom allocator for fast codec ReadString/ReadBinary (#1427) chore: remove useless reflection api (#1433) optimize(lb): rebalance when instance weights updated (#1397) fix: support setting PurePayload Transport Protocol (#1436) fix: fix a bug "unknown service xxx" when using CombineService client by not writing IDLServiceName and searching serviceInfo by method name (#1402) fix: support setting PurePayload with new style (#1438) refactor(multi-services): refactoring service and method routing for multi-services (#1439) refactor: move apache code to separated pkg (#1381) Co-authored-by: QihengZhou Co-authored-by: Yi Duan refactor(generic): refactor existing generic to have new ServiceInfo which has the generic's reader and writer info directly (#1408) chore: upgrade go directive version to 1.17 of go.mod (#1415) refactor: get rid of apache TApplicationException (#1389) feat(generic): support grpc json generic for client (#1411) feat: add PrependError for thriftgo (#1420) feat(thrift): generic fastcodec (#1424) feat(tool): no apache for fastcodec (#1426) * use bthrift.PrependError * updated kitex tool MinDepVersion to 0.11.0 * fixed `undefined: bthrift.KitexUnusedProtection` chore: fixed undefined KitexUnusedProtection (#1428) test: works without apache code (#1429) chore: update CI version and readme community (#1431) refactor: new generic interface without thrift apache (#1434) fix(generic): fix payload length check for http generic (#1442) chore: pick and fix conflict commits from develop branch (#1457) Co-authored-by: Jayant chore: update dependency (#1447) chore(ci): disable cache for lint and staticchecks (#1451) refactor(test): perf optimize and log loc correct (#1455) chore(ci): speed up multiple ci processes 8min -> 1min (#1454) * rm unused codeconv, it didn't work as expected due to quota and user experience * don't use cache for self-hosted runners * cache for github hosted runners hotfix: multi service registry issue chore(ci): optimized bench tests. it takes <1m now (#1461) chore(test): fix xorshift64 in consist_test.go (#1462) chore: fix grpc keepalive test by start server responsiblly (#1463) refactor: deprecate bthrift, use cloudwego/gopkg (#1441) * tool: generates code for using cloudwego/gopkg * internal/mocks: updated thrift using the latest tool * pkg/utils/fastthrift: moved to cloudwego/gopkg * pkg/remote/codec/thrift: uses skipdecoder of cloudwego/gopkg * pkg/remote/codec/thrift: add fastcodec as a fallback, and always use fastcodec for Ex * pkg/generic: uses cloudwego/gopkg for pkg/protocol/bthrift: * type BinaryWriter = gopkgthrift.NocopyWriter * Removed ThriftFastCodec, moved to cloudwego/gopkg before releasing * Removed bthrift/exception.go, moved to cloudwego/gopkg before releasing * Removed bthrift/test, only for unknownfields testing * bthrift/unknown: moved to cloudwego/gopkg refactor(generic): remove apache thrift.TProtocol from generic (#1450) --- internal/mocks/thrift/utils.go | 70 +++ pkg/generic/map_test/generic_init.go | 8 +- pkg/generic/reflect_test/reflect_test.go | 1 + .../bthrift/apache/application_exception.go | 39 ++ .../bthrift/apache/binary_protocol.go | 23 + pkg/protocol/bthrift/apache/memory_buffer.go | 57 +++ pkg/protocol/bthrift/apache/messagetype.go | 32 ++ pkg/protocol/bthrift/apache/protocol.go | 33 ++ .../bthrift/apache/protocol_exception.go | 33 ++ pkg/protocol/bthrift/apache/serializer.go | 24 + pkg/protocol/bthrift/apache/transport.go | 23 + pkg/protocol/bthrift/apache/type.go | 43 ++ pkg/protocol/bthrift/binary_test.go | 41 +- pkg/protocol/bthrift/interface.go | 5 - tool/cmd/kitex/args/args.go | 7 +- tool/cmd/kitex/args/tpl_args.go | 218 +++------ tool/cmd/kitex/main.go | 5 +- .../internal_pkg/generator/custom_template.go | 2 +- tool/internal_pkg/generator/generator.go | 3 +- tool/internal_pkg/generator/generator_test.go | 11 +- .../pluginmode/thriftgo/plugin.go | 2 +- tool/internal_pkg/util/command.go | 168 +++++-- tool/internal_pkg/util/command_test.go | 238 ++-------- tool/internal_pkg/util/flag.go | 440 +++++------------- 24 files changed, 753 insertions(+), 773 deletions(-) create mode 100644 internal/mocks/thrift/utils.go create mode 100644 pkg/protocol/bthrift/apache/application_exception.go create mode 100644 pkg/protocol/bthrift/apache/binary_protocol.go create mode 100644 pkg/protocol/bthrift/apache/memory_buffer.go create mode 100644 pkg/protocol/bthrift/apache/messagetype.go create mode 100644 pkg/protocol/bthrift/apache/protocol.go create mode 100644 pkg/protocol/bthrift/apache/protocol_exception.go create mode 100644 pkg/protocol/bthrift/apache/serializer.go create mode 100644 pkg/protocol/bthrift/apache/transport.go create mode 100644 pkg/protocol/bthrift/apache/type.go diff --git a/internal/mocks/thrift/utils.go b/internal/mocks/thrift/utils.go new file mode 100644 index 0000000000..5d080bc513 --- /dev/null +++ b/internal/mocks/thrift/utils.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "io" + + "github.com/cloudwego/gopkg/protocol/thrift" + + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +) + +// ApacheCodecAdapter converts a fastcodec struct to apache codec +type ApacheCodecAdapter struct { + p thrift.FastCodec +} + +// Write implements athrift.TStruct +func (p ApacheCodecAdapter) Write(tp athrift.TProtocol) error { + b := make([]byte, p.p.BLength()) + b = b[:p.p.FastWriteNocopy(b, nil)] + _, err := tp.Transport().Write(b) + return err +} + +// Read implements athrift.TStruct +func (p ApacheCodecAdapter) Read(tp athrift.TProtocol) error { + var err error + var b []byte + trans := tp.Transport() + n := trans.RemainingBytes() + if int64(n) < 0 { + return errors.New("unknown buffer len") + } + b = make([]byte, n) + _, err = io.ReadFull(trans, b) + if err == nil { + _, err = p.p.FastRead(b) + } + return err +} + +// ToApacheCodec converts a thrift.FastCodec to athrift.TStruct +func ToApacheCodec(p thrift.FastCodec) athrift.TStruct { + return ApacheCodecAdapter{p: p} +} + +// UnpackApacheCodec unpacks the value returned by `ToApacheCodec` +func UnpackApacheCodec(v interface{}) interface{} { + a, ok := v.(ApacheCodecAdapter) + if ok { + return a.p + } + return v +} diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index 2b92a303a3..7ccdc17499 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -205,16 +205,16 @@ func serviceInfo() *serviceinfo.ServiceInfo { } func newMockTestArgs() interface{} { - return kt.NewMockTestArgs() + return kt.ToApacheCodec(kt.NewMockTestArgs()) } func newMockTestResult() interface{} { - return kt.NewMockTestResult() + return kt.ToApacheCodec(kt.NewMockTestResult()) } func testHandler(ctx context.Context, handler, arg, result interface{}) error { - realArg := arg.(*kt.MockTestArgs) - realResult := result.(*kt.MockTestResult) + realArg := kt.UnpackApacheCodec(arg).(*kt.MockTestArgs) + realResult := kt.UnpackApacheCodec(result).(*kt.MockTestResult) success, err := handler.(kt.Mock).Test(ctx, realArg.Req) if err != nil { return err diff --git a/pkg/generic/reflect_test/reflect_test.go b/pkg/generic/reflect_test/reflect_test.go index 714cebf2c4..6974bc6762 100644 --- a/pkg/generic/reflect_test/reflect_test.go +++ b/pkg/generic/reflect_test/reflect_test.go @@ -34,6 +34,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/klog" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" diff --git a/pkg/protocol/bthrift/apache/application_exception.go b/pkg/protocol/bthrift/apache/application_exception.go new file mode 100644 index 0000000000..ac02e24cef --- /dev/null +++ b/pkg/protocol/bthrift/apache/application_exception.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/application_exception.go + +const ( + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 +) + +type TApplicationException = thrift.TApplicationException + +var NewTApplicationException = thrift.NewTApplicationException diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go new file mode 100644 index 0000000000..2a2a4538b2 --- /dev/null +++ b/pkg/protocol/bthrift/apache/binary_protocol.go @@ -0,0 +1,23 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/binary_protocol.go + +var NewTBinaryProtocol = thrift.NewTBinaryProtocol diff --git a/pkg/protocol/bthrift/apache/memory_buffer.go b/pkg/protocol/bthrift/apache/memory_buffer.go new file mode 100644 index 0000000000..10a0af751f --- /dev/null +++ b/pkg/protocol/bthrift/apache/memory_buffer.go @@ -0,0 +1,57 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "bytes" + "context" +) + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/memory_buffer.go + +// Memory buffer-based implementation of the TTransport interface. +type TMemoryBuffer struct { + *bytes.Buffer + size int +} + +func NewTMemoryBufferLen(size int) *TMemoryBuffer { + buf := make([]byte, 0, size) + return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} +} + +func (p *TMemoryBuffer) IsOpen() bool { + return true +} + +func (p *TMemoryBuffer) Open() error { + return nil +} + +func (p *TMemoryBuffer) Close() error { + p.Buffer.Reset() + return nil +} + +// Flushing a memory buffer is a no-op +func (p *TMemoryBuffer) Flush(ctx context.Context) error { + return nil +} + +func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { + return uint64(p.Buffer.Len()) +} diff --git a/pkg/protocol/bthrift/apache/messagetype.go b/pkg/protocol/bthrift/apache/messagetype.go new file mode 100644 index 0000000000..1885144aee --- /dev/null +++ b/pkg/protocol/bthrift/apache/messagetype.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/messagetype.go + +// Message type constants in the Thrift protocol. +type TMessageType = thrift.TMessageType + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) diff --git a/pkg/protocol/bthrift/apache/protocol.go b/pkg/protocol/bthrift/apache/protocol.go new file mode 100644 index 0000000000..9d0a991d96 --- /dev/null +++ b/pkg/protocol/bthrift/apache/protocol.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol.go + +const ( + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 +) + +type TProtocol = thrift.TProtocol + +// The maximum recursive depth the skip() function will traverse +const DEFAULT_RECURSION_DEPTH = 64 + +var SkipDefaultDepth = thrift.SkipDefaultDepth diff --git a/pkg/protocol/bthrift/apache/protocol_exception.go b/pkg/protocol/bthrift/apache/protocol_exception.go new file mode 100644 index 0000000000..7b020797f5 --- /dev/null +++ b/pkg/protocol/bthrift/apache/protocol_exception.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol_exception.go + +var NewTProtocolExceptionWithType = thrift.NewTProtocolExceptionWithType + +const ( + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) diff --git a/pkg/protocol/bthrift/apache/serializer.go b/pkg/protocol/bthrift/apache/serializer.go new file mode 100644 index 0000000000..c255250301 --- /dev/null +++ b/pkg/protocol/bthrift/apache/serializer.go @@ -0,0 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/serializer.go + +type TStruct interface { + Write(p TProtocol) error + Read(p TProtocol) error +} diff --git a/pkg/protocol/bthrift/apache/transport.go b/pkg/protocol/bthrift/apache/transport.go new file mode 100644 index 0000000000..25a752ae52 --- /dev/null +++ b/pkg/protocol/bthrift/apache/transport.go @@ -0,0 +1,23 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/transport.go + +type TTransport = thrift.TTransport diff --git a/pkg/protocol/bthrift/apache/type.go b/pkg/protocol/bthrift/apache/type.go new file mode 100644 index 0000000000..42533b085e --- /dev/null +++ b/pkg/protocol/bthrift/apache/type.go @@ -0,0 +1,43 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/type.go + +type TType = thrift.TType + +const ( + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 +) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 25382f95c6..a0754bcd55 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -21,8 +21,8 @@ import ( "fmt" "testing" + "github.com/cloudwego/kitex/internal/test" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/protocol/bthrift/internal/test" ) // TestWriteMessageEnd test binary WriteMessageEnd function @@ -291,24 +291,6 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } -// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator -func TestWriteAndReadStringWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - wn := Binary.WriteString(buf, "kitex") - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadString(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - test.Assert(t, v == "kitex") - SetSpanCache(false) -} - // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -328,27 +310,6 @@ func TestWriteAndReadBinary(t *testing.T) { } } -// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator -func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - val := []byte("kitex") - wn := Binary.WriteBinary(buf, val) - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadBinary(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - for i := 0; i < len(v); i++ { - test.Assert(t, val[i] == v[i]) - } - SetSpanCache(false) -} - // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index e65d667318..75fa0ce951 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -97,8 +97,3 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } - -type Allocator interface { - Make(n int) []byte - Copy(buf []byte) (p []byte) -} diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index f57526fb3e..0e8e9e9300 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,20 +212,17 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" && a.TplDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" && a.RenderTplDir == "" { if a.Use != "" { return fmt.Errorf("-use must be used with -service or -template-dir or template render") } } - if a.TemplateDir != "" && a.TplDir != "" { + if a.TemplateDir != "" && a.RenderTplDir != "" { return fmt.Errorf("template render and -template-dir cannot be used at the same time") } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } - if a.ServiceName != "" && a.TplDir != "" { - return fmt.Errorf("template render and -service cannot be used at the same time") - } if a.ServiceName != "" { a.GenerateMain = true } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 4a00a3c6be..cadb49df82 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -1,153 +1,42 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package args import ( - "flag" "fmt" + "os" + "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/util" - "os" - "strings" ) -func (a *Arguments) addBasicFlags(f *flag.FlagSet, version string) *flag.FlagSet { - f.BoolVar(&a.NoFastAPI, "no-fast-api", false, - "Generate codes without injecting fast method.") - f.StringVar(&a.ModuleName, "module", "", - "Specify the Go module name to generate go.mod.") - f.StringVar(&a.ServiceName, "service", "", - "Specify the service name to generate server side codes.") - f.StringVar(&a.Use, "use", "", - "Specify the kitex_gen package to import when generate server side codes.") - f.BoolVar(&a.Verbose, "v", false, "") // short for -verbose - f.BoolVar(&a.Verbose, "verbose", false, - "Turn on verbose mode.") - f.BoolVar(&a.GenerateInvoker, "invoker", false, - "Generate invoker side codes when service name is specified.") - f.StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") - f.Var(&a.Includes, "I", "Add an IDL search path for includes.") - f.Var(&a.ThriftOptions, "thrift", "Specify arguments for the thrift go compiler.") - f.Var(&a.Hessian2Options, "hessian2", "Specify arguments for the hessian2 codec.") - f.DurationVar(&a.ThriftPluginTimeLimit, "thrift-plugin-time-limit", generator.DefaultThriftPluginTimeLimit, "Specify thrift plugin execution time limit.") - f.StringVar(&a.CompilerPath, "compiler-path", "", "Specify the path of thriftgo/protoc.") - f.Var(&a.ThriftPlugins, "thrift-plugin", "Specify thrift plugin arguments for the thrift compiler.") - f.Var(&a.ProtobufOptions, "protobuf", "Specify arguments for the protobuf compiler.") - f.Var(&a.ProtobufPlugins, "protobuf-plugin", "Specify protobuf plugin arguments for the protobuf compiler.(plugin_name:options:out_dir)") - f.BoolVar(&a.CombineService, "combine-service", false, - "Combine services in root thrift file.") - f.BoolVar(&a.CopyIDL, "copy-idl", false, - "Copy each IDL file to the output path.") - f.BoolVar(&a.HandlerReturnKeepResp, "handler-return-keep-resp", false, - "When the server-side handler returns both err and resp, the resp return is retained for use in middleware where both err and resp can be used simultaneously. Note: At the RPC communication level, if the handler returns an err, the framework still only returns err to the client without resp.") - f.StringVar(&a.ExtensionFile, "template-extension", a.ExtensionFile, - "Specify a file for template extension.") - f.BoolVar(&a.FrugalPretouch, "frugal-pretouch", false, - "Use frugal to compile arguments and results when new clients and servers.") - f.BoolVar(&a.Record, "record", false, - "Record Kitex cmd into kitex-all.sh.") - f.StringVar(&a.TemplateDir, "template-dir", "", - "Use custom template to generate codes.") - f.StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, - "Specify a code gen path.") - f.BoolVar(&a.DeepCopyAPI, "deep-copy-api", false, - "Generate codes with injecting deep copy method.") - f.StringVar(&a.Protocol, "protocol", "", - "Specify a protocol for codec") - f.BoolVar(&a.NoDependencyCheck, "no-dependency-check", false, - "Skip dependency checking.") - a.RecordCmd = os.Args - a.Version = version - a.ThriftOptions = append(a.ThriftOptions, - "naming_style=golint", - "ignore_initialisms", - "gen_setter", - "gen_deep_equal", - "compatible_names", - "frugal_tag", - "thrift_streaming", - "no_processor", - ) - - for _, e := range a.extends { - e.Apply(f) - } - return f -} - -func (a *Arguments) buildInitFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("init", flag.ContinueOnError) - f.StringVar(&a.InitOutputDir, "o", ".", "Specify template init path (default current directory)") - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template init [flags] - -Examples: - %s template init -o /path/to/output - %s template init - -Flags: -`, version, os.Args[0], os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - -func (a *Arguments) buildRenderFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("render", flag.ContinueOnError) - f.StringVar(&a.TemplateFile, "f", "", "Specify template init path") - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template render [template dir_path] [flags] IDL - -Examples: - %s template render ${template dir_path} -module ${module_name} idl/hello.thrift - %s template render ${template dir_path} -f service.go.tpl -module ${module_name} idl/hello.thrift - %s template render ${template dir_path} -module ${module_name} -I xxx.git idl/hello.thrift - -Flags: -`, version, os.Args[0], os.Args[0], os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - -func (a *Arguments) buildCleanFlags(version string) *flag.FlagSet { - f := flag.NewFlagSet("clean", flag.ContinueOnError) - f = a.addBasicFlags(f, version) - f.Usage = func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: %s template clean - -Examples: - %s template clean - -Flags: -`, version, os.Args[0], os.Args[0]) - f.PrintDefaults() - } - return f -} - func (a *Arguments) TemplateArgs(version, curpath string) error { + kitexCmd := &util.Command{ + Use: "kitex", + Short: "Kitex command", + } templateCmd := &util.Command{ Use: "template", Short: "Template command", - RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Template command executed") - return nil - }, } initCmd := &util.Command{ Use: "init", Short: "Init command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Init command executed") - f := a.buildInitFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -157,7 +46,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()) if err != nil { return err } @@ -176,27 +65,19 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { Use: "render", Short: "Render command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Render command executed") if len(args) > 0 { - a.TplDir = args[0] + a.RenderTplDir = args[0] } var tplDir string - for i, arg := range args { + for _, arg := range args { if !strings.HasPrefix(arg, "-") { tplDir = arg - args = append(args[:i], args[i+1:]...) break } } if tplDir == "" { - cmd.PrintUsage() return fmt.Errorf("template directory is required") } - - f := a.buildRenderFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -206,7 +87,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()[1:]) if err != nil { return err } @@ -225,11 +106,6 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { Use: "clean", Short: "Clean command", RunE: func(cmd *util.Command, args []string) error { - fmt.Println("Clean command executed") - f := a.buildCleanFlags(version) - if err := f.Parse(args); err != nil { - return err - } log.Verbose = a.Verbose for _, e := range a.extends { @@ -239,7 +115,7 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { } } - err := a.checkIDL(f.Args()) + err := a.checkIDL(cmd.Flags().Args()) if err != nil { return err } @@ -254,11 +130,45 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { return a.checkPath(curpath) }, } - templateCmd.AddCommand(initCmd) - templateCmd.AddCommand(renderCmd) - templateCmd.AddCommand(cleanCmd) + initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") + initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") + renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", + "Specify the Go module name to generate go.mod.") + renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + renderCmd.Flags().StringVarP(&a.TemplateFile, "file", "f", "", "Specify single template path") + renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") + initCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: kitex template init [flags] + +Examples: + kitex template init -o /path/to/output + kitex template init - if _, err := templateCmd.ExecuteC(); err != nil { +Flags: +`, version) + }) + renderCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: template render [template dir_path] [flags] IDL +`, version) + }) + cleanCmd.SetUsageFunc(func() { + fmt.Fprintf(os.Stderr, `Version %s +Usage: kitex template clean + +Examples: + kitex template clean + +Flags: +`, version) + }) + // renderCmd.PrintUsage() + templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) + kitexCmd.AddCommand(templateCmd) + if _, err := kitexCmd.ExecuteC(); err != nil { return err } return nil diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index c03869d81c..ec931c6506 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -51,7 +51,7 @@ func init() { if err := versions.RegisterMinDepVersion( &versions.MinDepVersion{ RefPath: "github.com/cloudwego/kitex", - Version: "v0.9.0", + Version: "v0.11.0", }, ); err != nil { log.Warn(err) @@ -78,6 +78,9 @@ func main() { } if os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version, curpath) + } else if !strings.HasPrefix(os.Args[1], "-") { + log.Warnf("Unknown command %q", os.Args[1]) + os.Exit(1) } else { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 82efd8eb91..48dc283133 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -240,7 +240,7 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, g.setImports(HandlerFileName, pkg) var tpls []*Template - tpls, err = readTpls(g.TplDir, g.TplDir, tpls) + tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) if err != nil { return nil, err } diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 92f4f8006d..fe95db4bbc 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -136,7 +136,8 @@ type Config struct { // subcommand template InitOutputDir string - TplDir string + InitType string + RenderTplDir string TemplateFile string GenPath string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 0cba0caca8..cc39bffc31 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -59,10 +59,7 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir string InitType string RenderTplDir string - TemplateFiles []string - DebugTpl bool - IncludesTpl string - MetaFlags string + TemplateFile string Protocol string HandlerReturnKeepResp bool } @@ -76,7 +73,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -107,9 +104,7 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir: tt.fields.InitOutputDir, InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, - TemplateFiles: tt.fields.TemplateFiles, - IncludesTpl: tt.fields.IncludesTpl, - MetaFlags: tt.fields.MetaFlags, + TemplateFile: tt.fields.TemplateFile, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 01c0d1fe83..e65c3e179e 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -107,7 +107,7 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } - if conv.Config.TplDir != "" { + if conv.Config.RenderTplDir != "" { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 2d5e1e596e..32e1a3ad55 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -1,9 +1,23 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package util import ( - "flag" - "github.com/cloudwego/kitex/tool/internal_pkg/log" + "fmt" "os" + "strings" ) type Command struct { @@ -13,31 +27,28 @@ type Command struct { RunE func(cmd *Command, args []string) error commands []*Command parent *Command - flags *flag.FlagSet + flags *FlagSet + // helpFunc is help func defined by user. + usage func() } -func (c *Command) AddCommand(cmds ...*Command) { +func (c *Command) AddCommand(cmds ...*Command) error { for i, x := range cmds { if cmds[i] == c { - panic("Command can't be a child of itself") + return fmt.Errorf("command can't be a child of itself") } cmds[i].parent = c c.commands = append(c.commands, x) } + return nil } // Flags returns the FlagSet of the Command -func (c *Command) Flags() *flag.FlagSet { - return c.flags -} - -// PrintUsage prints the usage of the Command -func (c *Command) PrintUsage() { - log.Warn("Usage: %s\n\n%s\n\n", c.Use, c.Long) - c.flags.PrintDefaults() - for _, cmd := range c.commands { - log.Warnf(" %s: %s\n", cmd.Use, cmd.Short) +func (c *Command) Flags() *FlagSet { + if c.flags == nil { + c.flags = NewFlagSet(c.Use, ContinueOnError) } + return c.flags } // Parent returns a commands parent command. @@ -58,40 +69,133 @@ func (c *Command) Root() *Command { return c } -func (c *Command) Find(args []string) (*Command, []string, error) { - if len(args) == 0 { - return c, args, nil +// HasSubCommands determines if the command has children commands. +func (c *Command) HasSubCommands() bool { + return len(c.commands) > 0 +} + +func stripFlags(args []string) []string { + commands := make([]string, 0) + for len(args) > 0 { + s := args[0] + args = args[1:] + if s != "" && !strings.HasPrefix(s, "-") { + commands = append(commands, s) + } else if strings.HasPrefix(s, "-") { + break + } } + return commands +} +func (c *Command) findNext(next string) *Command { for _, cmd := range c.commands { - if cmd.Use == args[0] { - return cmd.Find(args[1:]) + if cmd.Use == next { + return cmd + } + } + return nil +} + +func nextArgs(args []string, x string) []string { + if len(args) == 0 { + return args + } + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case strings.HasPrefix(s, "-"): + pos++ + continue + case !strings.HasPrefix(s, "-"): + if s == x { + var ret []string + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } + } + } + return args +} + +func validateArgs(cmd *Command, args []string) error { + // no subcommand, always take args + if !cmd.HasSubCommands() { + return nil + } + + // root command with subcommands, do subcommand checking. + if !cmd.HasParent() && len(args) > 0 { + return fmt.Errorf("unknown command %q", args[0]) + } + return nil +} + +// Find the target command given the args and command tree +func (c *Command) Find(args []string) (*Command, []string, error) { + var innerFind func(*Command, []string) (*Command, []string) + + innerFind = func(c *Command, innerArgs []string) (*Command, []string) { + argsWithoutFlags := stripFlags(innerArgs) + if len(argsWithoutFlags) == 0 { + return c, innerArgs } + nextSubCmd := argsWithoutFlags[0] + + cmd := c.findNext(nextSubCmd) + if cmd != nil { + return innerFind(cmd, nextArgs(innerArgs, nextSubCmd)) + } + return c, innerArgs } + commandFound, a := innerFind(c, args) + return commandFound, a, validateArgs(commandFound, stripFlags(a)) +} - return c, args, nil +// ParseFlags parses persistent flag tree and local flags. +func (c *Command) ParseFlags(args []string) error { + err := c.Flags().Parse(args) + return err +} + +func (c *Command) SetUsageFunc(f func()) { + c.usage = f +} + +func (c *Command) UsageFunc() func() { + return c.usage } // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { - args := os.Args[2:] - // Regardless of what command execute is called on, run on Root only - if c.HasParent() { - return c.Root().ExecuteC() - } + args := os.Args[1:] cmd, flags, err := c.Find(args) if err != nil { - log.Warn(err) return c, err } + err = cmd.execute(flags) + if err != nil { + cmd.usage() + } - if cmd.RunE != nil { - err = cmd.RunE(cmd, flags) + return cmd, err +} + +func (c *Command) execute(a []string) error { + if c == nil { + return fmt.Errorf("called Execute() on a nil Command") + } + err := c.ParseFlags(a) + if err != nil { + return err + } + if c.RunE != nil { + err := c.RunE(c, a) if err != nil { - log.Warn(err) + return err } } - - return cmd, nil + return nil } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go index ae83875b52..a5a4766850 100644 --- a/tool/internal_pkg/util/command_test.go +++ b/tool/internal_pkg/util/command_test.go @@ -15,231 +15,75 @@ package util import ( - "fmt" - "reflect" - "strings" + "os" "testing" ) -func emptyRun(*Command, []string) error { return nil } +func TestAddCommand(t *testing.T) { + rootCmd := &Command{Use: "root"} + childCmd := &Command{Use: "child"} -func executeCommand(root *Command, args ...string) (err error) { - _, err = executeCommandC(root, args...) - return err -} - -func executeCommandC(root *Command, args ...string) (c *Command, err error) { - root.SetArgs(args) - c, err = root.ExecuteC() - return c, err -} - -const onetwo = "one two" - -func TestSingleCommand(t *testing.T) { - rootCmd := &Command{ - Use: "root", - RunE: func(_ *Command, args []string) error { return nil }, - } - aCmd := &Command{Use: "a", RunE: emptyRun} - bCmd := &Command{Use: "b", RunE: emptyRun} - rootCmd.AddCommand(aCmd, bCmd) - - _ = executeCommand(rootCmd, "one", "two") -} - -func TestChildCommand(t *testing.T) { - var child1CmdArgs []string - rootCmd := &Command{Use: "root", RunE: emptyRun} - child1Cmd := &Command{ - Use: "child1", - RunE: func(_ *Command, args []string) error { child1CmdArgs = args; return nil }, - } - child2Cmd := &Command{Use: "child2", RunE: emptyRun} - rootCmd.AddCommand(child1Cmd, child2Cmd) - - err := executeCommand(rootCmd, "child1", "one", "two") + // Test adding a valid child command + err := rootCmd.AddCommand(childCmd) if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("expected no error, got %v", err) } - got := strings.Join(child1CmdArgs, " ") - if got != onetwo { - t.Errorf("child1CmdArgs expected: %q, got: %q", onetwo, got) + if len(rootCmd.commands) != 1 { + t.Fatalf("expected 1 command, got %d", len(rootCmd.commands)) } -} -func TestCallCommandWithoutSubcommands(t *testing.T) { - rootCmd := &Command{Use: "root", RunE: emptyRun} - err := executeCommand(rootCmd) - if err != nil { - t.Errorf("Calling command without subcommands should not have error: %v", err) + if rootCmd.commands[0] != childCmd { + t.Fatalf("expected child command to be added") } -} - -func TestRootExecuteUnknownCommand(t *testing.T) { - rootCmd := &Command{Use: "root", RunE: emptyRun} - rootCmd.AddCommand(&Command{Use: "child", RunE: emptyRun}) - - _ = executeCommand(rootCmd, "unknown") -} -func TestSubcommandExecuteC(t *testing.T) { - rootCmd := &Command{Use: "root", RunE: emptyRun} - childCmd := &Command{Use: "child", RunE: emptyRun} - rootCmd.AddCommand(childCmd) + // Test adding a command to itself + err = rootCmd.AddCommand(rootCmd) + if err == nil { + t.Fatalf("expected an error, got nil") + } - _, err := executeCommandC(rootCmd, "child") - if err != nil { - t.Errorf("Unexpected error: %v", err) + expectedErr := "command can't be a child of itself" + if err.Error() != expectedErr { + t.Fatalf("expected error %q, got %q", expectedErr, err.Error()) } } -func TestFind(t *testing.T) { - var foo, bar string - root := &Command{ +func TestExecuteC(t *testing.T) { + rootCmd := &Command{ Use: "root", - } - root.Flags().StringVarP(&foo, "foo", "f", "", "") - root.Flags().StringVarP(&bar, "bar", "b", "something", "") - - child := &Command{ - Use: "child", - } - root.AddCommand(child) - - testCases := []struct { - args []string - expectedFoundArgs []string - }{ - { - []string{"child"}, - []string{}, - }, - { - []string{"child", "child"}, - []string{"child"}, - }, - { - []string{"child", "foo", "child", "bar", "child", "baz", "child"}, - []string{"foo", "child", "bar", "child", "baz", "child"}, - }, - { - []string{"-f", "child", "child"}, - []string{"-f", "child"}, + RunE: func(cmd *Command, args []string) error { + return nil }, - { - []string{"child", "-f", "child"}, - []string{"-f", "child"}, - }, - { - []string{"-b", "child", "child"}, - []string{"-b", "child"}, - }, - { - []string{"child", "-b", "child"}, - []string{"-b", "child"}, - }, - { - []string{"child", "-b"}, - []string{"-b"}, - }, - { - []string{"-b", "-f", "child", "child"}, - []string{"-b", "-f", "child"}, - }, - { - []string{"-f", "child", "-b", "something", "child"}, - []string{"-f", "child", "-b", "something"}, - }, - { - []string{"-f", "child", "child", "-b"}, - []string{"-f", "child", "-b"}, - }, - { - []string{"-f=child", "-b=something", "child"}, - []string{"-f=child", "-b=something"}, - }, - { - []string{"--foo", "child", "--bar", "something", "child"}, - []string{"--foo", "child", "--bar", "something"}, - }, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { - cmd, foundArgs, err := root.Find(tc.args) - if err != nil { - t.Fatal(err) - } - - if cmd != child { - t.Fatal("Expected cmd to be child, but it was not") - } - - if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) { - t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs) - } - }) - } -} - -func TestFlagLong(t *testing.T) { - var cArgs []string - c := &Command{ - Use: "c", - RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, - } - - var stringFlagValue string - c.Flags().StringVar(&stringFlagValue, "sf", "", "") - - err := executeCommand(c, "--sf=abc", "one", "--", "two") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if stringFlagValue != "abc" { - t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } - got := strings.Join(cArgs, " ") - if got != onetwo { - t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) + subCmd := &Command{ + Use: "sub", + RunE: func(cmd *Command, args []string) error { + return nil + }, } -} -func TestFlagShort(t *testing.T) { - var cArgs []string - c := &Command{ - Use: "c", - RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, - } + rootCmd.AddCommand(subCmd) - var stringFlagValue string - c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") + // Simulate command line arguments + os.Args = []string{"root", "sub"} - err := executeCommand(c, "-sabc", "one", "two") + // Execute the command + cmd, err := rootCmd.ExecuteC() if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("expected no error, got %v", err) } - if stringFlagValue != "abc" { - t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) + if cmd.Use != "sub" { + t.Fatalf("expected sub command to be executed, got %s", cmd.Use) } - got := strings.Join(cArgs, " ") - if got != onetwo { - t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) - } -} + // Simulate command line arguments with an unknown command + os.Args = []string{"root", "unknown"} -func TestChildFlag(t *testing.T) { - rootCmd := &Command{Use: "root", RunE: emptyRun} - childCmd := &Command{Use: "child", RunE: emptyRun} - rootCmd.AddCommand(childCmd) - err := executeCommand(rootCmd, "child") - if err != nil { - t.Errorf("Unexpected error: %v", err) + _, err = rootCmd.ExecuteC() + if err == nil { + t.Fatalf("expected an error for unknown command, got nil") } } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 3c936b72ef..3c0b681197 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -15,14 +15,11 @@ package util import ( - "bytes" - "encoding/csv" "errors" goflag "flag" "fmt" "io" "os" - "sort" "strconv" "strings" ) @@ -86,15 +83,17 @@ type FlagSet struct { // A Flag represents the state of a flag. type Flag struct { - Name string // name as it appears on command line - Shorthand string // one-letter abbreviated flag - Usage string // help message - Value Value // value as set - DefValue string // default value (as text); for usage message - Changed bool // If the user set the value (or if left to default) - NoOptDefVal string // default value (as text); if the flag is on the command line without any options - Deprecated string // If this flag is deprecated, this string is the new or now thing to use - ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string // default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use + Annotations map[string][]string // used by cobra.Command bash autocomple code } // Value is the interface to the dynamic value stored in a flag. @@ -102,7 +101,6 @@ type Flag struct { type Value interface { String() string Set(string) error - Type() string } // SliceValue is a secondary interface to all flags which hold a list @@ -117,20 +115,27 @@ type SliceValue interface { GetSlice() []string } -// sortFlags returns the flags as a slice in lexicographical sorted order. -func sortFlags(flags map[NormalizedName]*Flag) []*Flag { - list := make(sort.StringSlice, len(flags)) - i := 0 - for k := range flags { - list[i] = string(k) - i++ - } - list.Sort() - result := make([]*Flag, len(list)) - for i, name := range list { - result[i] = flags[NormalizedName(name)] +// SetNormalizeFunc allows you to add a function which can translate flag names. +// Flags added to the FlagSet will be translated and then when anything tries to +// look up the flag that will also be translated. So it would be possible to create +// a flag named "getURL" and have it translated to "geturl". A user could then pass +// "--getUrl" which may also be translated to "geturl" and everything will work. +func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { + f.normalizeNameFunc = n + f.sortedFormal = f.sortedFormal[:0] + for fname, flag := range f.formal { + nname := f.normalizeFlagName(flag.Name) + if fname == nname { + continue + } + flag.Name = string(nname) + delete(f.formal, fname) + f.formal[nname] = flag + if _, set := f.actual[fname]; set { + delete(f.actual, fname) + f.actual[nname] = flag + } } - return result } // GetNormalizeFunc returns the previously set NormalizeFunc of a function which @@ -147,16 +152,6 @@ func (f *FlagSet) normalizeFlagName(name string) NormalizedName { return n(f, name) } -// Lookup returns the Flag structure of the named flag, returning nil if none exists. -func (f *FlagSet) Lookup(name string) *Flag { - return f.lookup(f.normalizeFlagName(name)) -} - -// lookup returns the Flag structure of the named flag, returning nil if none exists. -func (f *FlagSet) lookup(name NormalizedName) *Flag { - return f.formal[name] -} - func (f *FlagSet) out() io.Writer { if f.output == nil { return os.Stderr @@ -205,200 +200,36 @@ func (f *FlagSet) Set(name, value string) error { return nil } -func (f *FlagSet) VisitAll(fn func(*Flag)) { - if len(f.formal) == 0 { - return - } - - var flags []*Flag - if f.SortFlags { - if len(f.formal) != len(f.sortedFormal) { - f.sortedFormal = sortFlags(f.formal) - } - flags = f.sortedFormal - } else { - flags = f.orderedFormal - } - - for _, flag := range flags { - fn(flag) - } -} - -func UnquoteUsage(flag *Flag) (name, usage string) { - // Look for a back-quoted name, but avoid the strings package. - usage = flag.Usage - for i := 0; i < len(usage); i++ { - if usage[i] == '`' { - for j := i + 1; j < len(usage); j++ { - if usage[j] == '`' { - name = usage[i+1 : j] - usage = usage[:i] + name + usage[j+1:] - return name, usage - } - } - break // Only one back quote; use type name. - } - } - - name = flag.Value.Type() - switch name { - case "bool": - name = "" - case "float64": - name = "float" - case "int64": - name = "int" - case "uint64": - name = "uint" - case "stringSlice": - name = "strings" - case "intSlice": - name = "ints" - case "uintSlice": - name = "uints" - case "boolSlice": - name = "bools" - } - - return -} - -func (f *FlagSet) FlagUsagesWrapped(cols int) string { - buf := new(bytes.Buffer) - - lines := make([]string, 0, len(f.formal)) - - maxlen := 0 - f.VisitAll(func(flag *Flag) { - line := "" - if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { - line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) - } else { - line = fmt.Sprintf(" --%s", flag.Name) - } - - varname, usage := UnquoteUsage(flag) - if varname != "" { - line += " " + varname - } - if flag.NoOptDefVal != "" { - switch flag.Value.Type() { - case "string": - line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) - case "bool": - if flag.NoOptDefVal != "true" { - line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) - } - case "count": - if flag.NoOptDefVal != "+1" { - line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) - } - default: - line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) - } - } - - // This special character will be replaced with spacing once the - // correct alignment is calculated - line += "\x00" - if len(line) > maxlen { - maxlen = len(line) - } - - line += usage - if len(flag.Deprecated) != 0 { - line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) - } - - lines = append(lines, line) - }) - - for _, line := range lines { - sidx := strings.Index(line, "\x00") - spacing := strings.Repeat(" ", maxlen-sidx) - // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx - fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) - } - - return buf.String() -} - -func wrap(i, w int, s string) string { - if w == 0 { - return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) - } - - // space between indent i and end of line width w into which - // we should wrap the text. - wrap := w - i - - var r, l string - - // Not enough space for sensible wrapping. Wrap as a block on - // the next line instead. - if wrap < 24 { - i = 16 - wrap = w - i - r += "\n" + strings.Repeat(" ", i) - } - // If still not enough space then don't even try to wrap. - if wrap < 24 { - return strings.Replace(s, "\n", r, -1) +// SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. +// This is sometimes used by spf13/cobra programs which want to generate additional +// bash completion information. +func (f *FlagSet) SetAnnotation(name, key string, values []string) error { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + return fmt.Errorf("no such flag -%v", name) } - - // Try to avoid short orphan words on the final line, by - // allowing wrapN to go a bit over if that would fit in the - // remainder of the line. - slop := 5 - wrap = wrap - slop - - // Handle first line, which is indented by the caller (or the - // special case above) - l, s = wrapN(wrap, slop, s) - r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) - - // Now wrap the rest - for s != "" { - var t string - - t, s = wrapN(wrap, slop, s) - r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + if flag.Annotations == nil { + flag.Annotations = map[string][]string{} } - - return r + flag.Annotations[key] = values + return nil } -func wrapN(i, slop int, s string) (string, string) { - if i+slop > len(s) { - return s, "" - } +// NFlag returns the number of flags that have been set. +func (f *FlagSet) NFlag() int { return len(f.actual) } - w := strings.LastIndexAny(s[:i], " \t\n") - if w <= 0 { - return s, "" +// Arg returns the i'th argument. Arg(0) is the first remaining argument +// after flags have been processed. +func (f *FlagSet) Arg(i int) string { + if i < 0 || i >= len(f.args) { + return "" } - nlPos := strings.LastIndex(s[:i], "\n") - if nlPos > 0 && nlPos < w { - return s[:nlPos], s[nlPos+1:] - } - return s[:w], s[w+1:] -} - -func (f *FlagSet) FlagUsages() string { - return f.FlagUsagesWrapped(0) + return f.args[i] } -func (f *FlagSet) PrintDefaults() { - usages := f.FlagUsages() - fmt.Fprint(f.out(), usages) -} - -// defaultUsage is the default function to print a usage message. -func defaultUsage(f *FlagSet) { - fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) - f.PrintDefaults() -} +// NArg is the number of arguments remaining after flags have been processed. +func (f *FlagSet) NArg() int { return len(f.args) } // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } @@ -469,7 +300,6 @@ func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) if f.errorHandling != ContinueOnError { fmt.Fprintln(f.out(), err) - f.usage() } return err } @@ -511,7 +341,6 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if !exists { switch { case name == "help": - f.usage() return a, ErrHelp case f.ParseErrorsWhitelist.UnknownFlags: // --unknown=unknownval arg ... @@ -565,7 +394,6 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse if !exists { switch { case c == 'h': - f.usage() err = ErrHelp return case f.ParseErrorsWhitelist.UnknownFlags: @@ -698,11 +526,60 @@ func (f *FlagSet) Parse(arguments []string) error { type parseFunc func(flag *Flag, value string) error +// ParseAll parses flag definitions from the argument list, which should not +// include the command name. The arguments for fn are flag and value. Must be +// called after all flags in the FlagSet are defined and before flags are +// accessed by the program. The return value will be ErrHelp if -help was set +// but not defined. +func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error { + f.parsed = true + f.args = make([]string, 0, len(arguments)) + + err := f.parseArgs(arguments, fn) + if err != nil { + switch f.errorHandling { + case ContinueOnError: + return err + case ExitOnError: + os.Exit(2) + case PanicOnError: + panic(err) + } + } + return nil +} + // Parsed reports whether f.Parse has been called. func (f *FlagSet) Parsed() bool { return f.parsed } +// Parse parses the command-line flags from os.Args[1:]. Must be called +// after all flags are defined and before flags are accessed by the program. +func Parse() { + // Ignore errors; CommandLine is set for ExitOnError. + CommandLine.Parse(os.Args[1:]) +} + +// ParseAll parses the command-line flags from os.Args[1:] and called fn for each. +// The arguments for fn are flag and value. Must be called after all flags are +// defined and before flags are accessed by the program. +func ParseAll(fn func(flag *Flag, value string) error) { + // Ignore errors; CommandLine is set for ExitOnError. + CommandLine.ParseAll(os.Args[1:], fn) +} + +// SetInterspersed sets whether to support interspersed option/non-option arguments. +func SetInterspersed(interspersed bool) { + CommandLine.SetInterspersed(interspersed) +} + +// Parsed returns true if the command-line flags have been parsed. +func Parsed() bool { + return CommandLine.Parsed() +} + +// CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) // NewFlagSet returns a new, empty flag set with the specified name, @@ -718,26 +595,18 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { return f } -// PrintDefaults prints to standard error the default values of all defined command-line flags. -func PrintDefaults() { - CommandLine.PrintDefaults() -} - -var Usage = func() { - fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - PrintDefaults() +// SetInterspersed sets whether to support interspersed option/non-option arguments. +func (f *FlagSet) SetInterspersed(interspersed bool) { + f.interspersed = interspersed } -// usage calls the Usage method for the flag set, or the usage function if -// the flag set is CommandLine. -func (f *FlagSet) usage() { - if f == CommandLine { - Usage() - } else if f.Usage == nil { - defaultUsage(f) - } else { - f.Usage() - } +// Init sets the name and error handling property for a flag set. +// By default, the zero FlagSet uses an empty name and the +// ContinueOnError error handling policy. +func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { + f.name = name + f.errorHandling = errorHandling + f.argsLenAtDash = -1 } // -- string Value @@ -753,12 +622,12 @@ func (s *stringValue) Set(val string) error { return nil } -func (s *stringValue) String() string { return string(*s) } - func (s *stringValue) Type() string { return "string" } +func (s *stringValue) String() string { return string(*s) } + // StringVar defines a string flag with specified name, default value, and usage string. func (f *FlagSet) StringVar(p *string, name, value, usage string) { f.VarP(newStringValue(value, p), name, "", usage) @@ -769,21 +638,6 @@ func (f *FlagSet) StringVarP(p *string, name, shorthand, value, usage string) { f.VarP(newStringValue(value, p), name, shorthand, usage) } -// String defines a string flag with specified name, default value, and usage string. -// The return value is the address of a string variable that stores the value of the flag. -func (f *FlagSet) String(name, value, usage string) *string { - p := new(string) - f.StringVarP(p, name, "", value, usage) - return p -} - -// StringP is like String, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) StringP(name, shorthand, value, usage string) *string { - p := new(string) - f.StringVarP(p, name, shorthand, value, usage) - return p -} - // -- bool Value type boolValue bool @@ -798,12 +652,12 @@ func (b *boolValue) Set(s string) error { return err } -func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } - func (b *boolValue) Type() string { return "bool" } +func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } + func (b *boolValue) IsBoolFlag() bool { return true } // BoolVar defines a bool flag with specified name, default value, and usage string. @@ -825,65 +679,3 @@ func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage st flag.NoOptDefVal = "true" return nil } - -// Bool defines a bool flag with specified name, default value, and usage string. -// The return value is the address of a bool variable that stores the value of the flag. -func (f *FlagSet) Bool(name string, value bool, usage string) *bool { - return f.BoolP(name, "", value, usage) -} - -// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { - p := new(bool) - f.BoolVarP(p, name, shorthand, value, usage) - return p -} - -// -- stringArray Value -type stringArrayValue struct { - value *[]string - changed bool -} - -func newStringArrayValue(val []string, p *[]string) *stringArrayValue { - ssv := new(stringArrayValue) - ssv.value = p - *ssv.value = val - return ssv -} - -func (s *stringArrayValue) Set(val string) error { - if !s.changed { - *s.value = []string{val} - s.changed = true - } else { - *s.value = append(*s.value, val) - } - return nil -} - -func writeAsCSV(vals []string) (string, error) { - b := &bytes.Buffer{} - w := csv.NewWriter(b) - err := w.Write(vals) - if err != nil { - return "", err - } - w.Flush() - return strings.TrimSuffix(b.String(), "\n"), nil -} - -func (s *stringArrayValue) String() string { - str, _ := writeAsCSV(*s.value) - return "[" + str + "]" -} - -func (s *stringArrayValue) Type() string { - return "stringArray" -} - -// StringArrayVar defines a string flag with specified name, default value, and usage string. -// The argument p points to a []string variable in which to store the values of the multiple flags. -func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { - f.VarP(newStringArrayValue(value, p), name, "", usage) -} From 6215b050257a4e398ab7b0fa18353568600a2149 Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 31 Jul 2024 20:05:20 +0800 Subject: [PATCH 34/41] add test for command and flag add test for command and flag resolve conflicts refactor: rm apache thrift in pkg/generic & netpollmux (#1470) fix: allow HEADERS frame with empty header block fragment (#1466) refactor: rm apache thrift in internal/mocks (#1474) also: fix(server): invoker return err if no apache codec fix(server): listening on loopback addr chore: remove github.com/stretchr/testify direct dependency (#1475) chore: upgrade gopkg to v0.1.0 (#1477) perf: custom allocator for fast codec ReadString/ReadBinary (#1427) # Conflicts: # pkg/protocol/bthrift/binary.go resolve conflicts feat: add init Signed-off-by: shawn --- pkg/generic/reflect_test/reflect_test.go | 1 - tool/cmd/kitex/args/tpl_args.go | 291 ++++++++++++++++------- tool/cmd/kitex/main.go | 6 +- tool/internal_pkg/generator/generator.go | 7 +- tool/internal_pkg/tpl/mock.go | 27 +++ tool/internal_pkg/util/command.go | 45 ++-- tool/internal_pkg/util/command_test.go | 256 ++++++++++++++++---- tool/internal_pkg/util/flag.go | 245 +++++++++---------- 8 files changed, 585 insertions(+), 293 deletions(-) create mode 100644 tool/internal_pkg/tpl/mock.go diff --git a/pkg/generic/reflect_test/reflect_test.go b/pkg/generic/reflect_test/reflect_test.go index 6974bc6762..714cebf2c4 100644 --- a/pkg/generic/reflect_test/reflect_test.go +++ b/pkg/generic/reflect_test/reflect_test.go @@ -34,7 +34,6 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/klog" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index cadb49df82..5ef114757a 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -16,14 +16,211 @@ package args import ( "fmt" - "os" - "strings" - "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "os" + "path/filepath" +) + +// Constants . +const ( + KitexGenPath = "kitex_gen" + DefaultCodec = "thrift" + + BuildFileName = "build.sh" + BootstrapFileName = "bootstrap.sh" + ToolVersionFileName = "kitex_info.yaml" + HandlerFileName = "handler.go" + MainFileName = "main.go" + ClientFileName = "client.go" + ServerFileName = "server.go" + InvokerFileName = "invoker.go" + ServiceFileName = "*service.go" + ExtensionFilename = "extensions.yaml" + + ClientMockFilename = "client_mock.go" +) + +var defaultTemplates = map[string]string{ + BuildFileName: tpl.BuildTpl, + BootstrapFileName: tpl.BootstrapTpl, + ToolVersionFileName: tpl.ToolVersionTpl, + HandlerFileName: tpl.HandlerTpl, + MainFileName: tpl.MainTpl, + ClientFileName: tpl.ClientTpl, + ServerFileName: tpl.ServerTpl, + InvokerFileName: tpl.InvokerTpl, + ServiceFileName: tpl.ServiceTpl, +} + +var mockTemplates = map[string]string{ + ClientMockFilename: tpl.ClientMockTpl, +} + +const ( + DefaultType = "default" + MockType = "mock" ) +type TemplateGenerator func(string) error + +var genTplMap = map[string]TemplateGenerator{ + DefaultType: GenTemplates, + MockType: GenMockTemplates, +} + +// GenTemplates is the entry for command kitex template, +// it will create the specified path +func GenTemplates(path string) error { + for key := range defaultTemplates { + if key == BootstrapFileName { + defaultTemplates[key] = util.JoinPath(path, "script", BootstrapFileName) + } + } + return InitTemplates(path, defaultTemplates) +} + +func GenMockTemplates(path string) error { + return InitTemplates(path, mockTemplates) +} + +// InitTemplates creates template files. +func InitTemplates(path string, templates map[string]string) error { + if err := MkdirIfNotExist(path); err != nil { + return err + } + + for k, v := range templates { + if err := createTemplate(filepath.Join(path, k+".tpl"), v); err != nil { + return err + } + } + + return nil +} + +// GetTemplateDir returns the category path. +func GetTemplateDir(category string) (string, error) { + home, err := filepath.Abs(".") + if err != nil { + return "", err + } + return filepath.Join(home, category), nil +} + +// MkdirIfNotExist makes directories if the input path is not exists +func MkdirIfNotExist(dir string) error { + if len(dir) == 0 { + return nil + } + + if _, err := os.Stat(dir); os.IsNotExist(err) { + return os.MkdirAll(dir, os.ModePerm) + } + + return nil +} + +func createTemplate(file, content string) error { + if util.Exists(file) { + return nil + } + + f, err := os.Create(file) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(content) + return err +} + +func (a *Arguments) Init(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + path := a.InitOutputDir + initType := a.InitType + if initType == "" { + initType = DefaultType + } + if path == "" { + path = curpath + } + if err := genTplMap[initType](path); err != nil { + return err + } + fmt.Printf("Templates are generated in %s\n", path) + os.Exit(0) + return nil +} + +func (a *Arguments) Render(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + if len(args) < 2 { + return fmt.Errorf("both template directory and idl is required") + } + a.RenderTplDir = args[0] + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(cmd.Flags().Args()[1:]) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Clean(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(cmd.Flags().Args()) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + func (a *Arguments) TemplateArgs(version, curpath string) error { kitexCmd := &util.Command{ Use: "kitex", @@ -36,99 +233,17 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { initCmd := &util.Command{ Use: "init", Short: "Init command", - RunE: func(cmd *util.Command, args []string) error { - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Init, } renderCmd := &util.Command{ Use: "render", Short: "Render command", - RunE: func(cmd *util.Command, args []string) error { - if len(args) > 0 { - a.RenderTplDir = args[0] - } - var tplDir string - for _, arg := range args { - if !strings.HasPrefix(arg, "-") { - tplDir = arg - break - } - } - if tplDir == "" { - return fmt.Errorf("template directory is required") - } - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()[1:]) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Render, } cleanCmd := &util.Command{ Use: "clean", Short: "Clean command", - RunE: func(cmd *util.Command, args []string) error { - log.Verbose = a.Verbose - - for _, e := range a.extends { - err := e.Check(a) - if err != nil { - return err - } - } - - err := a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() - if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath - } - return a.checkPath(curpath) - }, + RunE: a.Clean, } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index ec931c6506..237297067f 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,6 +17,7 @@ package main import ( "bytes" "flag" + "fmt" "io/ioutil" "os" "os/exec" @@ -51,7 +52,7 @@ func init() { if err := versions.RegisterMinDepVersion( &versions.MinDepVersion{ RefPath: "github.com/cloudwego/kitex", - Version: "v0.11.0", + Version: "v0.10.3", }, ); err != nil { log.Warn(err) @@ -79,8 +80,7 @@ func main() { if os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version, curpath) } else if !strings.HasPrefix(os.Args[1], "-") { - log.Warnf("Unknown command %q", os.Args[1]) - os.Exit(1) + err = fmt.Errorf("unknown command %q", os.Args[1]) } else { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index fe95db4bbc..238cfcace6 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -16,6 +16,7 @@ package generator import ( + "errors" "fmt" "go/token" "path/filepath" @@ -135,9 +136,9 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string + InitOutputDir string //specify the location path of init subcommand InitType string - RenderTplDir string + RenderTplDir string // specify the path of template directory for render subcommand TemplateFile string GenPath string @@ -374,7 +375,7 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error pkg.ServiceInfo.ServiceName) f, err := comp.CompleteMethods() if err != nil { - if err == errNoNewMethod { + if errors.Is(err, errNoNewMethod) { return fs, nil } return nil, err diff --git a/tool/internal_pkg/tpl/mock.go b/tool/internal_pkg/tpl/mock.go new file mode 100644 index 0000000000..27da3a59d7 --- /dev/null +++ b/tool/internal_pkg/tpl/mock.go @@ -0,0 +1,27 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpl + +var ClientMockTpl = ` +// Code generated by MockGen. DO NOT EDIT. +// Source: kitex_gen/example/shop/item/itemserviceb/client.go +// +// Generated by this command: +// +// ___1go_build_go_uber_org_mock_mockgen -source=kitex_gen/example/shop/item/itemserviceb/client.go -destination=client_mock.go -package=main +// + +// Package main is a generated GoMock package. +` diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 32e1a3ad55..fda97822d3 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -30,6 +30,14 @@ type Command struct { flags *FlagSet // helpFunc is help func defined by user. usage func() + // for debug + args []string +} + +// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden +// particularly useful when testing. +func (c *Command) SetArgs(a []string) { + c.args = a } func (c *Command) AddCommand(cmds ...*Command) error { @@ -51,24 +59,11 @@ func (c *Command) Flags() *FlagSet { return c.flags } -// Parent returns a commands parent command. -func (c *Command) Parent() *Command { - return c.parent -} - // HasParent determines if the command is a child command. func (c *Command) HasParent() bool { return c.parent != nil } -// Root finds root command. -func (c *Command) Root() *Command { - if c.HasParent() { - return c.Parent().Root() - } - return c -} - // HasSubCommands determines if the command has children commands. func (c *Command) HasSubCommands() bool { return len(c.commands) > 0 @@ -79,10 +74,16 @@ func stripFlags(args []string) []string { for len(args) > 0 { s := args[0] args = args[1:] - if s != "" && !strings.HasPrefix(s, "-") { + if strings.HasPrefix(s, "-") { + // handle "-f child child" args + if len(args) <= 1 { + break + } else { + args = args[1:] + continue + } + } else if s != "" && !strings.HasPrefix(s, "-") { commands = append(commands, s) - } else if strings.HasPrefix(s, "-") { - break } } return commands @@ -109,7 +110,8 @@ func nextArgs(args []string, x string) []string { continue case !strings.HasPrefix(s, "-"): if s == x { - var ret []string + // cannot use var ret []string cause it return nil + ret := make([]string, 0) ret = append(ret, args[:pos]...) ret = append(ret, args[pos+1:]...) return ret @@ -169,8 +171,10 @@ func (c *Command) UsageFunc() func() { // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { - args := os.Args[1:] - + args := c.args + if c.args == nil { + args = os.Args[1:] + } cmd, flags, err := c.Find(args) if err != nil { return c, err @@ -191,8 +195,9 @@ func (c *Command) execute(a []string) error { if err != nil { return err } + argWoFlags := c.Flags().Args() if c.RunE != nil { - err := c.RunE(c, a) + err := c.RunE(c, argWoFlags) if err != nil { return err } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go index a5a4766850..9bb5723d67 100644 --- a/tool/internal_pkg/util/command_test.go +++ b/tool/internal_pkg/util/command_test.go @@ -15,75 +15,249 @@ package util import ( - "os" + "fmt" + "reflect" + "strings" "testing" ) -func TestAddCommand(t *testing.T) { - rootCmd := &Command{Use: "root"} - childCmd := &Command{Use: "child"} +func emptyRun(*Command, []string) error { return nil } - // Test adding a valid child command - err := rootCmd.AddCommand(childCmd) - if err != nil { - t.Fatalf("expected no error, got %v", err) +func executeCommand(root *Command, args ...string) (err error) { + _, err = executeCommandC(root, args...) + return err +} + +func executeCommandC(root *Command, args ...string) (c *Command, err error) { + root.SetArgs(args) + c, err = root.ExecuteC() + return c, err +} + +const onetwo = "one two" + +func TestSingleCommand(t *testing.T) { + rootCmd := &Command{ + Use: "root", + RunE: func(_ *Command, args []string) error { return nil }, } + aCmd := &Command{Use: "a", RunE: emptyRun} + bCmd := &Command{Use: "b", RunE: emptyRun} + rootCmd.AddCommand(aCmd, bCmd) + + _ = executeCommand(rootCmd, "one", "two") +} - if len(rootCmd.commands) != 1 { - t.Fatalf("expected 1 command, got %d", len(rootCmd.commands)) +func TestChildCommand(t *testing.T) { + var child1CmdArgs []string + rootCmd := &Command{Use: "root", RunE: emptyRun} + child1Cmd := &Command{ + Use: "child1", + RunE: func(_ *Command, args []string) error { child1CmdArgs = args; return nil }, } + child2Cmd := &Command{Use: "child2", RunE: emptyRun} + rootCmd.AddCommand(child1Cmd, child2Cmd) - if rootCmd.commands[0] != childCmd { - t.Fatalf("expected child command to be added") + err := executeCommand(rootCmd, "child1", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) } - // Test adding a command to itself - err = rootCmd.AddCommand(rootCmd) - if err == nil { - t.Fatalf("expected an error, got nil") + got := strings.Join(child1CmdArgs, " ") + if got != onetwo { + t.Errorf("child1CmdArgs expected: %q, got: %q", onetwo, got) } +} - expectedErr := "command can't be a child of itself" - if err.Error() != expectedErr { - t.Fatalf("expected error %q, got %q", expectedErr, err.Error()) +func TestCallCommandWithoutSubcommands(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + err := executeCommand(rootCmd) + if err != nil { + t.Errorf("Calling command without subcommands should not have error: %v", err) } } -func TestExecuteC(t *testing.T) { - rootCmd := &Command{ +func TestRootExecuteUnknownCommand(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + rootCmd.AddCommand(&Command{Use: "child", RunE: emptyRun}) + + _ = executeCommand(rootCmd, "unknown") +} + +func TestSubcommandExecuteC(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + + _, err := executeCommandC(rootCmd, "child") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestFind(t *testing.T) { + var foo, bar string + root := &Command{ Use: "root", - RunE: func(cmd *Command, args []string) error { - return nil - }, } + root.Flags().StringVarP(&foo, "foo", "f", "", "") + root.Flags().StringVarP(&bar, "bar", "b", "something", "") + + child := &Command{ + Use: "child", + } + root.AddCommand(child) - subCmd := &Command{ - Use: "sub", - RunE: func(cmd *Command, args []string) error { - return nil + testCases := []struct { + args []string + expectedFoundArgs []string + }{ + { + []string{"child"}, + []string{}, + }, + { + []string{"child", "child"}, + []string{"child"}, + }, + { + []string{"child", "foo", "child", "bar", "child", "baz", "child"}, + []string{"foo", "child", "bar", "child", "baz", "child"}, + }, + { + []string{"-f", "child", "child"}, + []string{"-f", "child"}, + }, + { + []string{"child", "-f", "child"}, + []string{"-f", "child"}, + }, + { + []string{"-b", "child", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b"}, + []string{"-b"}, }, + { + []string{"-b", "-f", "child", "child"}, + []string{"-b", "-f", "child"}, + }, + { + []string{"-f", "child", "-b", "something", "child"}, + []string{"-f", "child", "-b", "something"}, + }, + { + []string{"-f", "child", "child", "-b"}, + []string{"-f", "child", "-b"}, + }, + { + []string{"-f=child", "-b=something", "child"}, + []string{"-f=child", "-b=something"}, + }, + { + []string{"--foo", "child", "--bar", "something", "child"}, + []string{"--foo", "child", "--bar", "something"}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { + cmd, foundArgs, err := root.Find(tc.args) + if err != nil { + t.Fatal(err) + } + + if cmd != child { + t.Fatal("Expected cmd to be child, but it was not") + } + + if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) { + t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs) + } + }) } +} - rootCmd.AddCommand(subCmd) +func TestFlagLong(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } - // Simulate command line arguments - os.Args = []string{"root", "sub"} + var intFlagValue int + var stringFlagValue string + c.Flags().IntVar(&intFlagValue, "intf", -1, "") + c.Flags().StringVar(&stringFlagValue, "sf", "", "") - // Execute the command - cmd, err := rootCmd.ExecuteC() + err := executeCommand(c, "--intf=7", "--sf=abc", "one", "--", "two") if err != nil { - t.Fatalf("expected no error, got %v", err) + t.Errorf("Unexpected error: %v", err) + } + + if intFlagValue != 7 { + t.Errorf("Expected intFlagValue: %v, got %v", 7, intFlagValue) + } + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } - if cmd.Use != "sub" { - t.Fatalf("expected sub command to be executed, got %s", cmd.Use) + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) } +} - // Simulate command line arguments with an unknown command - os.Args = []string{"root", "unknown"} +func TestFlagShort(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } + + var intFlagValue int + var stringFlagValue string + c.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") + c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") + + err := executeCommand(c, "-i", "7", "-sabc", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if intFlagValue != 7 { + t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) + } + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) + } + + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) + } +} + +func TestChildFlag(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + + var intFlagValue int + childCmd.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") + + err := executeCommand(rootCmd, "child", "-i7") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } - _, err = rootCmd.ExecuteC() - if err == nil { - t.Fatalf("expected an error for unknown command, got nil") + if intFlagValue != 7 { + t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) } } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 3c0b681197..21096cf8f1 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -22,6 +22,7 @@ import ( "os" "strconv" "strings" + "time" ) // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. @@ -69,7 +70,6 @@ type FlagSet struct { orderedActual []*Flag formal map[NormalizedName]*Flag orderedFormal []*Flag - sortedFormal []*Flag shorthands map[byte]*Flag args []string // arguments after flags argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- @@ -83,17 +83,15 @@ type FlagSet struct { // A Flag represents the state of a flag. type Flag struct { - Name string // name as it appears on command line - Shorthand string // one-letter abbreviated flag - Usage string // help message - Value Value // value as set - DefValue string // default value (as text); for usage message - Changed bool // If the user set the value (or if left to default) - NoOptDefVal string // default value (as text); if the flag is on the command line without any options - Deprecated string // If this flag is deprecated, this string is the new or now thing to use - Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text - ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use - Annotations map[string][]string // used by cobra.Command bash autocomple code + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string // default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use } // Value is the interface to the dynamic value stored in a flag. @@ -115,29 +113,6 @@ type SliceValue interface { GetSlice() []string } -// SetNormalizeFunc allows you to add a function which can translate flag names. -// Flags added to the FlagSet will be translated and then when anything tries to -// look up the flag that will also be translated. So it would be possible to create -// a flag named "getURL" and have it translated to "geturl". A user could then pass -// "--getUrl" which may also be translated to "geturl" and everything will work. -func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { - f.normalizeNameFunc = n - f.sortedFormal = f.sortedFormal[:0] - for fname, flag := range f.formal { - nname := f.normalizeFlagName(flag.Name) - if fname == nname { - continue - } - flag.Name = string(nname) - delete(f.formal, fname) - f.formal[nname] = flag - if _, set := f.actual[fname]; set { - delete(f.actual, fname) - f.actual[nname] = flag - } - } -} - // GetNormalizeFunc returns the previously set NormalizeFunc of a function which // does no translation, if not set previously. func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { @@ -152,6 +127,16 @@ func (f *FlagSet) normalizeFlagName(name string) NormalizedName { return n(f, name) } +// Lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) Lookup(name string) *Flag { + return f.lookup(f.normalizeFlagName(name)) +} + +// lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) lookup(name NormalizedName) *Flag { + return f.formal[name] +} + func (f *FlagSet) out() io.Writer { if f.output == nil { return os.Stderr @@ -200,37 +185,6 @@ func (f *FlagSet) Set(name, value string) error { return nil } -// SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. -// This is sometimes used by spf13/cobra programs which want to generate additional -// bash completion information. -func (f *FlagSet) SetAnnotation(name, key string, values []string) error { - normalName := f.normalizeFlagName(name) - flag, ok := f.formal[normalName] - if !ok { - return fmt.Errorf("no such flag -%v", name) - } - if flag.Annotations == nil { - flag.Annotations = map[string][]string{} - } - flag.Annotations[key] = values - return nil -} - -// NFlag returns the number of flags that have been set. -func (f *FlagSet) NFlag() int { return len(f.actual) } - -// Arg returns the i'th argument. Arg(0) is the first remaining argument -// after flags have been processed. -func (f *FlagSet) Arg(i int) string { - if i < 0 || i >= len(f.args) { - return "" - } - return f.args[i] -} - -// NArg is the number of arguments remaining after flags have been processed. -func (f *FlagSet) NArg() int { return len(f.args) } - // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } @@ -526,60 +480,11 @@ func (f *FlagSet) Parse(arguments []string) error { type parseFunc func(flag *Flag, value string) error -// ParseAll parses flag definitions from the argument list, which should not -// include the command name. The arguments for fn are flag and value. Must be -// called after all flags in the FlagSet are defined and before flags are -// accessed by the program. The return value will be ErrHelp if -help was set -// but not defined. -func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error { - f.parsed = true - f.args = make([]string, 0, len(arguments)) - - err := f.parseArgs(arguments, fn) - if err != nil { - switch f.errorHandling { - case ContinueOnError: - return err - case ExitOnError: - os.Exit(2) - case PanicOnError: - panic(err) - } - } - return nil -} - // Parsed reports whether f.Parse has been called. func (f *FlagSet) Parsed() bool { return f.parsed } -// Parse parses the command-line flags from os.Args[1:]. Must be called -// after all flags are defined and before flags are accessed by the program. -func Parse() { - // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.Parse(os.Args[1:]) -} - -// ParseAll parses the command-line flags from os.Args[1:] and called fn for each. -// The arguments for fn are flag and value. Must be called after all flags are -// defined and before flags are accessed by the program. -func ParseAll(fn func(flag *Flag, value string) error) { - // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.ParseAll(os.Args[1:], fn) -} - -// SetInterspersed sets whether to support interspersed option/non-option arguments. -func SetInterspersed(interspersed bool) { - CommandLine.SetInterspersed(interspersed) -} - -// Parsed returns true if the command-line flags have been parsed. -func Parsed() bool { - return CommandLine.Parsed() -} - -// CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) // NewFlagSet returns a new, empty flag set with the specified name, @@ -595,20 +500,6 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { return f } -// SetInterspersed sets whether to support interspersed option/non-option arguments. -func (f *FlagSet) SetInterspersed(interspersed bool) { - f.interspersed = interspersed -} - -// Init sets the name and error handling property for a flag set. -// By default, the zero FlagSet uses an empty name and the -// ContinueOnError error handling policy. -func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { - f.name = name - f.errorHandling = errorHandling - f.argsLenAtDash = -1 -} - // -- string Value type stringValue string @@ -622,10 +513,6 @@ func (s *stringValue) Set(val string) error { return nil } -func (s *stringValue) Type() string { - return "string" -} - func (s *stringValue) String() string { return string(*s) } // StringVar defines a string flag with specified name, default value, and usage string. @@ -638,6 +525,21 @@ func (f *FlagSet) StringVarP(p *string, name, shorthand, value, usage string) { f.VarP(newStringValue(value, p), name, shorthand, usage) } +// String defines a string flag with specified name, default value, and usage string. +// The return value is the address of a string variable that stores the value of the flag. +func (f *FlagSet) String(name, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, "", value, usage) + return p +} + +// StringP is like String, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringP(name, shorthand, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, shorthand, value, usage) + return p +} + // -- bool Value type boolValue bool @@ -652,10 +554,6 @@ func (b *boolValue) Set(s string) error { return err } -func (b *boolValue) Type() string { - return "bool" -} - func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } func (b *boolValue) IsBoolFlag() bool { return true } @@ -679,3 +577,76 @@ func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage st flag.NoOptDefVal = "true" return nil } + +// Bool defines a bool flag with specified name, default value, and usage string. +// The return value is the address of a bool variable that stores the value of the flag. +func (f *FlagSet) Bool(name string, value bool, usage string) *bool { + return f.BoolP(name, "", value, usage) +} + +// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { + p := new(bool) + f.BoolVarP(p, name, shorthand, value, usage) + return p +} + +// -- int Value +type intValue int + +func newIntValue(val int, p *int) *intValue { + *p = val + return (*intValue)(p) +} + +func (i *intValue) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + *i = intValue(v) + return err +} + +func (i *intValue) String() string { return strconv.Itoa(int(*i)) } + +// IntVar defines an int flag with specified name, default value, and usage string. +// The argument p points to an int variable in which to store the value of the flag. +func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { + f.VarP(newIntValue(value, p), name, "", usage) +} + +// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IntVarP(p *int, name, shorthand string, value int, usage string) { + f.VarP(newIntValue(value, p), name, shorthand, usage) +} + +func (f *FlagSet) Int(name string, value int, usage string) *int { + p := new(int) + f.IntVarP(p, name, "", value, usage) + return p +} + +// -- time.Duration Value +type durationValue time.Duration + +func (d *durationValue) Set(s string) error { + v, err := time.ParseDuration(s) + *d = durationValue(v) + return err +} + +func (d *durationValue) String() string { return (*time.Duration)(d).String() } + +func newDurationValue(val time.Duration, p *time.Duration) *durationValue { + *p = val + return (*durationValue)(p) +} + +// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { + f.VarP(newDurationValue(value, p), name, shorthand, usage) +} + +func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration { + p := new(time.Duration) + f.DurationVarP(p, name, "", value, usage) + return p +} From e335067b87942884ce0fbadf0400913f4427774a Mon Sep 17 00:00:00 2001 From: shawn Date: Mon, 5 Aug 2024 21:40:47 +0800 Subject: [PATCH 35/41] feat: render with multiple files and debug mode Signed-off-by: shawn feat: add subcommand clean Signed-off-by: shawn feat: add kitex_render_meta.yaml for render Signed-off-by: shawn --- tool/cmd/kitex/args/tpl_args.go | 109 +++++---- tool/cmd/kitex/main.go | 2 +- .../internal_pkg/generator/custom_template.go | 228 +++--------------- tool/internal_pkg/generator/generator.go | 8 +- tool/internal_pkg/generator/generator_test.go | 7 +- .../pluginmode/thriftgo/plugin.go | 12 + tool/internal_pkg/util/flag.go | 56 +++-- tool/internal_pkg/util/flag_test.go | 17 ++ 8 files changed, 161 insertions(+), 278 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 5ef114757a..222008c949 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -16,12 +16,15 @@ package args import ( "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" - "os" - "path/filepath" ) // Constants . @@ -40,7 +43,7 @@ const ( ServiceFileName = "*service.go" ExtensionFilename = "extensions.yaml" - ClientMockFilename = "client_mock.go" + MultipleServicesFileName = "multiple_services.go" ) var defaultTemplates = map[string]string{ @@ -55,35 +58,30 @@ var defaultTemplates = map[string]string{ ServiceFileName: tpl.ServiceTpl, } -var mockTemplates = map[string]string{ - ClientMockFilename: tpl.ClientMockTpl, +var multipleServicesTpl = map[string]string{ + MultipleServicesFileName: tpl.MultipleServicesTpl, } const ( - DefaultType = "default" - MockType = "mock" + DefaultType = "default" + MultipleServicesType = "multiple_services" ) type TemplateGenerator func(string) error var genTplMap = map[string]TemplateGenerator{ - DefaultType: GenTemplates, - MockType: GenMockTemplates, + DefaultType: GenTemplates, + MultipleServicesType: GenMultipleServicesTemplates, } // GenTemplates is the entry for command kitex template, // it will create the specified path func GenTemplates(path string) error { - for key := range defaultTemplates { - if key == BootstrapFileName { - defaultTemplates[key] = util.JoinPath(path, "script", BootstrapFileName) - } - } return InitTemplates(path, defaultTemplates) } -func GenMockTemplates(path string) error { - return InitTemplates(path, mockTemplates) +func GenMultipleServicesTemplates(path string) error { + return InitTemplates(path, multipleServicesTpl) } // InitTemplates creates template files. @@ -92,8 +90,18 @@ func InitTemplates(path string, templates map[string]string) error { return err } - for k, v := range templates { - if err := createTemplate(filepath.Join(path, k+".tpl"), v); err != nil { + for name, content := range templates { + var filePath string + if name == BootstrapFileName { + bootstrapDir := filepath.Join(path, "script") + if err := MkdirIfNotExist(bootstrapDir); err != nil { + return err + } + filePath = filepath.Join(bootstrapDir, name+".tpl") + } else { + filePath = filepath.Join(path, name+".tpl") + } + if err := createTemplate(filePath, content); err != nil { return err } } @@ -164,10 +172,6 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) } - if len(args) < 2 { - return fmt.Errorf("both template directory and idl is required") - } - a.RenderTplDir = args[0] log.Verbose = a.Verbose for _, e := range a.extends { @@ -177,7 +181,7 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { } } - err = a.checkIDL(cmd.Flags().Args()[1:]) + err = a.checkIDL(args) if err != nil { return err } @@ -197,31 +201,35 @@ func (a *Arguments) Clean(cmd *util.Command, args []string) error { if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) } - log.Verbose = a.Verbose - for _, e := range a.extends { - err := e.Check(a) + magicString := "// Kitex template debug file. use template clean to delete it." + err = filepath.WalkDir(curpath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } - } - - err = a.checkIDL(cmd.Flags().Args()) - if err != nil { - return err - } - err = a.checkServiceName() + if d.IsDir() { + return nil + } + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read file %s faild: %v", path, err) + } + if strings.Contains(string(content), magicString) { + if err := os.Remove(path); err != nil { + return fmt.Errorf("delete file %s failed: %v", path, err) + } + } + return nil + }) if err != nil { - return err - } - // todo finish protobuf - if a.IDLType != "thrift" { - a.GenPath = generator.KitexGenPath + return fmt.Errorf("error cleaning debug template files: %v", err) } - return a.checkPath(curpath) + fmt.Println("clean debug template files successfully...") + os.Exit(0) + return nil } -func (a *Arguments) TemplateArgs(version, curpath string) error { +func (a *Arguments) TemplateArgs(version string) error { kitexCmd := &util.Command{ Use: "kitex", Short: "Kitex command", @@ -246,38 +254,29 @@ func (a *Arguments) TemplateArgs(version, curpath string) error { RunE: a.Clean, } initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") - renderCmd.Flags().StringVarP(&a.ModuleName, "module", "m", "", + initCmd.Flags().StringVar(&a.InitType, "type", "", "Specify template init type") + renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") + renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, "Specify a code gen path.") - renderCmd.Flags().StringVarP(&a.TemplateFile, "file", "f", "", "Specify single template path") + renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") + renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") initCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template init [flags] - -Examples: - kitex template init -o /path/to/output - kitex template init - -Flags: `, version) }) renderCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s -Usage: template render [template dir_path] [flags] IDL +Usage: template render --dir [template dir_path] [flags] IDL `, version) }) cleanCmd.SetUsageFunc(func() { fmt.Fprintf(os.Stderr, `Version %s Usage: kitex template clean - -Examples: - kitex template clean - -Flags: `, version) }) // renderCmd.PrintUsage() diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 237297067f..1fb5372ffe 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -78,7 +78,7 @@ func main() { os.Exit(1) } if os.Args[1] == "template" { - err = args.TemplateArgs(kitex.Version, curpath) + err = args.TemplateArgs(kitex.Version) } else if !strings.HasPrefix(os.Args[1], "-") { err = fmt.Errorf("unknown command %q", os.Args[1]) } else { diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 48dc283133..52306dc8b1 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -200,93 +200,6 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } -func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { - files, _ := os.ReadDir(currentDir) - for _, f := range files { - // filter dir and non-tpl files - if f.IsDir() { - subDir := filepath.Join(currentDir, f.Name()) - subTemplates, err := readTpls(rootDir, subDir, ts) - if err != nil { - return nil, err - } - ts = append(ts, subTemplates...) - } else if strings.HasSuffix(f.Name(), ".tpl") { - p := filepath.Join(currentDir, f.Name()) - tplData, err := os.ReadFile(p) - if err != nil { - return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) - } - // Remove the .tpl suffix from the Path and compute relative path - relativePath, err := filepath.Rel(rootDir, p) - if err != nil { - return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) - } - trimmedPath := strings.TrimSuffix(relativePath, ".tpl") - t := &Template{ - Path: trimmedPath, - Body: string(tplData), - UpdateBehavior: &Update{Type: string(skip)}, - } - ts = append(ts, t) - } - } - - return ts, nil -} - -func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { - g.updatePackageInfo(pkg) - - g.setImports(HandlerFileName, pkg) - var tpls []*Template - tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) - if err != nil { - return nil, err - } - for _, tpl := range tpls { - newPath := filepath.Join(g.OutputPath, tpl.Path) - dir := filepath.Dir(newPath) - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return nil, fmt.Errorf("failed to create directory %s: %v", dir, err) - } - if tpl.LoopService && g.CombineService { - svrInfo, cs := pkg.ServiceInfo, pkg.CombineServices - - for i := range cs { - pkg.ServiceInfo = cs[i] - f, err := renderFile(pkg, g.OutputPath, tpl) - if err != nil { - return nil, err - } - fs = append(fs, f...) - } - pkg.ServiceInfo, pkg.CombineServices = svrInfo, cs - } else { - f, err := renderFile(pkg, g.OutputPath, tpl) - if err != nil { - return nil, err - } - fs = append(fs, f...) - } - } - return fs, nil -} - -func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { - cg := NewCustomGenerator(pkg, outputPath) - // special handling Methods field - if tpl.LoopMethod { - err = cg.loopGenerate(tpl) - } else { - err = cg.commonGenerate(tpl) - } - if errors.Is(err, errNoNewMethod) { - err = nil - } - return cg.fs, err -} - func readTemplates(dir string) ([]*Template, error) { files, _ := ioutil.ReadDir(dir) var ts []*Template @@ -311,86 +224,26 @@ func readTemplates(dir string) ([]*Template, error) { return ts, nil } -// parseMeta parses the meta flag and returns a map where the value is a slice of strings -func parseMeta(metaFlags string) (map[string][]string, error) { - meta := make(map[string][]string) - if metaFlags == "" { - return meta, nil - } - - // split for each key=value pairs - pairs := strings.Split(metaFlags, ";") - for _, pair := range pairs { - kv := strings.SplitN(pair, "=", 2) - if len(kv) == 2 { - key := kv[0] - values := strings.Split(kv[1], ",") - meta[key] = values - } else { - return nil, fmt.Errorf("Invalid meta format: %s\n", pair) - } +func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { + cg := NewCustomGenerator(pkg, outputPath) + // special handling Methods field + if tpl.LoopMethod { + err = cg.loopGenerate(tpl) + } else { + err = cg.commonGenerate(tpl) } - return meta, nil -} - -func parseMiddlewares(middlewares []MiddlewareForResolve) ([]GlobalMiddleware, error) { - var mwList []GlobalMiddleware - - for _, mw := range middlewares { - content, err := os.ReadFile(mw.Path) - if err != nil { - return nil, fmt.Errorf("failed to read middleware file %s: %v", mw.Path, err) - } - mwList = append(mwList, GlobalMiddleware{ - Name: mw.Name, - Content: string(content), - }) + if errors.Is(err, errNoNewMethod) { + err = nil } - return mwList, nil + return cg.fs, err } func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { g.updatePackageInfo(pkg) g.setImports(HandlerFileName, pkg) - pkg.ExtendMeta, err = parseMeta(g.MetaFlags) - if err != nil { - return nil, err - } - if g.Config.IncludesTpl != "" { - inc := g.Config.IncludesTpl - if strings.HasPrefix(inc, "git@") || strings.HasPrefix(inc, "http://") || strings.HasPrefix(inc, "https://") { - localGitPath, errMsg, gitErr := util.RunGitCommand(inc) - if gitErr != nil { - if errMsg == "" { - errMsg = gitErr.Error() - } - return nil, fmt.Errorf("failed to pull IDL from git:%s\nYou can execute 'rm -rf ~/.kitex' to clean the git cache and try again", errMsg) - } - if g.RenderTplDir != "" { - g.RenderTplDir = filepath.Join(localGitPath, g.RenderTplDir) - } else { - g.RenderTplDir = localGitPath - } - if util.Exists(g.RenderTplDir) { - return nil, fmt.Errorf("the render template directory path you specified does not exists int the git path") - } - } - } - var meta *Meta - metaPath := filepath.Join(g.RenderTplDir, kitexRenderMetaFile) - if util.Exists(metaPath) { - meta, err = readMetaFile(metaPath) - if err != nil { - return nil, err - } - middlewares, err := parseMiddlewares(meta.MWs) - if err != nil { - return nil, err - } - pkg.MWs = middlewares - } - tpls, err := readTpls(g.RenderTplDir, g.RenderTplDir, meta) + var tpls []*Template + tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) if err != nil { return nil, err } @@ -425,25 +278,13 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, const kitexRenderMetaFile = "kitex_render_meta.yaml" -// Meta represents the structure of the kitex_render_meta.yaml file. +// Meta 代表kitex_render_meta.yaml文件的结构 type Meta struct { - Templates []Template `yaml:"templates"` - MWs []MiddlewareForResolve `yaml:"middlewares"` - ExtendMeta []ExtendMeta `yaml:"extend_meta"` -} - -type MiddlewareForResolve struct { - // name of the middleware - Name string `yaml:"name"` - // path of the middleware - Path string `yaml:"path"` -} - -type ExtendMeta struct { - key string + Templates []Template `yaml:"templates"` } -func readMetaFile(metaPath string) (*Meta, error) { +func readMetaFile(rootDir string) (*Meta, error) { + metaPath := filepath.Join(rootDir, kitexRenderMetaFile) metaData, err := os.ReadFile(metaPath) if err != nil { return nil, fmt.Errorf("failed to read meta file from %s: %v", metaPath, err) @@ -458,28 +299,27 @@ func readMetaFile(metaPath string) (*Meta, error) { return &meta, nil } -func getMetadata(meta *Meta, relativePath string) *Template { - for i := range meta.Templates { - if meta.Templates[i].Path == relativePath { - return &meta.Templates[i] +func getUpdateBehavior(meta *Meta, relativePath string) *Update { + for _, template := range meta.Templates { + // fmt.Println(template.Path, "==", relativePath) + if template.Path == relativePath { + return template.UpdateBehavior } } - return &Template{ - UpdateBehavior: &Update{Type: string(skip)}, - } + return &Update{Type: string(skip)} } -func readTpls(rootDir, currentDir string, meta *Meta) (ts []*Template, error error) { - defaultMetadata := &Template{ - UpdateBehavior: &Update{Type: string(skip)}, +func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { + meta, err := readMetaFile(rootDir) + if err != nil { + return nil, err } - files, _ := os.ReadDir(currentDir) for _, f := range files { // filter dir and non-tpl files if f.IsDir() { subDir := filepath.Join(currentDir, f.Name()) - subTemplates, err := readTpls(rootDir, subDir, meta) + subTemplates, err := readTpls(rootDir, subDir, ts) if err != nil { return nil, err } @@ -496,19 +336,11 @@ func readTpls(rootDir, currentDir string, meta *Meta) (ts []*Template, error err return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) } trimmedPath := strings.TrimSuffix(relativePath, ".tpl") - // If kitex_render_meta.yaml exists, get the corresponding metadata; otherwise, use the default metadata - var metadata *Template - if meta != nil { - metadata = getMetadata(meta, relativePath) - } else { - metadata = defaultMetadata - } + updateBehavior := getUpdateBehavior(meta, relativePath) t := &Template{ Path: trimmedPath, Body: string(tplData), - UpdateBehavior: metadata.UpdateBehavior, - LoopMethod: metadata.LoopMethod, - LoopService: metadata.LoopService, + UpdateBehavior: updateBehavior, } ts = append(ts, t) } @@ -525,7 +357,7 @@ func (g *generator) RenderWithMultipleFiles(pkg *PackageInfo) (fs []*File, err e } var updatedContent string if g.Config.DebugTpl { - // when --debug is enabled, add a magic string at the top of the template content for distinction. + // --debug时 在模板内容顶部加上一段magic string用于区分 updatedContent = "// Kitex template debug file. use template clean to delete it.\n\n" + string(content) } else { updatedContent = string(content) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 238cfcace6..c7478887c0 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -98,6 +98,7 @@ type Generator interface { GenerateMainPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackageWithTpl(pkg *PackageInfo) ([]*File, error) + RenderWithMultipleFiles(pkg *PackageInfo) ([]*File, error) } // Config . @@ -136,10 +137,11 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string //specify the location path of init subcommand + InitOutputDir string // specify the location path of init subcommand InitType string - RenderTplDir string // specify the path of template directory for render subcommand - TemplateFile string + RenderTplDir string // specify the path of template directory for render subcommand + TemplateFiles []string // specify the path of single file or multiple file to render + DebugTpl bool GenPath string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index cc39bffc31..ebb5245b70 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -59,7 +59,8 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir string InitType string RenderTplDir string - TemplateFile string + TemplateFiles []string + DebugTpl bool Protocol string HandlerReturnKeepResp bool } @@ -73,7 +74,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFile=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -104,7 +105,7 @@ func TestConfig_Pack(t *testing.T) { InitOutputDir: tt.fields.InitOutputDir, InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, - TemplateFile: tt.fields.TemplateFile, + TemplateFiles: tt.fields.TemplateFiles, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index e65c3e179e..45560251d0 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -107,6 +107,18 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } + if len(conv.Config.TemplateFiles) > 0 { + if len(conv.Services) == 0 { + return conv.failResp(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.RenderWithMultipleFiles(&conv.Package) + if err != nil { + return conv.failResp(err) + } + files = append(files, fs...) + } + if conv.Config.RenderTplDir != "" { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 21096cf8f1..d80a7ac6cf 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -15,6 +15,8 @@ package util import ( + "bytes" + "encoding/csv" "errors" goflag "flag" "fmt" @@ -22,7 +24,6 @@ import ( "os" "strconv" "strings" - "time" ) // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. @@ -624,29 +625,48 @@ func (f *FlagSet) Int(name string, value int, usage string) *int { return p } -// -- time.Duration Value -type durationValue time.Duration +// -- stringArray Value +type stringArrayValue struct { + value *[]string + changed bool +} -func (d *durationValue) Set(s string) error { - v, err := time.ParseDuration(s) - *d = durationValue(v) - return err +func newStringArrayValue(val []string, p *[]string) *stringArrayValue { + ssv := new(stringArrayValue) + ssv.value = p + *ssv.value = val + return ssv } -func (d *durationValue) String() string { return (*time.Duration)(d).String() } +func (s *stringArrayValue) Set(val string) error { + if !s.changed { + *s.value = []string{val} + s.changed = true + } else { + *s.value = append(*s.value, val) + } + return nil +} -func newDurationValue(val time.Duration, p *time.Duration) *durationValue { - *p = val - return (*durationValue)(p) +func writeAsCSV(vals []string) (string, error) { + b := &bytes.Buffer{} + w := csv.NewWriter(b) + err := w.Write(vals) + if err != nil { + return "", err + } + w.Flush() + return strings.TrimSuffix(b.String(), "\n"), nil } -// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { - f.VarP(newDurationValue(value, p), name, shorthand, usage) +func (s *stringArrayValue) String() string { + str, _ := writeAsCSV(*s.value) + return "[" + str + "]" } -func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration { - p := new(time.Duration) - f.DurationVarP(p, name, "", value, usage) - return p +// StringArrayVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a []string variable in which to store the values of the multiple flags. +// The value of each argument will not try to be separated by comma. Use a StringSlice for that. +func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { + f.VarP(newStringArrayValue(value, p), name, "", usage) } diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go index 916f906df5..364d379c1f 100644 --- a/tool/internal_pkg/util/flag_test.go +++ b/tool/internal_pkg/util/flag_test.go @@ -45,13 +45,21 @@ func testParse(f *FlagSet, t *testing.T) { boolFlag := f.Bool("bool", false, "bool value") bool2Flag := f.Bool("bool2", false, "bool2 value") bool3Flag := f.Bool("bool3", false, "bool3 value") + intFlag := f.Int("int", 0, "int value") stringFlag := f.String("string", "0", "string value") + optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" + optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" extra := "one-extra-argument" args := []string{ "--bool", "--bool2=true", "--bool3=false", + "--int=22", "--string=hello", + "--optional-int-no-value", + "--optional-int-with-value=42", extra, } if err := f.Parse(args); err != nil { @@ -69,9 +77,18 @@ func testParse(f *FlagSet, t *testing.T) { if *bool3Flag != false { t.Error("bool3 flag should be false, is ", *bool2Flag) } + if *intFlag != 22 { + t.Error("int flag should be 22, is ", *intFlag) + } if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } + if *optionalIntNoValueFlag != 9 { + t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) + } + if *optionalIntWithValueFlag != 42 { + t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) + } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra { From 4e4e8de24fd9a0af9cf66f73ef9e608d9c8f82bf Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 6 Aug 2024 16:19:12 +0800 Subject: [PATCH 36/41] feat: add remote repo for render and middlewares Signed-off-by: shawn chore(generic): add generic base using gopkg base (#1482) fix(gonet): adjust gonet server read timeout to avoid read error (#1481) fix: remove redundant line Signed-off-by: shawn fix: fix single command Signed-off-by: shawn feat: add userDefinedMiddleware Signed-off-by: shawn --- tool/cmd/kitex/args/args.go | 7 +- tool/cmd/kitex/args/tpl_args.go | 151 +++++++++- tool/cmd/kitex/main.go | 4 +- .../internal_pkg/generator/custom_template.go | 155 ++++++++-- tool/internal_pkg/generator/generator.go | 8 +- tool/internal_pkg/generator/generator_test.go | 6 +- tool/internal_pkg/generator/type.go | 4 +- tool/internal_pkg/tpl/mock.go | 27 -- tool/internal_pkg/util/command.go | 53 +++- tool/internal_pkg/util/command_test.go | 24 +- tool/internal_pkg/util/flag.go | 285 +++++++++++++++--- tool/internal_pkg/util/flag_test.go | 17 -- 12 files changed, 578 insertions(+), 163 deletions(-) delete mode 100644 tool/internal_pkg/tpl/mock.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 0e8e9e9300..fab0cb7c0e 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -212,14 +212,11 @@ func (a *Arguments) checkIDL(files []string) error { } func (a *Arguments) checkServiceName() error { - if a.ServiceName == "" && a.TemplateDir == "" && a.RenderTplDir == "" { + if a.ServiceName == "" && a.TemplateDir == "" { if a.Use != "" { - return fmt.Errorf("-use must be used with -service or -template-dir or template render") + return fmt.Errorf("-use must be used with -service or -template-dir") } } - if a.TemplateDir != "" && a.RenderTplDir != "" { - return fmt.Errorf("template render and -template-dir cannot be used at the same time") - } if a.ServiceName != "" && a.TemplateDir != "" { return fmt.Errorf("-template-dir and -service cannot be specified at the same time") } diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index 222008c949..a190ff7f2d 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -167,6 +167,82 @@ func (a *Arguments) Init(cmd *util.Command, args []string) error { return nil } +func (a *Arguments) checkTplArgs() error { + if a.TemplateDir != "" && a.RenderTplDir != "" { + return fmt.Errorf("template render --dir and -template-dir cannot be used at the same time") + } + if a.RenderTplDir != "" && len(a.TemplateFiles) > 0 { + return fmt.Errorf("template render --dir and --file option cannot be specified at the same time") + } + return nil +} + +func (a *Arguments) Root(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Template(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + func (a *Arguments) Render(cmd *util.Command, args []string) error { curpath, err := filepath.Abs(".") if err != nil { @@ -189,6 +265,10 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { if err != nil { return err } + err = a.checkTplArgs() + if err != nil { + return err + } // todo finish protobuf if a.IDLType != "thrift" { a.GenPath = generator.KitexGenPath @@ -233,10 +313,12 @@ func (a *Arguments) TemplateArgs(version string) error { kitexCmd := &util.Command{ Use: "kitex", Short: "Kitex command", + RunE: a.Root, } templateCmd := &util.Command{ Use: "template", Short: "Template command", + RunE: a.Template, } initCmd := &util.Command{ Use: "init", @@ -253,8 +335,12 @@ func (a *Arguments) TemplateArgs(version string) error { Short: "Clean command", RunE: a.Clean, } + kitexCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + templateCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVar(&a.InitType, "type", "", "Specify template init type") + initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") @@ -263,23 +349,58 @@ func (a *Arguments) TemplateArgs(version string) error { "Specify a code gen path.") renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") - renderCmd.Flags().VarP(&a.Includes, "Includes", "I", "Add IDL search path and template search path for includes.") - initCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: kitex template init [flags] -`, version) + renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") + templateCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Template operation + +Usage: + kitex template [command] + +Available Commands: + init Initialize the templates according to the type + render Render the template files + clean Clean the debug templates + `) + }) + initCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Initialize the templates according to the type + +Usage: + kitex template init [flags] + +Flags: + -o, --output string Output directory + -t, --type string The init type of the template + `) }) - renderCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: template render --dir [template dir_path] [flags] IDL -`, version) + renderCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Render the template files + +Usage: + kitex template render [flags] + +Flags: + --dir string Output directory + --debug bool Turn on the debug mode + --file stringArray Specify multiple files for render + -I, --Includes string Add an template git search path for includes. + --meta string Specify meta data for render + --module string Specify the Go module name to generate go.mod. + -t, --type string The init type of the template + `) }) - cleanCmd.SetUsageFunc(func() { - fmt.Fprintf(os.Stderr, `Version %s -Usage: kitex template clean -`, version) + cleanCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Clean the debug templates + +Usage: + kitex template clean + `) }) - // renderCmd.PrintUsage() templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) kitexCmd.AddCommand(templateCmd) if _, err := kitexCmd.ExecuteC(); err != nil { diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 1fb5372ffe..0777986bd0 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -77,9 +77,9 @@ func main() { log.Warn("Get current path failed:", err.Error()) os.Exit(1) } - if os.Args[1] == "template" { + if len(os.Args) > 1 && os.Args[1] == "template" { err = args.TemplateArgs(kitex.Version) - } else if !strings.HasPrefix(os.Args[1], "-") { + } else if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { err = fmt.Errorf("unknown command %q", os.Args[1]) } else { // run as kitex diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 52306dc8b1..54530713bf 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -200,6 +200,20 @@ func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err err return fs, nil } +func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { + cg := NewCustomGenerator(pkg, outputPath) + // special handling Methods field + if tpl.LoopMethod { + err = cg.loopGenerate(tpl) + } else { + err = cg.commonGenerate(tpl) + } + if errors.Is(err, errNoNewMethod) { + err = nil + } + return cg.fs, err +} + func readTemplates(dir string) ([]*Template, error) { files, _ := ioutil.ReadDir(dir) var ts []*Template @@ -224,26 +238,86 @@ func readTemplates(dir string) ([]*Template, error) { return ts, nil } -func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, err error) { - cg := NewCustomGenerator(pkg, outputPath) - // special handling Methods field - if tpl.LoopMethod { - err = cg.loopGenerate(tpl) - } else { - err = cg.commonGenerate(tpl) +// parseMeta parses the meta flag and returns a map where the value is a slice of strings +func parseMeta(metaFlags string) (map[string][]string, error) { + meta := make(map[string][]string) + if metaFlags == "" { + return meta, nil } - if errors.Is(err, errNoNewMethod) { - err = nil + + // split for each key=value pairs + pairs := strings.Split(metaFlags, ";") + for _, pair := range pairs { + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + key := kv[0] + values := strings.Split(kv[1], ",") + meta[key] = values + } else { + return nil, fmt.Errorf("Invalid meta format: %s\n", pair) + } } - return cg.fs, err + return meta, nil +} + +func parseMiddlewares(middlewares []MiddlewareForResolve) ([]UserDefinedMiddleware, error) { + var mwList []UserDefinedMiddleware + + for _, mw := range middlewares { + content, err := os.ReadFile(mw.Path) + if err != nil { + return nil, fmt.Errorf("failed to read middleware file %s: %v", mw.Path, err) + } + mwList = append(mwList, UserDefinedMiddleware{ + Name: mw.Name, + Content: string(content), + }) + } + return mwList, nil } func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { g.updatePackageInfo(pkg) g.setImports(HandlerFileName, pkg) - var tpls []*Template - tpls, err = readTpls(g.RenderTplDir, g.RenderTplDir, tpls) + pkg.ExtendMeta, err = parseMeta(g.MetaFlags) + if err != nil { + return nil, err + } + if g.Config.IncludesTpl != "" { + inc := g.Config.IncludesTpl + if strings.HasPrefix(inc, "git@") || strings.HasPrefix(inc, "http://") || strings.HasPrefix(inc, "https://") { + localGitPath, errMsg, gitErr := util.RunGitCommand(inc) + if gitErr != nil { + if errMsg == "" { + errMsg = gitErr.Error() + } + return nil, fmt.Errorf("failed to pull IDL from git:%s\nYou can execute 'rm -rf ~/.kitex' to clean the git cache and try again", errMsg) + } + if g.RenderTplDir != "" { + g.RenderTplDir = filepath.Join(localGitPath, g.RenderTplDir) + } else { + g.RenderTplDir = localGitPath + } + if util.Exists(g.RenderTplDir) { + return nil, fmt.Errorf("the render template directory path you specified does not exists int the git path") + } + } + } + var meta *Meta + metaPath := filepath.Join(g.RenderTplDir, kitexRenderMetaFile) + if util.Exists(metaPath) { + meta, err = readMetaFile(metaPath) + if err != nil { + return nil, err + } + middlewares, err := parseMiddlewares(meta.MWs) + if err != nil { + return nil, err + } + pkg.MWs = middlewares + } + tpls, err := readTpls(g.RenderTplDir, g.RenderTplDir, meta) if err != nil { return nil, err } @@ -278,13 +352,25 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, const kitexRenderMetaFile = "kitex_render_meta.yaml" -// Meta 代表kitex_render_meta.yaml文件的结构 +// Meta represents the structure of the kitex_render_meta.yaml file. type Meta struct { - Templates []Template `yaml:"templates"` + Templates []Template `yaml:"templates"` + MWs []MiddlewareForResolve `yaml:"middlewares"` + ExtendMeta []ExtendMeta `yaml:"extend_meta"` +} + +type MiddlewareForResolve struct { + // name of the middleware + Name string `yaml:"name"` + // path of the middleware + Path string `yaml:"path"` } -func readMetaFile(rootDir string) (*Meta, error) { - metaPath := filepath.Join(rootDir, kitexRenderMetaFile) +type ExtendMeta struct { + key string +} + +func readMetaFile(metaPath string) (*Meta, error) { metaData, err := os.ReadFile(metaPath) if err != nil { return nil, fmt.Errorf("failed to read meta file from %s: %v", metaPath, err) @@ -299,27 +385,28 @@ func readMetaFile(rootDir string) (*Meta, error) { return &meta, nil } -func getUpdateBehavior(meta *Meta, relativePath string) *Update { - for _, template := range meta.Templates { - // fmt.Println(template.Path, "==", relativePath) - if template.Path == relativePath { - return template.UpdateBehavior +func getMetadata(meta *Meta, relativePath string) *Template { + for i := range meta.Templates { + if meta.Templates[i].Path == relativePath { + return &meta.Templates[i] } } - return &Update{Type: string(skip)} + return &Template{ + UpdateBehavior: &Update{Type: string(skip)}, + } } -func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { - meta, err := readMetaFile(rootDir) - if err != nil { - return nil, err +func readTpls(rootDir, currentDir string, meta *Meta) (ts []*Template, error error) { + defaultMetadata := &Template{ + UpdateBehavior: &Update{Type: string(skip)}, } + files, _ := os.ReadDir(currentDir) for _, f := range files { // filter dir and non-tpl files if f.IsDir() { subDir := filepath.Join(currentDir, f.Name()) - subTemplates, err := readTpls(rootDir, subDir, ts) + subTemplates, err := readTpls(rootDir, subDir, meta) if err != nil { return nil, err } @@ -336,11 +423,19 @@ func readTpls(rootDir, currentDir string, ts []*Template) ([]*Template, error) { return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) } trimmedPath := strings.TrimSuffix(relativePath, ".tpl") - updateBehavior := getUpdateBehavior(meta, relativePath) + // If kitex_render_meta.yaml exists, get the corresponding metadata; otherwise, use the default metadata + var metadata *Template + if meta != nil { + metadata = getMetadata(meta, relativePath) + } else { + metadata = defaultMetadata + } t := &Template{ Path: trimmedPath, Body: string(tplData), - UpdateBehavior: updateBehavior, + UpdateBehavior: metadata.UpdateBehavior, + LoopMethod: metadata.LoopMethod, + LoopService: metadata.LoopService, } ts = append(ts, t) } @@ -357,7 +452,7 @@ func (g *generator) RenderWithMultipleFiles(pkg *PackageInfo) (fs []*File, err e } var updatedContent string if g.Config.DebugTpl { - // --debug时 在模板内容顶部加上一段magic string用于区分 + // when --debug is enabled, add a magic string at the top of the template content for distinction. updatedContent = "// Kitex template debug file. use template clean to delete it.\n\n" + string(content) } else { updatedContent = string(content) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index c7478887c0..0a1966c421 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -137,11 +137,13 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string // specify the location path of init subcommand - InitType string + InitOutputDir string // specify the location path of init subcommand + InitType string // specify the type for init subcommand RenderTplDir string // specify the path of template directory for render subcommand TemplateFiles []string // specify the path of single file or multiple file to render - DebugTpl bool + DebugTpl bool // turn on the debug mode + IncludesTpl string // specify the path of remote template repository for render subcommand + MetaFlags string // Metadata in key=value format, keys separated by ';' values separated by ',' GenPath string diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index ebb5245b70..0cba0caca8 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -61,6 +61,8 @@ func TestConfig_Pack(t *testing.T) { RenderTplDir string TemplateFiles []string DebugTpl bool + IncludesTpl string + MetaFlags string Protocol string HandlerReturnKeepResp bool } @@ -74,7 +76,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -106,6 +108,8 @@ func TestConfig_Pack(t *testing.T) { InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, TemplateFiles: tt.fields.TemplateFiles, + IncludesTpl: tt.fields.IncludesTpl, + MetaFlags: tt.fields.MetaFlags, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 27ba00a711..01d149f372 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -52,10 +52,10 @@ type PackageInfo struct { IDLName string ServerPkg string ExtendMeta map[string][]string // key-value metadata for render - MWs []GlobalMiddleware + MWs []UserDefinedMiddleware } -type GlobalMiddleware struct { +type UserDefinedMiddleware struct { // the name of the middleware Name string // the content of the middleware diff --git a/tool/internal_pkg/tpl/mock.go b/tool/internal_pkg/tpl/mock.go deleted file mode 100644 index 27da3a59d7..0000000000 --- a/tool/internal_pkg/tpl/mock.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2024 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tpl - -var ClientMockTpl = ` -// Code generated by MockGen. DO NOT EDIT. -// Source: kitex_gen/example/shop/item/itemserviceb/client.go -// -// Generated by this command: -// -// ___1go_build_go_uber_org_mock_mockgen -source=kitex_gen/example/shop/item/itemserviceb/client.go -destination=client_mock.go -package=main -// - -// Package main is a generated GoMock package. -` diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index fda97822d3..35fe4b609d 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -15,7 +15,9 @@ package util import ( + "errors" "fmt" + "io" "os" "strings" ) @@ -29,7 +31,7 @@ type Command struct { parent *Command flags *FlagSet // helpFunc is help func defined by user. - usage func() + helpFunc func(*Command, []string) // for debug args []string } @@ -161,12 +163,43 @@ func (c *Command) ParseFlags(args []string) error { return err } -func (c *Command) SetUsageFunc(f func()) { - c.usage = f +// SetHelpFunc sets help function. Can be defined by Application. +func (c *Command) SetHelpFunc(f func(*Command, []string)) { + c.helpFunc = f } -func (c *Command) UsageFunc() func() { - return c.usage +// HelpFunc returns either the function set by SetHelpFunc for this command +// or a parent, or it returns a function with default help behavior. +func (c *Command) HelpFunc() func(*Command, []string) { + if c.helpFunc != nil { + return c.helpFunc + } + if c.HasParent() { + return c.parent.HelpFunc() + } + return nil +} + +// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrln(i ...interface{}) { + c.PrintErr(fmt.Sprintln(i...)) +} + +// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErr(i ...interface{}) { + fmt.Fprint(c.ErrOrStderr(), i...) +} + +// ErrOrStderr returns output to stderr +func (c *Command) ErrOrStderr() io.Writer { + return c.getErr(os.Stderr) +} + +func (c *Command) getErr(def io.Writer) io.Writer { + if c.HasParent() { + return c.parent.getErr(def) + } + return def } // ExecuteC executes the command. @@ -181,8 +214,16 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { } err = cmd.execute(flags) if err != nil { - cmd.usage() + // Always show help if requested, even if SilenceErrors is in + // effect + if errors.Is(err, ErrHelp) { + cmd.HelpFunc()(cmd, args) + return cmd, nil + } } + //if err != nil { + // cmd.usage() + //} return cmd, err } diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go index 9bb5723d67..ae83875b52 100644 --- a/tool/internal_pkg/util/command_test.go +++ b/tool/internal_pkg/util/command_test.go @@ -191,19 +191,14 @@ func TestFlagLong(t *testing.T) { RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, } - var intFlagValue int var stringFlagValue string - c.Flags().IntVar(&intFlagValue, "intf", -1, "") c.Flags().StringVar(&stringFlagValue, "sf", "", "") - err := executeCommand(c, "--intf=7", "--sf=abc", "one", "--", "two") + err := executeCommand(c, "--sf=abc", "one", "--", "two") if err != nil { t.Errorf("Unexpected error: %v", err) } - if intFlagValue != 7 { - t.Errorf("Expected intFlagValue: %v, got %v", 7, intFlagValue) - } if stringFlagValue != "abc" { t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } @@ -221,19 +216,14 @@ func TestFlagShort(t *testing.T) { RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, } - var intFlagValue int var stringFlagValue string - c.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") - err := executeCommand(c, "-i", "7", "-sabc", "one", "two") + err := executeCommand(c, "-sabc", "one", "two") if err != nil { t.Errorf("Unexpected error: %v", err) } - if intFlagValue != 7 { - t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) - } if stringFlagValue != "abc" { t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) } @@ -248,16 +238,8 @@ func TestChildFlag(t *testing.T) { rootCmd := &Command{Use: "root", RunE: emptyRun} childCmd := &Command{Use: "child", RunE: emptyRun} rootCmd.AddCommand(childCmd) - - var intFlagValue int - childCmd.Flags().IntVarP(&intFlagValue, "intf", "i", -1, "") - - err := executeCommand(rootCmd, "child", "-i7") + err := executeCommand(rootCmd, "child") if err != nil { t.Errorf("Unexpected error: %v", err) } - - if intFlagValue != 7 { - t.Errorf("Expected flag value: %v, got %v", 7, intFlagValue) - } } diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index d80a7ac6cf..3c936b72ef 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "os" + "sort" "strconv" "strings" ) @@ -71,6 +72,7 @@ type FlagSet struct { orderedActual []*Flag formal map[NormalizedName]*Flag orderedFormal []*Flag + sortedFormal []*Flag shorthands map[byte]*Flag args []string // arguments after flags argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- @@ -100,6 +102,7 @@ type Flag struct { type Value interface { String() string Set(string) error + Type() string } // SliceValue is a secondary interface to all flags which hold a list @@ -114,6 +117,22 @@ type SliceValue interface { GetSlice() []string } +// sortFlags returns the flags as a slice in lexicographical sorted order. +func sortFlags(flags map[NormalizedName]*Flag) []*Flag { + list := make(sort.StringSlice, len(flags)) + i := 0 + for k := range flags { + list[i] = string(k) + i++ + } + list.Sort() + result := make([]*Flag, len(list)) + for i, name := range list { + result[i] = flags[NormalizedName(name)] + } + return result +} + // GetNormalizeFunc returns the previously set NormalizeFunc of a function which // does no translation, if not set previously. func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { @@ -186,6 +205,201 @@ func (f *FlagSet) Set(name, value string) error { return nil } +func (f *FlagSet) VisitAll(fn func(*Flag)) { + if len(f.formal) == 0 { + return + } + + var flags []*Flag + if f.SortFlags { + if len(f.formal) != len(f.sortedFormal) { + f.sortedFormal = sortFlags(f.formal) + } + flags = f.sortedFormal + } else { + flags = f.orderedFormal + } + + for _, flag := range flags { + fn(flag) + } +} + +func UnquoteUsage(flag *Flag) (name, usage string) { + // Look for a back-quoted name, but avoid the strings package. + usage = flag.Usage + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name = usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break // Only one back quote; use type name. + } + } + + name = flag.Value.Type() + switch name { + case "bool": + name = "" + case "float64": + name = "float" + case "int64": + name = "int" + case "uint64": + name = "uint" + case "stringSlice": + name = "strings" + case "intSlice": + name = "ints" + case "uintSlice": + name = "uints" + case "boolSlice": + name = "bools" + } + + return +} + +func (f *FlagSet) FlagUsagesWrapped(cols int) string { + buf := new(bytes.Buffer) + + lines := make([]string, 0, len(f.formal)) + + maxlen := 0 + f.VisitAll(func(flag *Flag) { + line := "" + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) + } else { + line = fmt.Sprintf(" --%s", flag.Name) + } + + varname, usage := UnquoteUsage(flag) + if varname != "" { + line += " " + varname + } + if flag.NoOptDefVal != "" { + switch flag.Value.Type() { + case "string": + line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) + case "bool": + if flag.NoOptDefVal != "true" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + case "count": + if flag.NoOptDefVal != "+1" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + default: + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + } + + // This special character will be replaced with spacing once the + // correct alignment is calculated + line += "\x00" + if len(line) > maxlen { + maxlen = len(line) + } + + line += usage + if len(flag.Deprecated) != 0 { + line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) + } + + lines = append(lines, line) + }) + + for _, line := range lines { + sidx := strings.Index(line, "\x00") + spacing := strings.Repeat(" ", maxlen-sidx) + // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx + fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) + } + + return buf.String() +} + +func wrap(i, w int, s string) string { + if w == 0 { + return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + // space between indent i and end of line width w into which + // we should wrap the text. + wrap := w - i + + var r, l string + + // Not enough space for sensible wrapping. Wrap as a block on + // the next line instead. + if wrap < 24 { + i = 16 + wrap = w - i + r += "\n" + strings.Repeat(" ", i) + } + // If still not enough space then don't even try to wrap. + if wrap < 24 { + return strings.Replace(s, "\n", r, -1) + } + + // Try to avoid short orphan words on the final line, by + // allowing wrapN to go a bit over if that would fit in the + // remainder of the line. + slop := 5 + wrap = wrap - slop + + // Handle first line, which is indented by the caller (or the + // special case above) + l, s = wrapN(wrap, slop, s) + r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) + + // Now wrap the rest + for s != "" { + var t string + + t, s = wrapN(wrap, slop, s) + r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + return r +} + +func wrapN(i, slop int, s string) (string, string) { + if i+slop > len(s) { + return s, "" + } + + w := strings.LastIndexAny(s[:i], " \t\n") + if w <= 0 { + return s, "" + } + nlPos := strings.LastIndex(s[:i], "\n") + if nlPos > 0 && nlPos < w { + return s[:nlPos], s[nlPos+1:] + } + return s[:w], s[w+1:] +} + +func (f *FlagSet) FlagUsages() string { + return f.FlagUsagesWrapped(0) +} + +func (f *FlagSet) PrintDefaults() { + usages := f.FlagUsages() + fmt.Fprint(f.out(), usages) +} + +// defaultUsage is the default function to print a usage message. +func defaultUsage(f *FlagSet) { + fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) + f.PrintDefaults() +} + // Args returns the non-flag arguments. func (f *FlagSet) Args() []string { return f.args } @@ -255,6 +469,7 @@ func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) if f.errorHandling != ContinueOnError { fmt.Fprintln(f.out(), err) + f.usage() } return err } @@ -296,6 +511,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if !exists { switch { case name == "help": + f.usage() return a, ErrHelp case f.ParseErrorsWhitelist.UnknownFlags: // --unknown=unknownval arg ... @@ -349,6 +565,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse if !exists { switch { case c == 'h': + f.usage() err = ErrHelp return case f.ParseErrorsWhitelist.UnknownFlags: @@ -501,6 +718,28 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { return f } +// PrintDefaults prints to standard error the default values of all defined command-line flags. +func PrintDefaults() { + CommandLine.PrintDefaults() +} + +var Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + PrintDefaults() +} + +// usage calls the Usage method for the flag set, or the usage function if +// the flag set is CommandLine. +func (f *FlagSet) usage() { + if f == CommandLine { + Usage() + } else if f.Usage == nil { + defaultUsage(f) + } else { + f.Usage() + } +} + // -- string Value type stringValue string @@ -516,6 +755,10 @@ func (s *stringValue) Set(val string) error { func (s *stringValue) String() string { return string(*s) } +func (s *stringValue) Type() string { + return "string" +} + // StringVar defines a string flag with specified name, default value, and usage string. func (f *FlagSet) StringVar(p *string, name, value, usage string) { f.VarP(newStringValue(value, p), name, "", usage) @@ -557,6 +800,10 @@ func (b *boolValue) Set(s string) error { func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } +func (b *boolValue) Type() string { + return "bool" +} + func (b *boolValue) IsBoolFlag() bool { return true } // BoolVar defines a bool flag with specified name, default value, and usage string. @@ -592,39 +839,6 @@ func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool return p } -// -- int Value -type intValue int - -func newIntValue(val int, p *int) *intValue { - *p = val - return (*intValue)(p) -} - -func (i *intValue) Set(s string) error { - v, err := strconv.ParseInt(s, 0, 64) - *i = intValue(v) - return err -} - -func (i *intValue) String() string { return strconv.Itoa(int(*i)) } - -// IntVar defines an int flag with specified name, default value, and usage string. -// The argument p points to an int variable in which to store the value of the flag. -func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { - f.VarP(newIntValue(value, p), name, "", usage) -} - -// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) IntVarP(p *int, name, shorthand string, value int, usage string) { - f.VarP(newIntValue(value, p), name, shorthand, usage) -} - -func (f *FlagSet) Int(name string, value int, usage string) *int { - p := new(int) - f.IntVarP(p, name, "", value, usage) - return p -} - // -- stringArray Value type stringArrayValue struct { value *[]string @@ -664,9 +878,12 @@ func (s *stringArrayValue) String() string { return "[" + str + "]" } +func (s *stringArrayValue) Type() string { + return "stringArray" +} + // StringArrayVar defines a string flag with specified name, default value, and usage string. // The argument p points to a []string variable in which to store the values of the multiple flags. -// The value of each argument will not try to be separated by comma. Use a StringSlice for that. func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { f.VarP(newStringArrayValue(value, p), name, "", usage) } diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go index 364d379c1f..916f906df5 100644 --- a/tool/internal_pkg/util/flag_test.go +++ b/tool/internal_pkg/util/flag_test.go @@ -45,21 +45,13 @@ func testParse(f *FlagSet, t *testing.T) { boolFlag := f.Bool("bool", false, "bool value") bool2Flag := f.Bool("bool2", false, "bool2 value") bool3Flag := f.Bool("bool3", false, "bool3 value") - intFlag := f.Int("int", 0, "int value") stringFlag := f.String("string", "0", "string value") - optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") - f.Lookup("optional-int-no-value").NoOptDefVal = "9" - optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") - f.Lookup("optional-int-no-value").NoOptDefVal = "9" extra := "one-extra-argument" args := []string{ "--bool", "--bool2=true", "--bool3=false", - "--int=22", "--string=hello", - "--optional-int-no-value", - "--optional-int-with-value=42", extra, } if err := f.Parse(args); err != nil { @@ -77,18 +69,9 @@ func testParse(f *FlagSet, t *testing.T) { if *bool3Flag != false { t.Error("bool3 flag should be false, is ", *bool2Flag) } - if *intFlag != 22 { - t.Error("int flag should be 22, is ", *intFlag) - } if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } - if *optionalIntNoValueFlag != 9 { - t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) - } - if *optionalIntWithValueFlag != 42 { - t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) - } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra { From d2ea6a8c02984c6fe0f2474883a6af09c1c5dc88 Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 7 Aug 2024 16:09:06 +0800 Subject: [PATCH 37/41] fix: resolve conflics Signed-off-by: shawn --- internal/mocks/discovery/discovery.go | 3 +- internal/mocks/generic/generic_service.go | 3 +- internal/mocks/generic/thrift.go | 3 +- internal/mocks/klog/log.go | 3 +- internal/mocks/limiter/limiter.go | 3 +- internal/mocks/loadbalance/loadbalancer.go | 3 +- internal/mocks/net/net.go | 5 +- internal/mocks/netpoll/connection.go | 3 +- internal/mocks/proxy/proxy.go | 3 +- internal/mocks/remote/bytebuf.go | 3 +- internal/mocks/remote/codec.go | 3 +- internal/mocks/remote/conn_wrapper.go | 3 +- internal/mocks/remote/connpool.go | 3 +- internal/mocks/remote/dialer.go | 3 +- internal/mocks/remote/payload_codec.go | 3 +- internal/mocks/remote/trans_handler.go | 3 +- internal/mocks/remote/trans_meta.go | 3 +- internal/mocks/remote/trans_pipeline.go | 3 +- internal/mocks/stats/tracer.go | 3 +- internal/mocks/thrift/utils.go | 70 ------------------- internal/mocks/utils/sharedticker.go | 3 +- pkg/generic/map_test/generic_init.go | 8 +-- .../bthrift/apache/application_exception.go | 39 ----------- .../bthrift/apache/binary_protocol.go | 23 ------ pkg/protocol/bthrift/apache/memory_buffer.go | 57 --------------- pkg/protocol/bthrift/apache/messagetype.go | 32 --------- pkg/protocol/bthrift/apache/protocol.go | 33 --------- .../bthrift/apache/protocol_exception.go | 33 --------- pkg/protocol/bthrift/apache/serializer.go | 24 ------- pkg/protocol/bthrift/apache/transport.go | 23 ------ pkg/protocol/bthrift/apache/type.go | 43 ------------ pkg/protocol/bthrift/binary_test.go | 2 +- 32 files changed, 46 insertions(+), 403 deletions(-) delete mode 100644 internal/mocks/thrift/utils.go delete mode 100644 pkg/protocol/bthrift/apache/application_exception.go delete mode 100644 pkg/protocol/bthrift/apache/binary_protocol.go delete mode 100644 pkg/protocol/bthrift/apache/memory_buffer.go delete mode 100644 pkg/protocol/bthrift/apache/messagetype.go delete mode 100644 pkg/protocol/bthrift/apache/protocol.go delete mode 100644 pkg/protocol/bthrift/apache/protocol_exception.go delete mode 100644 pkg/protocol/bthrift/apache/serializer.go delete mode 100644 pkg/protocol/bthrift/apache/transport.go delete mode 100644 pkg/protocol/bthrift/apache/type.go diff --git a/internal/mocks/discovery/discovery.go b/internal/mocks/discovery/discovery.go index e70f64e2e0..2ae32eb94f 100644 --- a/internal/mocks/discovery/discovery.go +++ b/internal/mocks/discovery/discovery.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/discovery/discovery.go @@ -175,4 +176,4 @@ func (m *MockInstance) Weight() int { func (mr *MockInstanceMockRecorder) Weight() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockInstance)(nil).Weight)) -} +} \ No newline at end of file diff --git a/internal/mocks/generic/generic_service.go b/internal/mocks/generic/generic_service.go index 8ca3f06056..bc18989b81 100644 --- a/internal/mocks/generic/generic_service.go +++ b/internal/mocks/generic/generic_service.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/generic_service.go @@ -98,4 +99,4 @@ func (m *MockWithCodec) SetCodec(codec interface{}) { func (mr *MockWithCodecMockRecorder) SetCodec(codec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCodec", reflect.TypeOf((*MockWithCodec)(nil).SetCodec), codec) -} +} \ No newline at end of file diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 8362762539..00f7cce74d 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/generic/thrift/thrift.go @@ -102,4 +103,4 @@ func (m *MockMessageWriter) Write(ctx context.Context, out io.Writer, msg interf func (mr *MockMessageWriterMockRecorder) Write(ctx, out, msg, requestBase interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockMessageWriter)(nil).Write), ctx, out, msg, requestBase) -} +} \ No newline at end of file diff --git a/internal/mocks/klog/log.go b/internal/mocks/klog/log.go index 2c83208a32..2b6413079b 100644 --- a/internal/mocks/klog/log.go +++ b/internal/mocks/klog/log.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/klog/log.go @@ -890,4 +891,4 @@ func (mr *MockFullLoggerMockRecorder) Warnf(format interface{}, v ...interface{} mr.mock.ctrl.T.Helper() varargs := append([]interface{}{format}, v...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockFullLogger)(nil).Warnf), varargs...) -} +} \ No newline at end of file diff --git a/internal/mocks/limiter/limiter.go b/internal/mocks/limiter/limiter.go index 5fac2ad3e3..97cec6a9cf 100644 --- a/internal/mocks/limiter/limiter.go +++ b/internal/mocks/limiter/limiter.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/limiter/limiter.go @@ -225,4 +226,4 @@ func (m *MockLimitReporter) QPSOverloadReport() { func (mr *MockLimitReporterMockRecorder) QPSOverloadReport() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QPSOverloadReport", reflect.TypeOf((*MockLimitReporter)(nil).QPSOverloadReport)) -} +} \ No newline at end of file diff --git a/internal/mocks/loadbalance/loadbalancer.go b/internal/mocks/loadbalance/loadbalancer.go index 63e4fec5bb..e39fa9bc1c 100644 --- a/internal/mocks/loadbalance/loadbalancer.go +++ b/internal/mocks/loadbalance/loadbalancer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/loadbalance/loadbalancer.go @@ -162,4 +163,4 @@ func (m *MockRebalancer) Rebalance(arg0 discovery.Change) { func (mr *MockRebalancerMockRecorder) Rebalance(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rebalance", reflect.TypeOf((*MockRebalancer)(nil).Rebalance), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/net/net.go b/internal/mocks/net/net.go index 7ada9f882b..191cd25b19 100644 --- a/internal/mocks/net/net.go +++ b/internal/mocks/net/net.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: /usr/local/go/src/net/net.go @@ -21,8 +22,8 @@ package net import ( - net "net" reflect "reflect" + net "net" time "time" gomock "github.com/golang/mock/gomock" @@ -581,4 +582,4 @@ func (m *MockbuffersWriter) writeBuffers(arg0 *net.Buffers) (int64, error) { func (mr *MockbuffersWriterMockRecorder) writeBuffers(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "writeBuffers", reflect.TypeOf((*MockbuffersWriter)(nil).writeBuffers), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/netpoll/connection.go b/internal/mocks/netpoll/connection.go index 325fff2af7..94c932557e 100644 --- a/internal/mocks/netpoll/connection.go +++ b/internal/mocks/netpoll/connection.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../../netpoll/connection.go @@ -560,4 +561,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} +} \ No newline at end of file diff --git a/internal/mocks/proxy/proxy.go b/internal/mocks/proxy/proxy.go index 2e5ff5fb13..4d15434856 100644 --- a/internal/mocks/proxy/proxy.go +++ b/internal/mocks/proxy/proxy.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/proxy/proxy.go @@ -191,4 +192,4 @@ func (m *MockContextHandler) HandleContext(arg0 context.Context) context.Context func (mr *MockContextHandlerMockRecorder) HandleContext(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleContext", reflect.TypeOf((*MockContextHandler)(nil).HandleContext), arg0) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/bytebuf.go b/internal/mocks/remote/bytebuf.go index 0795e0fa78..24a4d7b9af 100644 --- a/internal/mocks/remote/bytebuf.go +++ b/internal/mocks/remote/bytebuf.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/bytebuf.go @@ -437,4 +438,4 @@ func (m *MockByteBuffer) WriteString(s string) (int, error) { func (mr *MockByteBufferMockRecorder) WriteString(s interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockByteBuffer)(nil).WriteString), s) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/codec.go b/internal/mocks/remote/codec.go index aaed664113..56593b0907 100644 --- a/internal/mocks/remote/codec.go +++ b/internal/mocks/remote/codec.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/codec.go @@ -193,4 +194,4 @@ func (m *MockMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, func (mr *MockMetaDecoderMockRecorder) DecodePayload(ctx, msg, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePayload", reflect.TypeOf((*MockMetaDecoder)(nil).DecodePayload), ctx, msg, in) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/conn_wrapper.go b/internal/mocks/remote/conn_wrapper.go index e57a75d8c6..03d159e9a5 100644 --- a/internal/mocks/remote/conn_wrapper.go +++ b/internal/mocks/remote/conn_wrapper.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/remotecli/conn_wrapper.go @@ -60,4 +61,4 @@ func (m *MockConnReleaser) ReleaseConn(err error, ri rpcinfo.RPCInfo) { func (mr *MockConnReleaserMockRecorder) ReleaseConn(err, ri interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseConn", reflect.TypeOf((*MockConnReleaser)(nil).ReleaseConn), err, ri) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/connpool.go b/internal/mocks/remote/connpool.go index fcdfd5575c..7fbf882329 100644 --- a/internal/mocks/remote/connpool.go +++ b/internal/mocks/remote/connpool.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/connpool.go @@ -308,4 +309,4 @@ func (m *MockIsActive) IsActive() bool { func (mr *MockIsActiveMockRecorder) IsActive() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsActive", reflect.TypeOf((*MockIsActive)(nil).IsActive)) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/dialer.go b/internal/mocks/remote/dialer.go index 90ff6960bd..cebd08a0b0 100644 --- a/internal/mocks/remote/dialer.go +++ b/internal/mocks/remote/dialer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/dialer.go @@ -64,4 +65,4 @@ func (m *MockDialer) DialTimeout(network, address string, timeout time.Duration) func (mr *MockDialerMockRecorder) DialTimeout(network, address, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTimeout", reflect.TypeOf((*MockDialer)(nil).DialTimeout), network, address, timeout) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/payload_codec.go b/internal/mocks/remote/payload_codec.go index 2adda83bf2..1db0d6128a 100644 --- a/internal/mocks/remote/payload_codec.go +++ b/internal/mocks/remote/payload_codec.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/payload_codec.go @@ -91,4 +92,4 @@ func (m *MockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message func (mr *MockPayloadCodecMockRecorder) Unmarshal(ctx, message, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmarshal", reflect.TypeOf((*MockPayloadCodec)(nil).Unmarshal), ctx, message, in) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_handler.go b/internal/mocks/remote/trans_handler.go index ad32a17a14..210ef935fa 100644 --- a/internal/mocks/remote/trans_handler.go +++ b/internal/mocks/remote/trans_handler.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_handler.go @@ -570,4 +571,4 @@ func (m *MockGracefulShutdown) GracefulShutdown(ctx context.Context) error { func (mr *MockGracefulShutdownMockRecorder) GracefulShutdown(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GracefulShutdown", reflect.TypeOf((*MockGracefulShutdown)(nil).GracefulShutdown), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_meta.go b/internal/mocks/remote/trans_meta.go index 71297e4616..38153da94d 100644 --- a/internal/mocks/remote/trans_meta.go +++ b/internal/mocks/remote/trans_meta.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_meta.go @@ -132,4 +133,4 @@ func (m *MockStreamingMetaHandler) OnReadStream(ctx context.Context) (context.Co func (mr *MockStreamingMetaHandlerMockRecorder) OnReadStream(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReadStream", reflect.TypeOf((*MockStreamingMetaHandler)(nil).OnReadStream), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/remote/trans_pipeline.go b/internal/mocks/remote/trans_pipeline.go index 60d89fc071..687e1a5fac 100644 --- a/internal/mocks/remote/trans_pipeline.go +++ b/internal/mocks/remote/trans_pipeline.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/remote/trans_pipeline.go @@ -267,4 +268,4 @@ func (m *MockDuplexBoundHandler) Write(ctx context.Context, conn net.Conn, send func (mr *MockDuplexBoundHandlerMockRecorder) Write(ctx, conn, send interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDuplexBoundHandler)(nil).Write), ctx, conn, send) -} +} \ No newline at end of file diff --git a/internal/mocks/stats/tracer.go b/internal/mocks/stats/tracer.go index 4f8da5476d..0734297631 100644 --- a/internal/mocks/stats/tracer.go +++ b/internal/mocks/stats/tracer.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/stats/tracer.go @@ -74,4 +75,4 @@ func (m *MockTracer) Start(ctx context.Context) context.Context { func (mr *MockTracerMockRecorder) Start(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTracer)(nil).Start), ctx) -} +} \ No newline at end of file diff --git a/internal/mocks/thrift/utils.go b/internal/mocks/thrift/utils.go deleted file mode 100644 index 5d080bc513..0000000000 --- a/internal/mocks/thrift/utils.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "errors" - "io" - - "github.com/cloudwego/gopkg/protocol/thrift" - - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// ApacheCodecAdapter converts a fastcodec struct to apache codec -type ApacheCodecAdapter struct { - p thrift.FastCodec -} - -// Write implements athrift.TStruct -func (p ApacheCodecAdapter) Write(tp athrift.TProtocol) error { - b := make([]byte, p.p.BLength()) - b = b[:p.p.FastWriteNocopy(b, nil)] - _, err := tp.Transport().Write(b) - return err -} - -// Read implements athrift.TStruct -func (p ApacheCodecAdapter) Read(tp athrift.TProtocol) error { - var err error - var b []byte - trans := tp.Transport() - n := trans.RemainingBytes() - if int64(n) < 0 { - return errors.New("unknown buffer len") - } - b = make([]byte, n) - _, err = io.ReadFull(trans, b) - if err == nil { - _, err = p.p.FastRead(b) - } - return err -} - -// ToApacheCodec converts a thrift.FastCodec to athrift.TStruct -func ToApacheCodec(p thrift.FastCodec) athrift.TStruct { - return ApacheCodecAdapter{p: p} -} - -// UnpackApacheCodec unpacks the value returned by `ToApacheCodec` -func UnpackApacheCodec(v interface{}) interface{} { - a, ok := v.(ApacheCodecAdapter) - if ok { - return a.p - } - return v -} diff --git a/internal/mocks/utils/sharedticker.go b/internal/mocks/utils/sharedticker.go index a7ac22a7db..996951cbd1 100644 --- a/internal/mocks/utils/sharedticker.go +++ b/internal/mocks/utils/sharedticker.go @@ -14,6 +14,7 @@ * limitations under the License. */ + // Code generated by MockGen. DO NOT EDIT. // Source: ../../pkg/utils/sharedticker.go @@ -59,4 +60,4 @@ func (m *MockTickerTask) Tick() { func (mr *MockTickerTaskMockRecorder) Tick() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tick", reflect.TypeOf((*MockTickerTask)(nil).Tick)) -} +} \ No newline at end of file diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index 7ccdc17499..2b92a303a3 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -205,16 +205,16 @@ func serviceInfo() *serviceinfo.ServiceInfo { } func newMockTestArgs() interface{} { - return kt.ToApacheCodec(kt.NewMockTestArgs()) + return kt.NewMockTestArgs() } func newMockTestResult() interface{} { - return kt.ToApacheCodec(kt.NewMockTestResult()) + return kt.NewMockTestResult() } func testHandler(ctx context.Context, handler, arg, result interface{}) error { - realArg := kt.UnpackApacheCodec(arg).(*kt.MockTestArgs) - realResult := kt.UnpackApacheCodec(result).(*kt.MockTestResult) + realArg := arg.(*kt.MockTestArgs) + realResult := result.(*kt.MockTestResult) success, err := handler.(kt.Mock).Test(ctx, realArg.Req) if err != nil { return err diff --git a/pkg/protocol/bthrift/apache/application_exception.go b/pkg/protocol/bthrift/apache/application_exception.go deleted file mode 100644 index ac02e24cef..0000000000 --- a/pkg/protocol/bthrift/apache/application_exception.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/application_exception.go - -const ( - UNKNOWN_APPLICATION_EXCEPTION = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE_EXCEPTION = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - INVALID_TRANSFORM = 8 - INVALID_PROTOCOL = 9 - UNSUPPORTED_CLIENT_TYPE = 10 -) - -type TApplicationException = thrift.TApplicationException - -var NewTApplicationException = thrift.NewTApplicationException diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go deleted file mode 100644 index 2a2a4538b2..0000000000 --- a/pkg/protocol/bthrift/apache/binary_protocol.go +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/binary_protocol.go - -var NewTBinaryProtocol = thrift.NewTBinaryProtocol diff --git a/pkg/protocol/bthrift/apache/memory_buffer.go b/pkg/protocol/bthrift/apache/memory_buffer.go deleted file mode 100644 index 10a0af751f..0000000000 --- a/pkg/protocol/bthrift/apache/memory_buffer.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import ( - "bytes" - "context" -) - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/memory_buffer.go - -// Memory buffer-based implementation of the TTransport interface. -type TMemoryBuffer struct { - *bytes.Buffer - size int -} - -func NewTMemoryBufferLen(size int) *TMemoryBuffer { - buf := make([]byte, 0, size) - return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} -} - -func (p *TMemoryBuffer) IsOpen() bool { - return true -} - -func (p *TMemoryBuffer) Open() error { - return nil -} - -func (p *TMemoryBuffer) Close() error { - p.Buffer.Reset() - return nil -} - -// Flushing a memory buffer is a no-op -func (p *TMemoryBuffer) Flush(ctx context.Context) error { - return nil -} - -func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { - return uint64(p.Buffer.Len()) -} diff --git a/pkg/protocol/bthrift/apache/messagetype.go b/pkg/protocol/bthrift/apache/messagetype.go deleted file mode 100644 index 1885144aee..0000000000 --- a/pkg/protocol/bthrift/apache/messagetype.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/messagetype.go - -// Message type constants in the Thrift protocol. -type TMessageType = thrift.TMessageType - -const ( - INVALID_TMESSAGE_TYPE TMessageType = 0 - CALL TMessageType = 1 - REPLY TMessageType = 2 - EXCEPTION TMessageType = 3 - ONEWAY TMessageType = 4 -) diff --git a/pkg/protocol/bthrift/apache/protocol.go b/pkg/protocol/bthrift/apache/protocol.go deleted file mode 100644 index 9d0a991d96..0000000000 --- a/pkg/protocol/bthrift/apache/protocol.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol.go - -const ( - VERSION_MASK = 0xffff0000 - VERSION_1 = 0x80010000 -) - -type TProtocol = thrift.TProtocol - -// The maximum recursive depth the skip() function will traverse -const DEFAULT_RECURSION_DEPTH = 64 - -var SkipDefaultDepth = thrift.SkipDefaultDepth diff --git a/pkg/protocol/bthrift/apache/protocol_exception.go b/pkg/protocol/bthrift/apache/protocol_exception.go deleted file mode 100644 index 7b020797f5..0000000000 --- a/pkg/protocol/bthrift/apache/protocol_exception.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol_exception.go - -var NewTProtocolExceptionWithType = thrift.NewTProtocolExceptionWithType - -const ( - UNKNOWN_PROTOCOL_EXCEPTION = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - NOT_IMPLEMENTED = 5 - DEPTH_LIMIT = 6 -) diff --git a/pkg/protocol/bthrift/apache/serializer.go b/pkg/protocol/bthrift/apache/serializer.go deleted file mode 100644 index c255250301..0000000000 --- a/pkg/protocol/bthrift/apache/serializer.go +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/serializer.go - -type TStruct interface { - Write(p TProtocol) error - Read(p TProtocol) error -} diff --git a/pkg/protocol/bthrift/apache/transport.go b/pkg/protocol/bthrift/apache/transport.go deleted file mode 100644 index 25a752ae52..0000000000 --- a/pkg/protocol/bthrift/apache/transport.go +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/transport.go - -type TTransport = thrift.TTransport diff --git a/pkg/protocol/bthrift/apache/type.go b/pkg/protocol/bthrift/apache/type.go deleted file mode 100644 index 42533b085e..0000000000 --- a/pkg/protocol/bthrift/apache/type.go +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/type.go - -type TType = thrift.TType - -const ( - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 -) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index a0754bcd55..ba86bd6a53 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -21,8 +21,8 @@ import ( "fmt" "testing" - "github.com/cloudwego/kitex/internal/test" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift/internal/test" ) // TestWriteMessageEnd test binary WriteMessageEnd function From f2b5d2426c1e74d9c6239f56cae70eb0ed95a22e Mon Sep 17 00:00:00 2001 From: shawn Date: Wed, 7 Aug 2024 16:12:03 +0800 Subject: [PATCH 38/41] fix: resolve conflics Signed-off-by: shawn --- tool/internal_pkg/generator/custom_template.go | 11 +++-------- tool/internal_pkg/tpl/multiple_services.go | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 54530713bf..0b4a32ab84 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -254,7 +254,7 @@ func parseMeta(metaFlags string) (map[string][]string, error) { values := strings.Split(kv[1], ",") meta[key] = values } else { - return nil, fmt.Errorf("Invalid meta format: %s\n", pair) + return nil, fmt.Errorf("invalid meta format: %s", pair) } } return meta, nil @@ -354,9 +354,8 @@ const kitexRenderMetaFile = "kitex_render_meta.yaml" // Meta represents the structure of the kitex_render_meta.yaml file. type Meta struct { - Templates []Template `yaml:"templates"` - MWs []MiddlewareForResolve `yaml:"middlewares"` - ExtendMeta []ExtendMeta `yaml:"extend_meta"` + Templates []Template `yaml:"templates"` + MWs []MiddlewareForResolve `yaml:"middlewares"` } type MiddlewareForResolve struct { @@ -366,10 +365,6 @@ type MiddlewareForResolve struct { Path string `yaml:"path"` } -type ExtendMeta struct { - key string -} - func readMetaFile(metaPath string) (*Meta, error) { metaData, err := os.ReadFile(metaPath) if err != nil { diff --git a/tool/internal_pkg/tpl/multiple_services.go b/tool/internal_pkg/tpl/multiple_services.go index 0ade8ffd8c..ffad15f379 100644 --- a/tool/internal_pkg/tpl/multiple_services.go +++ b/tool/internal_pkg/tpl/multiple_services.go @@ -1,3 +1,17 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package tpl var MultipleServicesTpl = `package main From 84ea22ba2f87cd09cbf79f1bf41e17757caa30c9 Mon Sep 17 00:00:00 2001 From: shawn Date: Thu, 29 Aug 2024 16:25:39 +0800 Subject: [PATCH 39/41] feat: add migrate flag for new-style custom_template --- tool/cmd/kitex/args/tpl_args.go | 113 ++++++++++++++++-- .../internal_pkg/generator/custom_template.go | 4 +- tool/internal_pkg/generator/generator.go | 2 - tool/internal_pkg/generator/generator_test.go | 6 +- tool/internal_pkg/util/command.go | 2 +- tool/internal_pkg/util/flag.go | 2 +- 6 files changed, 106 insertions(+), 23 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index a190ff7f2d..e0ec51b237 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -21,6 +21,8 @@ import ( "path/filepath" "strings" + "gopkg.in/yaml.v3" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" @@ -91,16 +93,16 @@ func InitTemplates(path string, templates map[string]string) error { } for name, content := range templates { - var filePath string + var dir string if name == BootstrapFileName { - bootstrapDir := filepath.Join(path, "script") - if err := MkdirIfNotExist(bootstrapDir); err != nil { - return err - } - filePath = filepath.Join(bootstrapDir, name+".tpl") + dir = filepath.Join(path, "script") } else { - filePath = filepath.Join(path, name+".tpl") + dir = path + } + if err := MkdirIfNotExist(dir); err != nil { + return err } + filePath := filepath.Join(dir, fmt.Sprintf("%s.tpl", name)) if err := createTemplate(filePath, content); err != nil { return err } @@ -151,8 +153,8 @@ func (a *Arguments) Init(cmd *util.Command, args []string) error { if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) } - path := a.InitOutputDir - initType := a.InitType + path := InitOutputDir + initType := InitType if initType == "" { initType = DefaultType } @@ -211,6 +213,9 @@ func (a *Arguments) Root(cmd *util.Command, args []string) error { } func (a *Arguments) Template(cmd *util.Command, args []string) error { + if len(args) == 0 { + return util.ErrHelp + } curpath, err := filepath.Abs(".") if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) @@ -243,7 +248,84 @@ func (a *Arguments) Template(cmd *util.Command, args []string) error { return a.checkPath(curpath) } +func parseYAMLFiles(directory string) ([]generator.Template, error) { + var templates []generator.Template + files, err := os.ReadDir(directory) + if err != nil { + return nil, err + } + for _, file := range files { + if filepath.Ext(file.Name()) == ".yaml" { + data, err := os.ReadFile(filepath.Join(directory, file.Name())) + if err != nil { + return nil, err + } + var template generator.Template + err = yaml.Unmarshal(data, &template) + if err != nil { + return nil, err + } + templates = append(templates, template) + } + } + return templates, nil +} + +func createFilesFromTemplates(templates []generator.Template, baseDirectory string) error { + for _, template := range templates { + fullPath := filepath.Join(baseDirectory, template.Path) + dir := filepath.Dir(fullPath) + err := os.MkdirAll(dir, os.ModePerm) + if err != nil { + return err + } + err = os.WriteFile(fullPath, []byte(template.Body), 0o644) + if err != nil { + return err + } + } + return nil +} + +func generateMetadata(templates []generator.Template, outputFile string) error { + var metadata generator.Meta + for _, template := range templates { + meta := generator.Template{ + Path: template.Path, + Body: template.Body, + UpdateBehavior: template.UpdateBehavior, + LoopMethod: template.LoopMethod, + LoopService: template.LoopService, + } + metadata.Templates = append(metadata.Templates, meta) + } + data, err := yaml.Marshal(&metadata) + if err != nil { + return err + } + return os.WriteFile(outputFile, data, 0o644) +} + func (a *Arguments) Render(cmd *util.Command, args []string) error { + if len(args) == 0 { + return util.ErrHelp + } + if MigratePath != "" { + templates, err := parseYAMLFiles(MigratePath) + if err != nil { + return err + } + err = createFilesFromTemplates(templates, MigratePath) + if err != nil { + return err + } + err = generateMetadata(templates, generator.KitexRenderMetaFile) + if err != nil { + return err + } + fmt.Println("Migrate successfully...") + os.Exit(0) + } curpath, err := filepath.Abs(".") if err != nil { return fmt.Errorf("get current path failed: %s", err.Error()) @@ -292,7 +374,7 @@ func (a *Arguments) Clean(cmd *util.Command, args []string) error { } content, err := os.ReadFile(path) if err != nil { - return fmt.Errorf("read file %s faild: %v", path, err) + return fmt.Errorf("read file %s failed: %v", path, err) } if strings.Contains(string(content), magicString) { if err := os.Remove(path); err != nil { @@ -309,6 +391,12 @@ func (a *Arguments) Clean(cmd *util.Command, args []string) error { return nil } +var ( + InitOutputDir string // specify the location path of init subcommand + InitType string // specify the type for init subcommand + MigratePath string // specify the path old-style template to new-style template +) + func (a *Arguments) TemplateArgs(version string) error { kitexCmd := &util.Command{ Use: "kitex", @@ -339,8 +427,8 @@ func (a *Arguments) TemplateArgs(version string) error { "Specify a code gen path.") templateCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, "Specify a code gen path.") - initCmd.Flags().StringVarP(&a.InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") - initCmd.Flags().StringVarP(&a.InitType, "type", "t", "", "Specify template init type") + initCmd.Flags().StringVarP(&InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") + initCmd.Flags().StringVarP(&InitType, "type", "t", "", "Specify template init type") renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") renderCmd.Flags().StringVar(&a.ModuleName, "module", "", "Specify the Go module name to generate go.mod.") @@ -351,6 +439,7 @@ func (a *Arguments) TemplateArgs(version string) error { renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") + renderCmd.Flags().StringVar(&MigratePath, "migrate", "", "Migrate path for old-style template") templateCmd.SetHelpFunc(func(*util.Command, []string) { fmt.Fprintln(os.Stderr, ` Template operation diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 0b4a32ab84..a68a32f849 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -305,7 +305,7 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, } } var meta *Meta - metaPath := filepath.Join(g.RenderTplDir, kitexRenderMetaFile) + metaPath := filepath.Join(g.RenderTplDir, KitexRenderMetaFile) if util.Exists(metaPath) { meta, err = readMetaFile(metaPath) if err != nil { @@ -350,7 +350,7 @@ func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, return fs, nil } -const kitexRenderMetaFile = "kitex_render_meta.yaml" +const KitexRenderMetaFile = "kitex_render_meta.yaml" // Meta represents the structure of the kitex_render_meta.yaml file. type Meta struct { diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 0a1966c421..3f69c009e1 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -137,8 +137,6 @@ type Config struct { TemplateDir string // subcommand template - InitOutputDir string // specify the location path of init subcommand - InitType string // specify the type for init subcommand RenderTplDir string // specify the path of template directory for render subcommand TemplateFiles []string // specify the path of single file or multiple file to render DebugTpl bool // turn on the debug mode diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 0cba0caca8..5523220f2e 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -56,8 +56,6 @@ func TestConfig_Pack(t *testing.T) { RecordCmd string ThriftPluginTimeLimit time.Duration TemplateDir string - InitOutputDir string - InitType string RenderTplDir string TemplateFiles []string DebugTpl bool @@ -76,7 +74,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "InitOutputDir=", "InitType=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, }, } for _, tt := range tests { @@ -104,8 +102,6 @@ func TestConfig_Pack(t *testing.T) { FrugalPretouch: tt.fields.FrugalPretouch, ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, TemplateDir: tt.fields.TemplateDir, - InitOutputDir: tt.fields.InitOutputDir, - InitType: tt.fields.InitType, RenderTplDir: tt.fields.RenderTplDir, TemplateFiles: tt.fields.TemplateFiles, IncludesTpl: tt.fields.IncludesTpl, diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go index 35fe4b609d..350e53559f 100644 --- a/tool/internal_pkg/util/command.go +++ b/tool/internal_pkg/util/command.go @@ -218,7 +218,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { // effect if errors.Is(err, ErrHelp) { cmd.HelpFunc()(cmd, args) - return cmd, nil + return cmd, err } } //if err != nil { diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go index 3c936b72ef..77d8f8069a 100644 --- a/tool/internal_pkg/util/flag.go +++ b/tool/internal_pkg/util/flag.go @@ -28,7 +28,7 @@ import ( ) // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. -var ErrHelp = errors.New("pflag: help requested") +var ErrHelp = errors.New("flag: help requested") // ErrorHandling defines how to handle flag parsing errors. type ErrorHandling int From 9c5270186aed3869a062966e6b48ada6b04714e8 Mon Sep 17 00:00:00 2001 From: shawn Date: Thu, 29 Aug 2024 21:24:27 +0800 Subject: [PATCH 40/41] fix: resolve conflicts --- tool/cmd/kitex/main.go | 3 --- tool/internal_pkg/generator/generator.go | 5 ----- tool/internal_pkg/generator/generator_test.go | 2 +- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 0b5fb02959..715ce232ba 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,10 +17,7 @@ package main import ( "bytes" "flag" - "fmt" - "io/ioutil" - "os" "path/filepath" "strings" diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index f2d9ef7712..b28da9a938 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -16,7 +16,6 @@ package generator import ( - "errors" "fmt" "go/token" "path/filepath" @@ -395,10 +394,6 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error if !g.Config.IsUsingMultipleServicesTpl() { f, err := g.generateHandler(pkg, pkg.ServiceInfo, HandlerFileName) if err != nil { - if errors.Is(err, errNoNewMethod) { - return fs, nil - } - return nil, err } // when there is no new method, f would be nil diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index ae2aad7d24..71cf64a32a 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -74,7 +74,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "BuiltinTpl="}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "BuiltinTpl="}, }, } for _, tt := range tests { From 5a2550c96bfd752a6cad1157ed5e755fc45cbb70 Mon Sep 17 00:00:00 2001 From: shawn Date: Thu, 29 Aug 2024 22:34:49 +0800 Subject: [PATCH 41/41] fix: fix migrate --- tool/cmd/kitex/args/tpl_args.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go index e0ec51b237..20bdbb8859 100644 --- a/tool/cmd/kitex/args/tpl_args.go +++ b/tool/cmd/kitex/args/tpl_args.go @@ -273,7 +273,7 @@ func parseYAMLFiles(directory string) ([]generator.Template, error) { func createFilesFromTemplates(templates []generator.Template, baseDirectory string) error { for _, template := range templates { - fullPath := filepath.Join(baseDirectory, template.Path) + fullPath := filepath.Join(baseDirectory, fmt.Sprintf("%s.tpl", template.Path)) dir := filepath.Dir(fullPath) err := os.MkdirAll(dir, os.ModePerm) if err != nil { @@ -292,7 +292,6 @@ func generateMetadata(templates []generator.Template, outputFile string) error { for _, template := range templates { meta := generator.Template{ Path: template.Path, - Body: template.Body, UpdateBehavior: template.UpdateBehavior, LoopMethod: template.LoopMethod, LoopService: template.LoopService, @@ -307,15 +306,16 @@ func generateMetadata(templates []generator.Template, outputFile string) error { } func (a *Arguments) Render(cmd *util.Command, args []string) error { - if len(args) == 0 { - return util.ErrHelp + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) } if MigratePath != "" { templates, err := parseYAMLFiles(MigratePath) if err != nil { return err } - err = createFilesFromTemplates(templates, MigratePath) + err = createFilesFromTemplates(templates, curpath) if err != nil { return err } @@ -326,9 +326,8 @@ func (a *Arguments) Render(cmd *util.Command, args []string) error { fmt.Println("Migrate successfully...") os.Exit(0) } - curpath, err := filepath.Abs(".") - if err != nil { - return fmt.Errorf("get current path failed: %s", err.Error()) + if len(args) == 0 { + return util.ErrHelp } log.Verbose = a.Verbose