Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mongodb): add support for object types #1086

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
// ->
// ["asdf", null]
func transformResponse(data []byte) ([]byte, error) {
logger.Debug.Printf("before transform: %s", data)
var m interface{}
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
Expand Down
39 changes: 29 additions & 10 deletions generator/ast/dmmf/dmmf.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@ const (
FieldKindEnum FieldKind = "enum"
)

// IncludeInStruct shows whether to include a field in a model struct.
func (v FieldKind) IncludeInStruct() bool {
return v == FieldKindScalar || v == FieldKindEnum
}

// IsRelation returns whether field is a relation
func (v FieldKind) IsRelation() bool {
// IsObject returns whether field is an object
func (v FieldKind) IsObject() bool {
return v == FieldKindObject
}

Expand Down Expand Up @@ -211,8 +206,19 @@ type EnumValue struct {

// Datamodel contains all types of the Prisma Datamodel.
type Datamodel struct {
Models []Model `json:"models"`
Enums []Enum `json:"enums"`
Models []Model `json:"models"`
Types []ObjectType `json:"types"`
Enums []Enum `json:"enums"`
}

// ObjectType is a MongoDB object type
type ObjectType struct {
Name types.String `json:"name"`
DbName types.String `json:"dbName"`
Fields []Field `json:"fields"`
PrimaryKey string `json:"primaryKey"`
UniqueFields []string `json:"uniqueFields"`
UniqueIndexes []string `json:"uniqueIndexes"`
}

type UniqueIndex struct {
Expand Down Expand Up @@ -254,7 +260,7 @@ func (m Model) Actions() []string {
func (m Model) RelationFieldsPlusOne() []Field {
var fields []Field
for _, field := range m.Fields {
if field.Kind.IsRelation() {
if field.IsRelation() {
fields = append(fields, field)
}
}
Expand Down Expand Up @@ -286,6 +292,19 @@ type Field struct {
HasDefaultValue bool `json:"hasDefaultValue"`
}

// IncludeInStruct shows whether to include a field in a model struct.
func (f Field) IncludeInStruct() bool {
return f.Kind == FieldKindScalar || f.Kind == FieldKindEnum || f.IsObjectType()
}

func (f Field) IsRelation() bool {
return f.RelationName != ""
}

func (f Field) IsObjectType() bool {
return f.Type != "" && !f.IsRelation() && f.Kind == FieldKindObject
}

func (f Field) RequiredOnCreate(key PrimaryKey) bool {
if !f.IsRequired || f.IsUpdatedAt || f.HasDefaultValue || f.IsReadOnly || f.IsList {
return false
Expand Down
2 changes: 1 addition & 1 deletion generator/templates/actions/actions.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var countOutput = []builder.Output{

var {{ $name }}Output = []builder.Output{
{{- range $i := $model.Fields }}
{{- if $i.Kind.IncludeInStruct }}
{{- if $i.IncludeInStruct }}
{Name: "{{ $i.Name }}"},
{{- end }}
{{- end }}
Expand Down
39 changes: 19 additions & 20 deletions generator/templates/models.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
// Inner{{ $model.Name.GoCase }} holds the actual data
type Inner{{ $model.Name.GoCase }} struct {
{{ range $field := $model.Fields }}
{{- if not $field.Kind.IsRelation -}}
{{- if $field.IsRequired }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }}
{{- else }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }}
{{- end }}
{{- end -}}
{{- if eq $field.RelationName "" -}}
{{ $field.Name.GoCase -}}
{{- if $field.IsList }}[]{{ else }}{{- if not $field.IsRequired }}*{{ end }} {{ end -}}
{{ $field.Type.Value }}{{ if eq $field.Kind "object" }}Type{{ end -}}
{{ $field.Name.Tag $field.IsRequired }}
{{- end }}
{{ end }}
}

// Raw{{ $model.Name.GoCase }}Model is a struct for {{ $model.Name }} when used in raw queries
type Raw{{ $model.Name.GoCase }}Model struct {
{{ range $field := $model.Fields }}
{{- if not $field.Kind.IsRelation -}}
{{- if not $field.IsRelation -}}
{{- if $field.IsRequired }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}Raw{{ $field.Type.GoCase }} {{ $field.Name.Tag $field.IsRequired }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}{{ if not $field.IsObjectType }}Raw{{ end }}{{ $field.Type.GoCase }}{{ if $field.IsObjectType }}Type{{ end }} {{ $field.Name.Tag $field.IsRequired }}
{{- else }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}Raw{{ $field.Type.GoCase }} {{ $field.Name.Tag $field.IsRequired }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ if not $field.IsObjectType }}Raw{{ end }}{{ $field.Type.GoCase }}{{ if $field.IsObjectType }}Type{{ end }} {{ $field.Name.Tag $field.IsRequired }}
{{- end }}
{{- end -}}
{{ end }}
Expand All @@ -36,35 +35,35 @@
// Relations{{ $model.Name.GoCase }} holds the relation data separately
type Relations{{ $model.Name.GoCase }} struct {
{{ range $field := $model.Fields }}
{{- if $field.Kind.IsRelation }}
{{- if $field.IsRelation }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }}Model {{ $field.Name.Tag false }}
{{- end -}}
{{ end }}
}

{{/* Attach methods for nullable (non-required) fields and relations. */}}
{{- range $field := $model.Fields }}
{{- if or (not $field.IsRequired) ($field.Kind.IsRelation) }}
{{- if or (not $field.IsRequired) ($field.IsRelation) }}
func (r {{ $model.Name.GoCase }}Model) {{ $field.Name.GoCase }}() (
{{- if $field.IsList }}value []{{ else }}value{{ end }} {{ if and $field.Kind.IsRelation (not $field.IsList) }}*{{ end }}{{ $field.Type.Value }}{{ if $field.Kind.IsRelation }}Model{{ end -}}
{{- if or (not $field.Kind.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
{{- if $field.IsList }}value []{{ else }}value{{ end }} {{ if and $field.IsRelation (not $field.IsList) }}*{{ end }}{{ $field.Type.Value }}{{ if $field.IsRelation }}Model{{ end -}}{{ if $field.IsObjectType }}Type{{ end -}}
{{- if or (not $field.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
, ok bool
{{- end -}}
) {
if r.{{ if $field.Kind.IsRelation }}Relations{{ else }}Inner{{ end }}{{ $model.Name.GoCase }}.{{ $field.Name.GoCase }} == nil {
{{- if and ($field.Kind.IsRelation) ($field.IsRequired) }}
if r.{{ if $field.IsRelation }}Relations{{ else }}Inner{{ end }}{{ $model.Name.GoCase }}.{{ $field.Name.GoCase }} == nil {
{{- if and ($field.IsRelation) ($field.IsRequired) }}
panic("attempted to access {{ $field.Name.GoLowerCase }} but did not fetch it using the .With() syntax")
{{- else }}
return value
{{- if or (not $field.Kind.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
{{- if or (not $field.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
, false
{{- end -}}
{{- end }}
}
return {{ if and (not $field.Kind.IsRelation) (not $field.IsList) }}*{{ end }}r.
{{- if $field.Kind.IsRelation }}Relations{{ else }}Inner{{ end }}{{ $model.Name.GoCase }}.
return {{ if and (not $field.IsRelation) (not $field.IsList) }}*{{ end }}r.
{{- if $field.IsRelation }}Relations{{ else }}Inner{{ end }}{{ $model.Name.GoCase }}.
{{- $field.Name.GoCase -}}
{{- if or (not $field.Kind.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
{{- if or (not $field.IsRelation) (and (not $field.IsList) (not $field.IsRequired)) -}}
, true
{{- end -}}
}
Expand Down
32 changes: 16 additions & 16 deletions generator/templates/query.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
{{ if $field.Prisma }}
{{ $name = $field.Name.PrismaGoCase }}
{{ end }}
{{- if $field.Kind.IncludeInStruct -}}
{{- if $field.IncludeInStruct -}}
// {{ $name }}
//
// @{{ if $field.IsRequired }}required{{ else }}optional{{ end }}
Expand All @@ -27,7 +27,7 @@
{{ $name }} {{ $nsQuery }}{{ $field.Name.GoCase }}{{ $field.Type }}
{{ end }}

{{- if $field.Kind.IsRelation }}
{{- if $field.IsRelation }}
{{ $name }} {{ $nsQuery }}{{ $name }}Relations
{{ end }}
{{- end }}
Expand Down Expand Up @@ -87,7 +87,7 @@
{{ $setReturnStruct = (print $name "SetParam") }}
{{ end}}

{{ if $field.Kind.IsRelation }}
{{ if $field.Field.IsRelation }}
type {{ $nsQuery }}{{ $field.Name.GoCase }}Relations struct {}

{{ range $method := $field.RelationMethods }}
Expand Down Expand Up @@ -232,13 +232,13 @@
{{ end }}
{{ end }}

{{ if $field.Kind.IncludeInStruct }}
{{ if $field.IncludeInStruct }}
{{ if not $field.Prisma }}
// Set the {{ if $field.IsRequired }}required{{ else }}optional{{ end }} value of {{ $field.Name.GoCase }}
func (r {{ $struct }}) Set(value {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }}) {{ $setReturnStruct }} {
func (r {{ $struct }}) Set(value {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $setReturnStruct }} {
{{ if $field.IsList }}
if value == nil {
value = []{{ $field.Type.Value }}{}
value = []{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}{}
}
{{ end }}
{{/* if scalar list (only postgres) */}}
Expand All @@ -265,7 +265,7 @@
}

// Set the optional value of {{ $field.Name.GoCase }} dynamically
func (r {{ $struct }}) SetIfPresent(value *{{ if $field.IsList }}[]{{ else }}{{ end }}{{ $field.Type.Value }}) {{ $setReturnStruct }} {
func (r {{ $struct }}) SetIfPresent(value *{{ if $field.IsList }}[]{{ else }}{{ end }}{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $setReturnStruct }} {
if value == nil {
return {{ $setReturnStruct }}{}
}
Expand All @@ -276,10 +276,10 @@

{{ if and (not $field.IsRequired) (not $field.IsList) (not $field.Prisma) }}
// Set the optional value of {{ $field.Name.GoCase }} dynamically
func (r {{ $struct }}) SetOptional(value *{{ $field.Type.Value }}) {{ $setReturnStruct }} {
func (r {{ $struct }}) SetOptional(value *{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $setReturnStruct }} {
if value == nil {
{{/* nil value of type */}}
var v *{{ $field.Type.Value }}
var v *{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}
return {{ $setReturnStruct }}{
data: builder.Field{
Name: "{{ $field.Name }}",
Expand All @@ -300,7 +300,7 @@
{{ $type = $field.Type.Value}}
{{ end }}
// {{ $method.Name }} the {{ if $field.IsRequired }}required{{ else }}optional{{ end }} value of {{ $field.Name.GoCase }}
func (r {{ $struct }}) {{ $method.Name }}(value {{ if $method.IsList }}[]{{ end }}{{ $type }}) {{ $setReturnStruct }} {
func (r {{ $struct }}) {{ $method.Name }}(value {{ if $method.IsList }}[]{{ end }}{{ $type }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $setReturnStruct }} {
return {{ $setReturnStruct }}{
data: builder.Field{
Name: "{{ $field.Name }}",
Expand Down Expand Up @@ -331,7 +331,7 @@
{{ $returnStruct = (print $name "DefaultParam") }}
{{ end }}

{{ if and $field.Kind.IncludeInStruct (not $field.Prisma) }}
{{ if and $field.IncludeInStruct (not $field.Prisma) }}
{{/* Provide an `Equals` method for most types. */}}
{{/* Equals has a special return type for individual fields */}}
{{ $equalsReturnStruct := "" }}
Expand All @@ -340,10 +340,10 @@
{{ else }}
{{ $equalsReturnStruct = (print $name "WithPrisma" $field.Name.GoCase "EqualsParam") }}
{{ end }}
func (r {{ $struct }}) Equals(value {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }}) {{ $equalsReturnStruct }} {
func (r {{ $struct }}) Equals(value {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $equalsReturnStruct }} {
{{ if $field.IsList }}
if value == nil {
value = []{{ $field.Type.Value }}{}
value = []{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}{}
}
{{ end }}
return {{ $equalsReturnStruct }}{
Expand All @@ -359,15 +359,15 @@
}
}

func (r {{ $struct }}) EqualsIfPresent(value {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }}) {{ $equalsReturnStruct }} {
func (r {{ $struct }}) EqualsIfPresent(value {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $equalsReturnStruct }} {
if value == nil {
return {{ $equalsReturnStruct }}{}
}
return r.Equals({{ if not $field.IsList }}*{{ end }}value)
}

{{ if and (not $field.IsRequired) (not $field.Prisma) }}
func (r {{ $struct }}) EqualsOptional(value *{{ $field.Type.Value }}) {{ $returnStruct }} {
func (r {{ $struct }}) EqualsOptional(value *{{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $returnStruct }} {
return {{ $returnStruct }}{
data: builder.Field{
Name: "{{ $field.Name }}",
Expand Down Expand Up @@ -406,7 +406,7 @@
}
}

func (r {{ $struct }}) Cursor(cursor {{ $field.Type.Value }}) {{ $name }}CursorParam {
func (r {{ $struct }}) Cursor(cursor {{ $field.Type.Value }}{{ if $field.Field.IsObjectType }}Type{{ end }}) {{ $name }}CursorParam {
return {{ $name }}CursorParam{
data: builder.Field{
Name: "{{ $field.Name }}",
Expand Down
18 changes: 18 additions & 0 deletions generator/templates/types.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}}

{{/* Types are for MongoDB object types */}}

{{ range $type := $.DMMF.Datamodel.Types }}
// {{ $type.Name.GoCase }}Type
type {{ $type.Name.GoCase }}Type struct {
{{ range $field := $type.Fields }}
{{- if not $field.IsRelation -}}
{{- if $field.IsRequired }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }}
{{- else }}
{{ $field.Name.GoCase }} {{ if $field.IsList }}[]{{ else }}*{{ end }}{{ $field.Type.Value }} {{ $field.Name.Tag $field.IsRequired }}
{{- end }}
{{- end -}}
{{ end }}
}
{{ end }}
90 changes: 90 additions & 0 deletions test/databases/mongodb/objects/default_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package raw

import (
"context"
"testing"

"github.com/steebchen/prisma-client-go/test"
"github.com/steebchen/prisma-client-go/test/helpers/massert"
)

type cx = context.Context
type Func func(t *testing.T, client *PrismaClient, ctx cx)

func TestObjects(t *testing.T) {
t.Parallel()

tests := []struct {
name string
before []string
run Func
}{{
name: "types",
run: func(t *testing.T, client *PrismaClient, ctx cx) {
expected := &UserModel{
InnerUser: InnerUser{
ID: "id1",
Email: "email1",
Username: "username1",
Info: InfoType{
Age: 5,
AgeOpt: 3,
},
InfoOpt: &InfoType{
Age: 5,
AgeOpt: 3,
},
List: []InfoType{
{
Age: 5,
AgeOpt: 3,
},
},
},
}

user, err := client.User.CreateOne(
User.Email.Set("id1"),
User.Username.Set("id1"),
User.Info.Set(InfoType{
Age: 5,
AgeOpt: 3,
}),
User.InfoOpt.Set(InfoType{
Age: 5,
AgeOpt: 3,
}),
User.List.Set([]InfoType{{
Age: 5,
AgeOpt: 3,
}}),
User.ID.Set("id1"),
).Exec(ctx)
if err != nil {
t.Fatal(err)
}

massert.Equal(t, expected, user)

user, err = client.User.FindUnique(
User.ID.Equals("id1"),
).Exec(ctx)
if err != nil {
t.Fatal(err)
}

massert.Equal(t, expected, user)
},
}}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
client := NewClient()

mockDB := test.Start(t, test.MongoDB, client.Engine, tt.before)
defer test.End(t, test.MongoDB, client.Engine, mockDB)

tt.run(t, client, context.Background())
})
}
}
Loading
Loading