diff --git a/engine/mock/do.go b/engine/mock/do.go index 744c30f1..28d532c1 100644 --- a/engine/mock/do.go +++ b/engine/mock/do.go @@ -17,7 +17,11 @@ func (e *Engine) Do(_ context.Context, payload interface{}, v interface{}) error n := -1 for i, e := range expectations { req := payload.(engine.GQLRequest) - if e.Query.Build() == req.Query { + str, err := e.Query.Build() + if err != nil { + return err + } + if str == req.Query { n = i break } diff --git a/engine/mock/mock.go b/engine/mock/mock.go index 56398cc0..e6b1f54b 100644 --- a/engine/mock/mock.go +++ b/engine/mock/mock.go @@ -27,7 +27,11 @@ func (m *Mock) Ensure(t *testing.T) { } for _, e := range *m.Expectations { if !e.Success { - t.Fatalf("expectation not met for query `%s` and result `%s`, error `%s`", e.Query.Build(), e.Want, e.WantErr) + str, err := e.Query.Build() + if err != nil { + t.Fatalf("could not build query: %s", err) + } + t.Fatalf("expectation not met for query `%s` and result `%s`, error `%s`", str, e.Want, e.WantErr) } } } diff --git a/generator/ast/dmmf/dmmf.go b/generator/ast/dmmf/dmmf.go index a1b0d940..26845738 100644 --- a/generator/ast/dmmf/dmmf.go +++ b/generator/ast/dmmf/dmmf.go @@ -98,6 +98,14 @@ func (Document) Operators() []Operator { }} } +func (d Document) OperatorActions() []string { + var operators []string + for _, operator := range d.Operators() { + operators = append(operators, operator.Action) + } + return operators +} + // Action describes a CRUD operation. type Action struct { // Type describes a query or a mutation diff --git a/runtime/builder/builder.go b/runtime/builder/builder.go index 6a7f3ad2..2566977e 100644 --- a/runtime/builder/builder.go +++ b/runtime/builder/builder.go @@ -32,10 +32,10 @@ type Field struct { // The Name of the field. Name string - // List saves whether the fields is a list of items + // List saves whether the fields are a list of items List bool - // WrapList saves whether the a list field should be wrapped in an object + // WrapList saves whether the field should be wrapped in an individual object WrapList bool // Value contains the field value. if nil, fields will contain a subselection. @@ -79,39 +79,51 @@ type Query struct { TxResult chan []byte } -func (q Query) Build() string { +func (q Query) Build() (string, error) { var builder strings.Builder builder.WriteString(q.Operation + " " + q.Name) builder.WriteString("{") builder.WriteString("result: ") - builder.WriteString(q.BuildInner()) + str, err := q.BuildInner() + if err != nil { + return "", err + } + builder.WriteString(str) builder.WriteString("}") - return builder.String() + return builder.String(), nil } -func (q Query) BuildInner() string { +func (q Query) BuildInner() (string, error) { var builder strings.Builder builder.WriteString(q.Method + q.Model) if len(q.Inputs) > 0 { - builder.WriteString(q.buildInputs(q.Inputs)) + str, err := q.buildInputs(q.Inputs) + if err != nil { + return "", err + } + builder.WriteString(str) } builder.WriteString(" ") if len(q.Outputs) > 0 { - builder.WriteString(q.buildOutputs(q.Outputs)) + str, err := q.buildOutputs(q.Outputs) + if err != nil { + return "", err + } + builder.WriteString(str) } - return builder.String() + return builder.String(), nil } -func (q Query) buildInputs(inputs []Input) string { +func (q Query) buildInputs(inputs []Input) (string, error) { var builder strings.Builder builder.WriteString("(") @@ -127,7 +139,11 @@ func (q Query) buildInputs(inputs []Input) string { if i.WrapList { builder.WriteString("[") } - builder.WriteString(q.buildFields(i.WrapList, i.WrapList, i.Fields)) + str, err := q.buildFields(i.WrapList, i.WrapList, i.Fields) + if err != nil { + return "", err + } + builder.WriteString(str) if i.WrapList { builder.WriteString("]") } @@ -138,10 +154,10 @@ func (q Query) buildInputs(inputs []Input) string { builder.WriteString(")") - return builder.String() + return builder.String(), nil } -func (q Query) buildOutputs(outputs []Output) string { +func (q Query) buildOutputs(outputs []Output) (string, error) { var builder strings.Builder builder.WriteString("{") @@ -150,20 +166,30 @@ func (q Query) buildOutputs(outputs []Output) string { builder.WriteString(o.Name + " ") if len(o.Inputs) > 0 { - builder.WriteString(q.buildInputs(o.Inputs)) + str, err := q.buildInputs(o.Inputs) + if err != nil { + return "", err + } + builder.WriteString(str) } if len(o.Outputs) > 0 { - builder.WriteString(q.buildOutputs(o.Outputs)) + str, err := q.buildOutputs(o.Outputs) + if err != nil { + return "", err + } + builder.WriteString(str) } } builder.WriteString("}") - return builder.String() + return builder.String(), nil } -func (q Query) buildFields(list bool, wrapList bool, fields []Field) string { +var ErrDuplicateField = fmt.Errorf("duplicate field") + +func (q Query) buildFields(list bool, wrapList bool, fields []Field) (string, error) { var builder strings.Builder if !list { @@ -200,6 +226,10 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string { } for _, f := range final { + if err := checkFields(f, f.Fields); err != nil { + return "", err + } + if wrapList { builder.WriteString("{") } @@ -217,7 +247,11 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string { } if f.Fields != nil { - builder.WriteString(q.buildFields(f.List, f.WrapList, f.Fields)) + str, err := q.buildFields(f.List, f.WrapList, f.Fields) + if err != nil { + return "", err + } + builder.WriteString(str) } if f.Value != nil { @@ -239,12 +273,29 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string { builder.WriteString("}") } - return builder.String() + return builder.String(), nil +} + +func checkFields(parent Field, fields []Field) error { + uniqueObjectFields := make(map[string]Field) + for _, f := range fields { + if f.Value != nil && !f.List && !parent.List { + if _, ok := uniqueObjectFields[f.Name]; ok { + return fmt.Errorf("%w: %q", ErrDuplicateField, f.Name) + } + uniqueObjectFields[f.Name] = f + } + } + return nil } func (q Query) Exec(ctx context.Context, into interface{}) error { + str, err := q.Build() + if err != nil { + return err + } payload := engine.GQLRequest{ - Query: q.Build(), + Query: str, Variables: map[string]interface{}{}, } return q.Do(ctx, payload, into) diff --git a/runtime/transaction/transaction.go b/runtime/transaction/transaction.go index e783c990..3c2e04d9 100644 --- a/runtime/transaction/transaction.go +++ b/runtime/transaction/transaction.go @@ -18,17 +18,9 @@ type Param interface { } func (r TX) Transaction(queries ...Param) Exec { - requests := make([]engine.GQLRequest, len(queries)) - for i, query := range queries { - requests[i] = engine.GQLRequest{ - Query: query.ExtractQuery().Build(), - Variables: map[string]interface{}{}, - } - } return Exec{ - engine: r.Engine, - requests: requests, - queries: queries, + engine: r.Engine, + queries: queries, } } @@ -39,6 +31,18 @@ type Exec struct { } func (r Exec) Exec(ctx context.Context) error { + r.requests = make([]engine.GQLRequest, len(r.queries)) + for i, query := range r.queries { + str, err := query.ExtractQuery().Build() + if err != nil { + return err + } + r.requests[i] = engine.GQLRequest{ + Query: str, + Variables: map[string]interface{}{}, + } + } + for _, q := range r.queries { //goland:noinspection GoDeferInLoop defer close(q.ExtractQuery().TxResult) diff --git a/test/features/enums/enums_test.go b/test/features/enums/enums_test.go index a0f979cc..177a7afe 100644 --- a/test/features/enums/enums_test.go +++ b/test/features/enums/enums_test.go @@ -2,8 +2,12 @@ package enums import ( "context" + "errors" "testing" + "github.com/stretchr/testify/assert" + + "github.com/steebchen/prisma-client-go/runtime/builder" "github.com/steebchen/prisma-client-go/test" "github.com/steebchen/prisma-client-go/test/helpers/massert" ) @@ -52,7 +56,7 @@ func TestEnums(t *testing.T) { massert.Equal(t, expected, created) - actual, err := client.User.FindMany( + actual, err := client.User.FindFirst( User.Role.Equals(RoleAdmin), User.Role.In([]Role{RoleAdmin}), User.RoleOpt.Equals(RoleModerator), @@ -62,7 +66,154 @@ func TestEnums(t *testing.T) { t.Fatalf("fail %s", err) } - massert.Equal(t, []UserModel{*expected}, actual) + massert.Equal(t, expected, actual) + }, + }, { + name: "many or with and wrapper", + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Role.Set(RoleAdmin), + User.ID.Set("123"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleModerator), + User.ID.Set("456"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleUser), + User.ID.Set("789"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + actual, err := client.User.FindMany( + User.Or( + User.And( + User.Role.Equals(RoleUser), + ), + User.And( + User.Role.Equals(RoleAdmin), + ), + ), + ).OrderBy( + User.ID.Order(SortOrderAsc), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + massert.Equal(t, []UserModel{ + { + InnerUser: InnerUser{ + ID: "123", + Role: RoleAdmin, + }, + }, + { + InnerUser: InnerUser{ + ID: "789", + Role: RoleUser, + }, + }, + }, actual) + }, + }, { + name: "many or direct", + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Role.Set(RoleAdmin), + User.ID.Set("123"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleModerator), + User.ID.Set("456"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleUser), + User.ID.Set("789"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.FindMany( + User.Or( + User.Role.Equals(RoleUser), + User.Role.Equals(RoleAdmin), + ), + ).OrderBy( + User.ID.Order(SortOrderAsc), + ).Exec(ctx) + + assert.Equal(t, builder.ErrDuplicateField, errors.Unwrap(err)) + }, + }, { + name: "in", + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Role.Set(RoleAdmin), + User.ID.Set("123"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleModerator), + User.ID.Set("456"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + _, err = client.User.CreateOne( + User.Role.Set(RoleUser), + User.ID.Set("789"), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + actual, err := client.User.FindMany( + User.Role.In([]Role{RoleUser, RoleAdmin}), + ).OrderBy( + User.ID.Order(SortOrderAsc), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + massert.Equal(t, []UserModel{ + { + InnerUser: InnerUser{ + ID: "123", + Role: RoleAdmin, + }, + }, + { + InnerUser: InnerUser{ + ID: "789", + Role: RoleUser, + }, + }, + }, actual) }, }} for _, tt := range tests { diff --git a/test/projects/basic/basic_test.go b/test/projects/basic/basic_test.go index 5eb0c64b..4fb69c59 100644 --- a/test/projects/basic/basic_test.go +++ b/test/projects/basic/basic_test.go @@ -3,8 +3,12 @@ package basic import ( "context" "encoding/json" + "errors" "testing" + "github.com/stretchr/testify/assert" + + "github.com/steebchen/prisma-client-go/runtime/builder" "github.com/steebchen/prisma-client-go/test" "github.com/steebchen/prisma-client-go/test/helpers/massert" ) @@ -569,6 +573,63 @@ func TestBasic(t *testing.T) { massert.Equal(t, expected, actual) }, }, { + name: "OR operation", + // language=GraphQL + before: []string{` + mutation { + result: createOneUser(data: { + id: "id1", + email: "email1", + username: "a", + }) { + id + } + } + `, ` + mutation { + result: createOneUser(data: { + id: "id2", + email: "email2", + username: "b", + }) { + id + } + } + `}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + actual, err := client.User.FindMany( + User.Or( + User.And( + User.Email.Equals("email1"), + ), + User.And( + User.ID.Equals("id2"), + ), + ), + ).OrderBy( + User.ID.Order(SortOrderAsc), + ).Exec(ctx) + if err != nil { + t.Fatalf("fail %s", err) + } + + expected := []UserModel{{ + InnerUser: InnerUser{ + ID: "id1", + Email: "email1", + Username: "a", + }, + }, { + InnerUser: InnerUser{ + ID: "id2", + Email: "email2", + Username: "b", + }, + }} + + massert.Equal(t, expected, actual) + }, + }, { name: "OR operation", // language=GraphQL before: []string{` @@ -619,10 +680,10 @@ func TestBasic(t *testing.T) { }, }} - massert.Equal(t, expected, actual) + assert.Equal(t, expected, actual) }, }, { - name: "OR operationc complex", + name: "OR operations complex with and", // language=GraphQL before: []string{` mutation { @@ -678,12 +739,26 @@ func TestBasic(t *testing.T) { ), User.And( User.Or( - User.Email.Equals("email4"), - User.Email.Equals("email999"), + User.And( + User.Email.Equals("email999"), + ), + User.And( + User.Email.Equals("email4"), + ), + User.And( + User.Email.Equals("email999"), + ), ), User.Or( - User.ID.Equals("id4"), - User.ID.Equals("id999"), + User.And( + User.ID.Equals("id999"), + ), + User.And( + User.ID.Equals("id4"), + ), + User.And( + User.ID.Equals("id999"), + ), ), ), ), @@ -716,6 +791,75 @@ func TestBasic(t *testing.T) { massert.Equal(t, expected, actual) }, + }, { + name: "OR operations complex no wrap", + // language=GraphQL + before: []string{` + mutation { + result: createOneUser(data: { + id: "id1", + email: "email1", + username: "a", + }) { + id + } + } + `, ` + mutation { + result: createOneUser(data: { + id: "id2", + email: "email2", + username: "b", + }) { + id + } + } + `, ` + mutation { + result: createOneUser(data: { + id: "id3", + email: "email3", + username: "c", + }) { + id + } + } + `, ` + mutation { + result: createOneUser(data: { + id: "id4", + email: "email4", + username: "d", + }) { + id + } + } + `}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.FindMany( + User.Or( + User.And( + User.Email.Equals("email1"), + User.ID.Equals("id1"), + ), + User.And( + User.Email.Equals("email2"), + User.ID.Equals("id2"), + ), + User.And( + User.Or( + User.Email.Equals("email999"), + User.Email.Equals("email4"), + User.Email.Equals("email999"), + ), + ), + ), + ).OrderBy( + User.ID.Order(SortOrderAsc), + ).Exec(ctx) + + assert.Equal(t, builder.ErrDuplicateField, errors.Unwrap(err)) + }, }, { name: "id in", // language=GraphQL