diff --git a/docs/pages/docs/reference/client/errors.md b/docs/pages/docs/reference/client/errors.md new file mode 100644 index 00000000..5861ca75 --- /dev/null +++ b/docs/pages/docs/reference/client/errors.md @@ -0,0 +1,40 @@ +# Errors + +## ErrNotFound + +`ErrNotFound` is returned when a query does not return any results. This error may be returned in `FindUnique`, `FindFirst`, but also when updating or deleting single records using `FindUnique().Update()` and `FindUnique().Delete()`. + +```go +post, err := client.Post.FindFirst( + db.Post.Title.Equals("hi"), +).Exec(ctx) +if err != nil { + if errors.Is(err, db.ErrNotFound) { + panic("no record with title 'hi' found") + } + panic("error occurred: %s", err) +} +``` + +## IsUniqueConstraintViolation + +A unique constraint violation happens when a query attempts to insert or update a record with a value that already exists in the database, or in other words, violates a unique constraint. + +```go +user, err := db.User.CreateOne(...).Exec(cxt) +if err != nil { + if info, err := db.IsErrUniqueConstraint(err); err != nil { + // Fields exists for Postgres and SQLite + log.Printf("unique constraint on the fields: %s", info.Fields) + + // you can also compare it with generated field names: + if info.Fields[0] == db.User.Name.Field() { + // do something + log.Printf("unique constraint on the `user.name` field") + } + + // For MySQL and MongoDB, use the constraint key + log.Printf("unique constraint on the key: %s", info.Key) + } +} +``` diff --git a/engine/mock/do.go b/engine/mock/do.go index 28d532c1..c6b08079 100644 --- a/engine/mock/do.go +++ b/engine/mock/do.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/steebchen/prisma-client-go/engine" + "github.com/steebchen/prisma-client-go/engine/protocol" ) func (e *Engine) Do(_ context.Context, payload interface{}, v interface{}) error { @@ -16,7 +16,7 @@ func (e *Engine) Do(_ context.Context, payload interface{}, v interface{}) error n := -1 for i, e := range expectations { - req := payload.(engine.GQLRequest) + req := payload.(protocol.GQLRequest) str, err := e.Query.Build() if err != nil { return err diff --git a/engine/protocol.go b/engine/protocol/protocol.go similarity index 60% rename from engine/protocol.go rename to engine/protocol/protocol.go index b20faf78..7bc6727e 100644 --- a/engine/protocol.go +++ b/engine/protocol/protocol.go @@ -1,4 +1,4 @@ -package engine +package protocol import ( "encoding/json" @@ -33,11 +33,30 @@ type GQLBatchRequest struct { Transaction bool `json:"transaction"` } -// GQLError is a GraphQL Error +type UserFacingError struct { + IsPanic bool `json:"is_panic"` + Message string `json:"message"` + Meta Meta `json:"meta"` + ErrorCode string `json:"error_code"` +} + +func (e *UserFacingError) Error() string { + return e.Message +} + +type Meta struct { + Target interface{} `json:"target"` // can be of type []string or string +} + +// GQLError is a GraphQL Message type GQLError struct { - Message string `json:"error"` // note: the query-engine uses 'error' instead of 'message' - Path []string `json:"path"` - Extensions map[string]interface{} `json:"query"` + Message string `json:"error"` + UserFacingError *UserFacingError `json:"user_facing_error"` + Path []string `json:"path"` +} + +func (e *GQLError) Error() string { + return e.Message } func (e *GQLError) RawMessage() string { diff --git a/engine/proxy.go b/engine/proxy.go index b3d55234..b93650a4 100644 --- a/engine/proxy.go +++ b/engine/proxy.go @@ -13,6 +13,7 @@ import ( "time" "github.com/steebchen/prisma-client-go/binaries" + "github.com/steebchen/prisma-client-go/engine/protocol" "github.com/steebchen/prisma-client-go/logger" "github.com/steebchen/prisma-client-go/runtime/types" ) @@ -108,7 +109,7 @@ func (e *DataProxyEngine) Do(ctx context.Context, payload interface{}, into inte startParse := time.Now() - var response GQLResponse + var response protocol.GQLResponse if err := json.Unmarshal(body, &response); err != nil { return fmt.Errorf("json gql resopnse unmarshal: %w", err) } diff --git a/engine/request.go b/engine/request.go index 8c0464d8..b5bf22ff 100644 --- a/engine/request.go +++ b/engine/request.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/steebchen/prisma-client-go/engine/protocol" "github.com/steebchen/prisma-client-go/logger" "github.com/steebchen/prisma-client-go/runtime/types" ) @@ -28,18 +29,23 @@ func (e *QueryEngine) Do(ctx context.Context, payload interface{}, v interface{} startParse := time.Now() - var response GQLResponse + var response protocol.GQLResponse if err := json.Unmarshal(body, &response); err != nil { return fmt.Errorf("json gql response unmarshal: %w", err) } if len(response.Errors) > 0 { - first := response.Errors[0] - if first.RawMessage() == internalUpdateNotFoundMessage || - first.RawMessage() == internalDeleteNotFoundMessage { + e := response.Errors[0] + if e.RawMessage() == internalUpdateNotFoundMessage || + e.RawMessage() == internalDeleteNotFoundMessage { return types.ErrNotFound } - return fmt.Errorf("pql error: %s", first.RawMessage()) + + if e.UserFacingError != nil { + return fmt.Errorf("user facing error: %w", e.UserFacingError) + } + + return fmt.Errorf("internal error: %s", e.RawMessage()) } response.Data.Result, err = transformResponse(response.Data.Result) diff --git a/generator/run.go b/generator/run.go index db3e730c..861faf63 100644 --- a/generator/run.go +++ b/generator/run.go @@ -78,6 +78,7 @@ func generateClient(input *Root) error { "client", "enums", "errors", + "fields", "mock", "models", "query", diff --git a/generator/templates/errors.gotpl b/generator/templates/errors.gotpl index 2a46cdd4..ffd9a943 100644 --- a/generator/templates/errors.gotpl +++ b/generator/templates/errors.gotpl @@ -1 +1,29 @@ +{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}} + var ErrNotFound = types.ErrNotFound +var IsErrNotFound = types.IsErrNotFound + +type ErrUniqueConstraint = types.ErrUniqueConstraint[prismaFields] + +// IsErrUniqueConstraint returns on a unique constraint error or violation with error info +// Use as follows: +// +// user, err := db.User.CreateOne(...).Exec(cxt) +// if err != nil { +// if info, err := db.IsErrUniqueConstraint(err); err != nil { +// // Fields exists for Postgres and SQLite +// log.Printf("unique constraint on the fields: %s", info.Fields) +// +// // you can also compare it with generated field names: +// if info.Fields[0] == db.User.Name.Field() { +// // do something +// } +// +// // For MySQL, use the constraint key +// log.Printf("unique constraint on the key: %s", info.Key) +// } +// } +// +func IsErrUniqueConstraint(err error) (*types.ErrUniqueConstraint[prismaFields], bool) { + return types.CheckUniqueConstraint[prismaFields](err) +} diff --git a/generator/templates/fields.gotpl b/generator/templates/fields.gotpl new file mode 100644 index 00000000..a26f325f --- /dev/null +++ b/generator/templates/fields.gotpl @@ -0,0 +1,11 @@ +{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}} + +type prismaFields string + +{{ range $model := $.AST.Models }} + type {{ $model.Name.GoLowerCase }}PrismaFields = prismaFields + + {{ range $field := $model.Fields }} + const {{ $model.Name.GoLowerCase }}Field{{ $field.Name.GoCase }} {{ $model.Name.GoLowerCase }}PrismaFields = "{{ $field.Name }}" + {{ end }} +{{ end }} diff --git a/generator/templates/query.gotpl b/generator/templates/query.gotpl index 47b1d279..1fd250dc 100644 --- a/generator/templates/query.gotpl +++ b/generator/templates/query.gotpl @@ -452,5 +452,10 @@ } {{ end }} {{ end }} + + {{/* Returns static field names */}} + func (r {{ $struct }}) Field() {{ $model.Name.GoLowerCase }}PrismaFields { + return {{ $model.Name.GoLowerCase }}Field{{ $field.Name.GoCase }} + } {{ end }} {{ end }} diff --git a/go.mod b/go.mod index e724a96c..4fbc1472 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,15 @@ module github.com/steebchen/prisma-client-go -go 1.16 +go 1.18 require ( github.com/joho/godotenv v1.5.1 github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.4 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 6ad1b28d..6c4e7344 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -7,15 +6,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/runtime/builder/builder.go b/runtime/builder/builder.go index 9dc0d34f..af5d82e9 100644 --- a/runtime/builder/builder.go +++ b/runtime/builder/builder.go @@ -8,6 +8,7 @@ import ( "time" "github.com/steebchen/prisma-client-go/engine" + "github.com/steebchen/prisma-client-go/engine/protocol" "github.com/steebchen/prisma-client-go/logger" ) @@ -294,7 +295,7 @@ func (q Query) Exec(ctx context.Context, into interface{}) error { if err != nil { return err } - payload := engine.GQLRequest{ + payload := protocol.GQLRequest{ Query: str, Variables: map[string]interface{}{}, } diff --git a/runtime/transaction/transaction.go b/runtime/transaction/transaction.go index 3c2e04d9..8afd4e8c 100644 --- a/runtime/transaction/transaction.go +++ b/runtime/transaction/transaction.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/steebchen/prisma-client-go/engine" + "github.com/steebchen/prisma-client-go/engine/protocol" "github.com/steebchen/prisma-client-go/runtime/builder" ) @@ -27,17 +28,17 @@ func (r TX) Transaction(queries ...Param) Exec { type Exec struct { queries []Param engine engine.Engine - requests []engine.GQLRequest + requests []protocol.GQLRequest } func (r Exec) Exec(ctx context.Context) error { - r.requests = make([]engine.GQLRequest, len(r.queries)) + r.requests = make([]protocol.GQLRequest, len(r.queries)) for i, query := range r.queries { str, err := query.ExtractQuery().Build() if err != nil { return err } - r.requests[i] = engine.GQLRequest{ + r.requests[i] = protocol.GQLRequest{ Query: str, Variables: map[string]interface{}{}, } @@ -48,8 +49,8 @@ func (r Exec) Exec(ctx context.Context) error { defer close(q.ExtractQuery().TxResult) } - var result engine.GQLBatchResponse - payload := engine.GQLBatchRequest{ + var result protocol.GQLBatchResponse + payload := protocol.GQLBatchRequest{ Batch: r.requests, Transaction: true, } diff --git a/runtime/types/errors.go b/runtime/types/errors.go index 2a98352d..7888cc91 100644 --- a/runtime/types/errors.go +++ b/runtime/types/errors.go @@ -1,6 +1,69 @@ package types -import "errors" +import ( + "errors" + + "github.com/steebchen/prisma-client-go/engine/protocol" +) // ErrNotFound gets returned when a database record does not exist var ErrNotFound = errors.New("ErrNotFound") + +// IsErrNotFound is true if the error is a ErrNotFound, which gets returned when a database record does not exist +// This can happen when you call `FindUnique` on a record, or update or delete a single record which doesn't exist. +func IsErrNotFound(err error) bool { + return errors.Is(err, ErrNotFound) +} + +type F interface { + ~string +} + +type ErrUniqueConstraint[T F] struct { + // Message is the error message + Message string + // Fields only shows on Postgres + Fields []T + // Key only shows on MySQL + Key string +} + +// CheckUniqueConstraint returns on a unique constraint error or violation with error info +// Ideally this will be replaced with Prisma-generated errors in the future +func CheckUniqueConstraint[T F](err error) (*ErrUniqueConstraint[T], bool) { + if err == nil { + return nil, false + } + + var ufr *protocol.UserFacingError + if ok := errors.As(err, &ufr); !ok { + return nil, false + } + + if ufr.ErrorCode != "P2002" { + return nil, false + } + + // postgres + if items, ok := ufr.Meta.Target.([]interface{}); ok { + var fields []T + for _, f := range items { + field, ok := f.(string) + if ok { + fields = append(fields, T(field)) + } + } + return &ErrUniqueConstraint[T]{ + Fields: fields, + }, true + } + + // mysql + if item, ok := ufr.Meta.Target.(string); ok { + return &ErrUniqueConstraint[T]{ + Key: item, + }, true + } + + return nil, false +} diff --git a/test/errors/unique/schema.prisma b/test/errors/unique/schema.prisma new file mode 100644 index 00000000..c8a70a2b --- /dev/null +++ b/test/errors/unique/schema.prisma @@ -0,0 +1,19 @@ +datasource db { + provider = "sqlite" + url = env("__REPLACE__") +} + +generator db { + provider = "go run github.com/steebchen/prisma-client-go" + output = "." + disableGoBinaries = true + package = "db" +} + +model User { + id String @id @default(cuid()) @map("_id") + email String @unique + username String + name String? + age Int? +} diff --git a/test/errors/unique/unique_test.go b/test/errors/unique/unique_test.go new file mode 100644 index 00000000..19fbd125 --- /dev/null +++ b/test/errors/unique/unique_test.go @@ -0,0 +1,140 @@ +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/steebchen/prisma-client-go/test" +) + +type cx = context.Context +type Func func(t *testing.T, client *PrismaClient, ctx cx) + +func TestUniqueConstraintViolation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dbs []test.Database + before []string + run Func + }{{ + name: "postgres unique constraint violation", + dbs: []test.Database{test.PostgreSQL}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + assert.Equal(t, nil, err) + + _, err = client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + + violation, ok := IsErrUniqueConstraint(err) + // assert.Equal(t, &ErrUniqueConstraint{ + // Field: User.Email.Field(), + // }, violation) + assert.Equal(t, User.Email.Field(), violation.Fields[0]) + + assert.Equal(t, true, ok) + }, + }, { + name: "mysql unique constraint violation", + dbs: []test.Database{test.MySQL}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + assert.Equal(t, nil, err) + + _, err = client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + + violation, ok := IsErrUniqueConstraint(err) + // assert.Equal(t, &ErrUniqueConstraint{ + // Key: "User_email_key", + // }, violation) + assert.Equal(t, "User_email_key", violation.Key) + + assert.Equal(t, true, ok) + }, + }, { + name: "sqlite unique constraint violation", + dbs: []test.Database{test.SQLite}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + assert.Equal(t, nil, err) + + _, err = client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + + violation, ok := IsErrUniqueConstraint(err) + // assert.Equal(t, &ErrUniqueConstraint{ + // Field: User.Email.Field(), + // }, violation) + assert.Equal(t, User.Email.Field(), violation.Fields[0]) + + assert.Equal(t, true, ok) + }, + }, { + name: "mongodb unique constraint violation", + dbs: []test.Database{test.MongoDB}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + assert.Equal(t, nil, err) + + _, err = client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + + violation, ok := IsErrUniqueConstraint(err) + // assert.Equal(t, &ErrUniqueConstraint{ + // Key: "User_email_key", + // }, violation) + assert.Equal(t, "User_email_key", violation.Key) + + assert.Equal(t, true, ok) + }, + }, { + name: "nil error should succeed", + dbs: []test.Database{test.SQLite}, + run: func(t *testing.T, client *PrismaClient, ctx cx) { + _, err := client.User.CreateOne( + User.Email.Set("john@example.com"), + User.Username.Set("username"), + ).Exec(ctx) + + _, ok := IsErrUniqueConstraint(err) + + assert.Equal(t, false, ok) + }, + }} + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + test.RunSerial(t, tt.dbs, func(t *testing.T, db test.Database, ctx context.Context) { + client := NewClient() + mockDBName := test.Start(t, db, client.Engine, tt.before) + defer test.End(t, db, client.Engine, mockDBName) + tt.run(t, client, context.Background()) + }) + }) + } +} diff --git a/test/integration/go.mod b/test/integration/go.mod index 897bb2f5..8575358a 100644 --- a/test/integration/go.mod +++ b/test/integration/go.mod @@ -1,5 +1,5 @@ module integration -go 1.16 +go 1.18 replace github.com/steebchen/prisma-client-go => ../../ diff --git a/test/test.go b/test/test.go index 6eebf380..63d3e767 100644 --- a/test/test.go +++ b/test/test.go @@ -10,6 +10,7 @@ import ( "github.com/steebchen/prisma-client-go/cli" "github.com/steebchen/prisma-client-go/engine" + "github.com/steebchen/prisma-client-go/engine/protocol" "github.com/steebchen/prisma-client-go/test/cmd" "github.com/steebchen/prisma-client-go/test/setup/mongodb" "github.com/steebchen/prisma-client-go/test/setup/mysql" @@ -71,8 +72,8 @@ func Start(t *testing.T, db Database, e engine.Engine, queries []string) string } for _, q := range queries { - var response engine.GQLResponse - payload := engine.GQLRequest{ + var response protocol.GQLResponse + payload := protocol.GQLRequest{ Query: q, Variables: map[string]interface{}{}, }