Skip to content

Commit

Permalink
fix(builder): disallow duplicate fields (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
steebchen authored Nov 10, 2023
1 parent 59bd915 commit 4f1b6df
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 40 deletions.
6 changes: 5 additions & 1 deletion engine/mock/do.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ func (e *Engine) Do(_ context.Context, payload interface{}, v interface{}) error
n := -1
for i, e := range expectations {
req := payload.(engine.GQLRequest)
if e.Query.Build() == req.Query {
str, err := e.Query.Build()
if err != nil {
return err
}
if str == req.Query {
n = i
break
}
Expand Down
6 changes: 5 additions & 1 deletion engine/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ func (m *Mock) Ensure(t *testing.T) {
}
for _, e := range *m.Expectations {
if !e.Success {
t.Fatalf("expectation not met for query `%s` and result `%s`, error `%s`", e.Query.Build(), e.Want, e.WantErr)
str, err := e.Query.Build()
if err != nil {
t.Fatalf("could not build query: %s", err)
}
t.Fatalf("expectation not met for query `%s` and result `%s`, error `%s`", str, e.Want, e.WantErr)
}
}
}
8 changes: 8 additions & 0 deletions generator/ast/dmmf/dmmf.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func (Document) Operators() []Operator {
}}
}

func (d Document) OperatorActions() []string {
var operators []string
for _, operator := range d.Operators() {
operators = append(operators, operator.Action)
}
return operators
}

// Action describes a CRUD operation.
type Action struct {
// Type describes a query or a mutation
Expand Down
91 changes: 71 additions & 20 deletions runtime/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ type Field struct {
// The Name of the field.
Name string

// List saves whether the fields is a list of items
// List saves whether the fields are a list of items
List bool

// WrapList saves whether the a list field should be wrapped in an object
// WrapList saves whether the field should be wrapped in an individual object
WrapList bool

// Value contains the field value. if nil, fields will contain a subselection.
Expand Down Expand Up @@ -79,39 +79,51 @@ type Query struct {
TxResult chan []byte
}

func (q Query) Build() string {
func (q Query) Build() (string, error) {
var builder strings.Builder

builder.WriteString(q.Operation + " " + q.Name)
builder.WriteString("{")
builder.WriteString("result: ")

builder.WriteString(q.BuildInner())
str, err := q.BuildInner()
if err != nil {
return "", err
}
builder.WriteString(str)

builder.WriteString("}")

return builder.String()
return builder.String(), nil
}

func (q Query) BuildInner() string {
func (q Query) BuildInner() (string, error) {
var builder strings.Builder

builder.WriteString(q.Method + q.Model)

if len(q.Inputs) > 0 {
builder.WriteString(q.buildInputs(q.Inputs))
str, err := q.buildInputs(q.Inputs)
if err != nil {
return "", err
}
builder.WriteString(str)
}

builder.WriteString(" ")

if len(q.Outputs) > 0 {
builder.WriteString(q.buildOutputs(q.Outputs))
str, err := q.buildOutputs(q.Outputs)
if err != nil {
return "", err
}
builder.WriteString(str)
}

return builder.String()
return builder.String(), nil
}

func (q Query) buildInputs(inputs []Input) string {
func (q Query) buildInputs(inputs []Input) (string, error) {
var builder strings.Builder

builder.WriteString("(")
Expand All @@ -127,7 +139,11 @@ func (q Query) buildInputs(inputs []Input) string {
if i.WrapList {
builder.WriteString("[")
}
builder.WriteString(q.buildFields(i.WrapList, i.WrapList, i.Fields))
str, err := q.buildFields(i.WrapList, i.WrapList, i.Fields)
if err != nil {
return "", err
}
builder.WriteString(str)
if i.WrapList {
builder.WriteString("]")
}
Expand All @@ -138,10 +154,10 @@ func (q Query) buildInputs(inputs []Input) string {

builder.WriteString(")")

return builder.String()
return builder.String(), nil
}

func (q Query) buildOutputs(outputs []Output) string {
func (q Query) buildOutputs(outputs []Output) (string, error) {
var builder strings.Builder

builder.WriteString("{")
Expand All @@ -150,20 +166,30 @@ func (q Query) buildOutputs(outputs []Output) string {
builder.WriteString(o.Name + " ")

if len(o.Inputs) > 0 {
builder.WriteString(q.buildInputs(o.Inputs))
str, err := q.buildInputs(o.Inputs)
if err != nil {
return "", err
}
builder.WriteString(str)
}

if len(o.Outputs) > 0 {
builder.WriteString(q.buildOutputs(o.Outputs))
str, err := q.buildOutputs(o.Outputs)
if err != nil {
return "", err
}
builder.WriteString(str)
}
}

builder.WriteString("}")

return builder.String()
return builder.String(), nil
}

func (q Query) buildFields(list bool, wrapList bool, fields []Field) string {
var ErrDuplicateField = fmt.Errorf("duplicate field (https://github.com/steebchen/prisma-client-go/issues/1095)")

func (q Query) buildFields(list bool, wrapList bool, fields []Field) (string, error) {
var builder strings.Builder

if !list {
Expand Down Expand Up @@ -200,6 +226,10 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string {
}

for _, f := range final {
if err := checkFields(f, f.Fields); err != nil {
return "", err
}

if wrapList {
builder.WriteString("{")
}
Expand All @@ -217,7 +247,11 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string {
}

if f.Fields != nil {
builder.WriteString(q.buildFields(f.List, f.WrapList, f.Fields))
str, err := q.buildFields(f.List, f.WrapList, f.Fields)
if err != nil {
return "", err
}
builder.WriteString(str)
}

if f.Value != nil {
Expand All @@ -239,12 +273,29 @@ func (q Query) buildFields(list bool, wrapList bool, fields []Field) string {
builder.WriteString("}")
}

return builder.String()
return builder.String(), nil
}

func checkFields(parent Field, fields []Field) error {
uniqueObjectFields := make(map[string]Field)
for _, f := range fields {
if f.Value != nil && !f.List && !parent.List {
if _, ok := uniqueObjectFields[f.Name]; ok {
return fmt.Errorf("%w: %q", ErrDuplicateField, f.Name)
}
uniqueObjectFields[f.Name] = f
}
}
return nil
}

func (q Query) Exec(ctx context.Context, into interface{}) error {
str, err := q.Build()
if err != nil {
return err
}
payload := engine.GQLRequest{
Query: q.Build(),
Query: str,
Variables: map[string]interface{}{},
}
return q.Do(ctx, payload, into)
Expand Down
24 changes: 14 additions & 10 deletions runtime/transaction/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,9 @@ type Param interface {
}

func (r TX) Transaction(queries ...Param) Exec {
requests := make([]engine.GQLRequest, len(queries))
for i, query := range queries {
requests[i] = engine.GQLRequest{
Query: query.ExtractQuery().Build(),
Variables: map[string]interface{}{},
}
}
return Exec{
engine: r.Engine,
requests: requests,
queries: queries,
engine: r.Engine,
queries: queries,
}
}

Expand All @@ -39,6 +31,18 @@ type Exec struct {
}

func (r Exec) Exec(ctx context.Context) error {
r.requests = make([]engine.GQLRequest, len(r.queries))
for i, query := range r.queries {
str, err := query.ExtractQuery().Build()
if err != nil {
return err
}
r.requests[i] = engine.GQLRequest{
Query: str,
Variables: map[string]interface{}{},
}
}

for _, q := range r.queries {
//goland:noinspection GoDeferInLoop
defer close(q.ExtractQuery().TxResult)
Expand Down
Loading

0 comments on commit 4f1b6df

Please sign in to comment.