From 69008e86eadef4fbf4f70bff3782c63c525f2332 Mon Sep 17 00:00:00 2001 From: Seunghyun Hwang Date: Sat, 11 May 2024 00:00:37 +0900 Subject: [PATCH] support separate package --- entproto/cmd/protoc-gen-entgrpc/converter.go | 4 +- entproto/cmd/protoc-gen-entgrpc/main.go | 57 +++++++++++++------ .../protoc-gen-entgrpc/template/enums.tmpl | 6 +- .../template/method_get.tmpl | 9 ++- .../template/method_list.tmpl | 9 ++- .../template/method_mutate.tmpl | 7 ++- .../protoc-gen-entgrpc/template/service.tmpl | 6 +- .../protoc-gen-entgrpc/template/to_proto.tmpl | 19 +++++-- 8 files changed, 77 insertions(+), 40 deletions(-) diff --git a/entproto/cmd/protoc-gen-entgrpc/converter.go b/entproto/cmd/protoc-gen-entgrpc/converter.go index 1f12078b7..d3adc8e8b 100644 --- a/entproto/cmd/protoc-gen-entgrpc/converter.go +++ b/entproto/cmd/protoc-gen-entgrpc/converter.go @@ -65,7 +65,7 @@ func (g *serviceGenerator) newConverter(fld *entproto.FieldMappingDescriptor) (* case dpb.FieldDescriptorProto_TYPE_ENUM: enumName := fld.PbFieldDescriptor.GetEnumType().GetName() method := fmt.Sprintf("toProto%s_%s", g.EntType.Name, enumName) - out.ToProtoConstructor = g.File.GoImportPath.Ident(method) + out.ToProtoConstructor = g.EntgrpcPackage.Ident(method) case dpb.FieldDescriptorProto_TYPE_MESSAGE: if fld.IsEdgeField { if err := basicTypeConversion(fld.EdgeIDPbStructFieldDesc(), fld.EntEdge.Type.ID, out); err != nil { @@ -110,7 +110,7 @@ func (g *serviceGenerator) newConverter(fld *entproto.FieldMappingDescriptor) (* case efld.IsEnum(): enumName := fld.PbFieldDescriptor.GetEnumType().GetName() method := fmt.Sprintf("toEnt%s_%s", g.EntType.Name, enumName) - out.ToEntConstructor = g.File.GoImportPath.Ident(method) + out.ToEntConstructor = g.EntgrpcPackage.Ident(method) case efld.IsJSON(): switch efld.Type.Ident { case "[]string": diff --git a/entproto/cmd/protoc-gen-entgrpc/main.go b/entproto/cmd/protoc-gen-entgrpc/main.go index 5e98698ed..ee39e7646 100644 --- a/entproto/cmd/protoc-gen-entgrpc/main.go +++ b/entproto/cmd/protoc-gen-entgrpc/main.go @@ -30,19 +30,29 @@ import ( ) var ( - entSchemaPath *string - snake = gen.Funcs["snake"].(func(string) string) - status = protogen.GoImportPath("google.golang.org/grpc/status") - codes = protogen.GoImportPath("google.golang.org/grpc/codes") + entgrpcPackage *string + entSchemaPath *string + entPackagePath *string + + snake = gen.Funcs["snake"].(func(string) string) + status = protogen.GoImportPath("google.golang.org/grpc/status") + codes = protogen.GoImportPath("google.golang.org/grpc/codes") ) func main() { var flags flag.FlagSet + entgrpcPackage = flags.String("package", "", "package path to be generated") entSchemaPath = flags.String("schema_path", "", "ent schema path") + entPackagePath = flags.String("entity_package", "", "ent entity package path") protogen.Options{ ParamFunc: flags.Set, }.Run(func(plg *protogen.Plugin) error { - g, err := entc.LoadGraph(*entSchemaPath, &gen.Config{}) + conf := gen.Config{} + if entPackagePath != nil { + conf.Package = *entPackagePath + } + + g, err := entc.LoadGraph(*entSchemaPath, &conf) if err != nil { return err } @@ -99,19 +109,30 @@ func newServiceGenerator(plugin *protogen.Plugin, file *protogen.File, graph *ge if err != nil { return nil, err } + + entgrpcImportPath := file.GoImportPath + entgrpcPackageName := file.GoPackageName + if entgrpcPackage != nil { + entgrpcImportPath = protogen.GoImportPath(*entgrpcPackage) + entgrpcPackageName = protogen.GoPackageName(path.Base(*entgrpcPackage)) + } + filename := file.GeneratedFilenamePrefix + "_" + snake(service.GoName) + ".go" - g := plugin.NewGeneratedFile(filename, file.GoImportPath) + g := plugin.NewGeneratedFile(filename, entgrpcImportPath) fieldMap, err := adapter.FieldMap(typ.Name) if err != nil { return nil, err } + return &serviceGenerator{ - GeneratedFile: g, - EntPackage: protogen.GoImportPath(graph.Config.Package), - File: file, - Service: service, - EntType: typ, - FieldMap: fieldMap, + GeneratedFile: g, + EntgrpcPackageName: entgrpcPackageName, + EntgrpcPackage: entgrpcImportPath, + EntPackage: protogen.GoImportPath(graph.Config.Package), + File: file, + Service: service, + EntType: typ, + FieldMap: fieldMap, }, nil } @@ -161,11 +182,13 @@ func (g *serviceGenerator) generate() error { type ( serviceGenerator struct { *protogen.GeneratedFile - EntPackage protogen.GoImportPath - File *protogen.File - Service *protogen.Service - EntType *gen.Type - FieldMap entproto.FieldMap + EntgrpcPackageName protogen.GoPackageName + EntgrpcPackage protogen.GoImportPath + EntPackage protogen.GoImportPath + File *protogen.File + Service *protogen.Service + EntType *gen.Type + FieldMap entproto.FieldMap } methodInput struct { G *serviceGenerator diff --git a/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl index 40a2c3d6b..7831bab7f 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/enums.tmpl @@ -14,14 +14,14 @@ } func toProto{{ $pbEnumIdent.GoName }} (e {{ ident $entEnumIdent }}) {{ ident $pbEnumIdent }} { - if v, ok := {{ $pbEnumIdent.GoName }}_value[{{ qualify "strings" "ToUpper" }}({{ if not $omitPrefix }}"{{ $enumFieldPrefix }}" + {{ end }}protoIdentNormalize{{ $pbEnumIdent.GoName }}(string(e)))]; ok { + if v, ok := {{ ident $pbEnumIdent }}_value[{{ qualify "strings" "ToUpper" }}({{ if not $omitPrefix }}"{{ $enumFieldPrefix }}" + {{ end }}protoIdentNormalize{{ $pbEnumIdent.GoName }}(string(e)))]; ok { return {{ $pbEnumIdent | ident }}(v) } return {{ $pbEnumIdent | ident }}(0) } func toEnt{{ $pbEnumIdent.GoName }}(e {{ ident $pbEnumIdent }}) {{ ident $entEnumIdent }} { - if v, ok := {{ $pbEnumIdent.GoName }}_name[int32(e)]; ok { + if v, ok := {{ ident $pbEnumIdent }}_name[int32(e)]; ok { entVal := map[string]string{ {{- range .EntField.Enums }} "{{ if not $omitPrefix }}{{ $enumFieldPrefix }}{{ end }}{{ protoIdentNormalize .Value }}": "{{ .Value }}", @@ -32,4 +32,4 @@ return "" } {{ end}} -{{ end }} \ No newline at end of file +{{ end }} diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl index 4191e83e1..533829f0a 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/method_get.tmpl @@ -1,5 +1,6 @@ {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.methodInput*/ -}} {{ define "method_get" }} + {{ $importPath := .G.File.GoImportPath }} {{- $idField := .G.FieldMap.ID -}} {{- $varName := $idField.EntField.Name -}} {{- $inputName := .Method.Input.GoIdent.GoName -}} @@ -9,9 +10,11 @@ ) {{- template "field_to_ent" dict "Field" $idField "VarName" $idField.EntField.Name "Ident" (print "req.Get" $idField.PbStructField "()") }} switch req.GetView() { - case {{ $inputName }}_VIEW_UNSPECIFIED, {{ $inputName }}_BASIC: + case {{ $importPath.Ident ( print $inputName "_VIEW_UNSPECIFIED" ) | ident }}: + fallthrough + case {{ $importPath.Ident ( print $inputName "_BASIC" ) | ident }}: get, err = svc.client.{{ .G.EntType.Name }}.Get(ctx, {{ $varName }}) - case {{ $inputName }}_WITH_EDGE_IDS: + case {{ $importPath.Ident ( print $inputName "_WITH_EDGE_IDS" ) | ident }}: get, err = svc.client.{{ .G.EntType.Name }}.Query(). Where({{ qualify (print (unquote .G.EntPackage.String) "/" .G.EntType.Package) "ID" }}({{ $varName }})). {{ range .G.FieldMap.Edges }} @@ -32,4 +35,4 @@ default: return nil, {{ statusErrf "Internal" "internal error: %s" "err" }} } -{{ end }} \ No newline at end of file +{{ end }} diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl index 053233f89..79764208e 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/method_list.tmpl @@ -1,5 +1,6 @@ {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.methodInput*/ -}} {{ define "method_list" }} + {{ $importPath := .G.File.GoImportPath }} {{- $inputName := .Method.Input.GoIdent.GoName -}} var ( err error @@ -40,9 +41,11 @@ Where({{ qualify (print (unquote .G.EntPackage.String) "/" .G.EntType.Package) "IDLTE" }}(pageToken)) } switch req.GetView() { - case {{ $inputName }}_VIEW_UNSPECIFIED, {{ $inputName }}_BASIC: + case {{ $importPath.Ident ( print $inputName "_VIEW_UNSPECIFIED" ) | ident }}: + fallthrough + case {{ $importPath.Ident ( print $inputName "_BASIC" ) | ident }}: entList, err = listQuery.All(ctx) - case {{ $inputName }}_WITH_EDGE_IDS: + case {{ $importPath.Ident ( print $inputName "_WITH_EDGE_IDS" ) | ident }}: entList, err = listQuery. {{ range .G.FieldMap.Edges }} {{- $et := .EntEdge.Type -}} @@ -64,7 +67,7 @@ if err != nil { return nil, {{ statusErrf "Internal" "internal error: %s" "err" }} } - return &List{{ .G.EntType.Name }}Response{ + return &{{ $importPath.Ident ( print "List" .G.EntType.Name "Response" ) | ident }}{ {{ .G.EntType.Name }}List: protoList, NextPageToken: nextPageToken, }, nil diff --git a/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl index b68f8a86f..9a3ac8330 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/method_mutate.tmpl @@ -38,10 +38,11 @@ {{ define "create_builder_func" }} {{- $entType := .Method.G.EntType.Name -}} + {{- $pbType := .Method.G.File.GoImportPath.Ident $entType | ident -}} {{- $inputVar := camel $entType -}} - {{- $outputType := printf "%s%s" $entType "Create" -}} + {{- $outputType := .Method.G.EntPackage.Ident ( printf "%s%s" $entType "Create" ) | ident -}} - func (svc *{{ .ServiceName }}) createBuilder({{ $inputVar }} *{{ $entType }}) (*ent.{{ $outputType }}, error) { + func (svc *{{ .ServiceName }}) createBuilder({{ $inputVar }} *{{ $pbType }}) (*{{ $outputType }}, error) { m := svc.client.{{ $entType }}.Create() {{- template "mutate_helper" .Method -}} return m, nil @@ -85,4 +86,4 @@ } {{- end }} {{- end }} -{{ end }} \ No newline at end of file +{{ end }} diff --git a/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl index d4b916672..aef1645e7 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/service.tmpl @@ -1,12 +1,12 @@ {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.serviceGenerator*/ -}} {{ define "service" }} // Code generated by protoc-gen-entgrpc. DO NOT EDIT. -package {{ .File.GoPackageName }} +package {{ .EntgrpcPackageName }} // {{ .Service.GoName }} implements {{ .Service.GoName }}Server type {{ .Service.GoName }} struct { - client *{{ .EntPackage.Ident "Client" | ident }} - Unimplemented{{ .Service.GoName }}Server + client *{{ .EntPackage.Ident "Client" | ident }} + {{ .File.GoImportPath.Ident (print "Unimplemented" .Service.GoName "Server") | ident }} } // New{{ .Service.GoName }} returns a new {{ .Service.GoName }} diff --git a/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl b/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl index 64e093280..ac79d80ae 100644 --- a/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl +++ b/entproto/cmd/protoc-gen-entgrpc/template/to_proto.tmpl @@ -1,8 +1,11 @@ {{- /*gotype: entgo.io/contrib/entproto/cmd/protoc-gen-entgrpc.serviceGenerator*/ -}} {{ define "to_proto_func" }} + {{- $importPath := .File.GoImportPath -}} + {{- $pbTypeName := $importPath.Ident .EntType.Name | ident -}} + // toProto{{ .EntType.Name }} transforms the ent type to the pb type - func toProto{{ .EntType.Name }}(e *{{ .EntPackage.Ident .EntType.Name | ident }}) (*{{ .EntType.Name }}, error) { - v := &{{ .EntType.Name }}{} + func toProto{{ .EntType.Name }}(e *{{ .EntPackage.Ident .EntType.Name | ident }}) (*{{ $pbTypeName }}, error) { + v := &{{ $pbTypeName }}{} {{- range .FieldMap.Fields }} {{- $varName := .EntField.BuilderField -}} {{- $f := print "e." .EntField.StructField -}} @@ -17,20 +20,21 @@ {{- end }} {{- end }} {{- range .FieldMap.Edges }} + {{ $edgeTypeName := $importPath.Ident .EntEdge.Type.Name | ident }} {{- $varName := camel .EntEdge.Type.ID.StructField -}} {{- $id := print "edg." .EntEdge.Type.ID.StructField -}} {{- $name := .EntEdge.StructField -}} {{- if .EntEdge.Unique }} if edg := e.Edges.{{ $name }}; edg != nil { {{- template "field_to_proto" dict "Field" . "VarName" $varName "Ident" $id }} - v.{{ .PbStructField }} = &{{ .EntEdge.Type.Name }}{ + v.{{ .PbStructField }} = &{{ $edgeTypeName }}{ {{ .EdgeIDPbStructField }}: {{ $varName }}, } } {{- else }} for _, edg := range e.Edges.{{ $name }} { {{- template "field_to_proto" dict "Field" . "VarName" $varName "Ident" $id }} - v.{{ .PbStructField }} = append(v.{{ .PbStructField }}, &{{ .EntEdge.Type.Name }}{ + v.{{ .PbStructField }} = append(v.{{ .PbStructField }}, &{{ $edgeTypeName }}{ {{ .EdgeIDPbStructField }}: {{ $varName }}, }) } @@ -41,9 +45,12 @@ {{ end }} {{ define "to_proto_list_func" }} + {{- $importPath := .File.GoImportPath -}} + {{- $pbTypeName := $importPath.Ident .EntType.Name | ident -}} + // toProto{{ .EntType.Name }}List transforms a list of ent type to a list of pb type - func toProto{{ .EntType.Name }}List(e []*{{ .EntPackage.Ident .EntType.Name | ident }}) ([]*{{ .EntType.Name }}, error) { - var pbList []*{{ .EntType.Name }} + func toProto{{ .EntType.Name }}List(e []*{{ .EntPackage.Ident .EntType.Name | ident }}) ([]*{{ $pbTypeName }}, error) { + var pbList []*{{ $pbTypeName }} for _, entEntity := range e { pbEntity, err := toProto{{ .EntType.Name }}(entEntity) if err != nil {