Skip to content

Commit

Permalink
adaptions
Browse files Browse the repository at this point in the history
  • Loading branch information
steebchen committed Oct 28, 2023
1 parent a2a6381 commit 4ade1f0
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 37 deletions.
6 changes: 3 additions & 3 deletions generator/templates/errors.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

var ErrNotFound = types.ErrNotFound

type ErrUniqueConstraint = types.ErrUniqueConstraint[PrismaFields]
type ErrUniqueConstraint = types.ErrUniqueConstraint[prismaFields]

func IsUniqueConstraint(err error) (*types.ErrUniqueConstraint[PrismaFields], bool) {
return types.CheckUniqueConstraint[PrismaFields](err)
func IsUniqueConstraint(err error) (*types.ErrUniqueConstraint[prismaFields], bool) {
return types.CheckUniqueConstraint[prismaFields](err)
}
8 changes: 6 additions & 2 deletions generator/templates/fields.gotpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}}

type PrismaFields string
type prismaFields string

{{ range $model := $.AST.Models }}
type {{ $model.Name.GoCase }}PrismaFields string
type {{ $model.Name.GoLowerCase }}PrismaFields = prismaFields

{{ range $field := $model.Fields }}
const {{ $model.Name.GoLowerCase }}Field{{ $field.Name.GoCase }} {{ $model.Name.GoLowerCase }}PrismaFields = "{{ $field.Name }}"
{{ end }}
{{ end }}
5 changes: 5 additions & 0 deletions generator/templates/query.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -452,5 +452,10 @@
}
{{ end }}
{{ end }}

{{/* Returns static field names */}}
func (r {{ $struct }}) Field() {{ $model.Name.GoLowerCase }}PrismaFields {
return {{ $model.Name.GoLowerCase }}Field{{ $field.Name.GoCase }}
}
{{ end }}
{{ end }}
50 changes: 26 additions & 24 deletions runtime/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,27 @@ package types

import (
"errors"
"strings"
"regexp"
)

// ErrNotFound gets returned when a database record does not exist
var ErrNotFound = errors.New("ErrNotFound")

type A string
type B string

type F interface {
~string
}

type ErrUniqueConstraint[T F] struct {
// Field only shows on Postgres
Field T
// Key only shows on MySQL
Key string
}

const prismaUniqueConstraint = "Unique constraint failed on the fields: (`%s`)"
const fieldKey = "field"

var prismaMySQLUniqueConstraint = regexp.MustCompile("Unique constraint failed on the constraint: `(?P<" + fieldKey + ">.+)`")
var prismaPostgresUniqueConstraint = regexp.MustCompile("Unique constraint failed on the fields: \\(`(?P<" + fieldKey + ">.+)`\\)")

// CheckUniqueConstraint returns on a unique constraint error or violation with error info
// Use as follows:
Expand All @@ -33,27 +36,26 @@ const prismaUniqueConstraint = "Unique constraint failed on the fields: (`%s`)"
//
// Ideally this will be replaced with Prisma-generated errors in the future
func CheckUniqueConstraint[T F](err error) (*ErrUniqueConstraint[T], bool) {
// TODO use regex
if !strings.Contains(err.Error(), prismaUniqueConstraint) {
return nil, false
if match, ok := findMatch(err, prismaMySQLUniqueConstraint); ok {
return &ErrUniqueConstraint[T]{
Key: match,
}, true
}
return &ErrUniqueConstraint[T]{
Field: "asdf",
}, true
if match, ok := findMatch(err, prismaPostgresUniqueConstraint); ok {
return &ErrUniqueConstraint[T]{
Field: T(match),
}, true
}
return nil, false
}

// ----------
// THIS IS GENERATED CODE
// ----------

type Fields string

// TODO check what JS client uses for fields exports

const UserModelNameField Fields = "user.name"

type RealErrUniqueConstraint = ErrUniqueConstraint[Fields]
func findMatch(err error, regex *regexp.Regexp) (string, bool) {
result := regex.FindStringSubmatch(err.Error())
if result == nil {
return "", false
}

func CheckUniqueConstraintError(err error) (*ErrUniqueConstraint[Fields], bool) {
return CheckUniqueConstraint[Fields](err)
index := regex.SubexpIndex(fieldKey)
field := result[index]
return field, true
}
41 changes: 33 additions & 8 deletions test/errors/unique/unique_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package db

import (
"context"
"github.com/steebchen/prisma-client-go/runtime/types"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -13,15 +12,17 @@ import (
type cx = context.Context
type Func func(t *testing.T, client *PrismaClient, ctx cx)

func TestNotFound(t *testing.T) {
func TestUniqueConstraintViolation(t *testing.T) {
t.Parallel()

tests := []struct {
name string
dbs []test.Database
before []string
run Func
}{{
name: "unique constraint violation on email",
name: "postgres unique constraint violation",
dbs: []test.Database{test.PostgreSQL},
run: func(t *testing.T, client *PrismaClient, ctx cx) {
_, err := client.User.CreateOne(
User.Email.Set("[email protected]"),
Expand All @@ -34,18 +35,42 @@ func TestNotFound(t *testing.T) {
User.Username.Set("username"),
).Exec(ctx)

violation, ok := types.CheckUniqueConstraintError(err) // IsUniqueConstraint
assert.Equal(t, types.RealErrUniqueConstraint{
Field: types.UserModelNameField, // User.Name.Field()
}, violation)
violation, ok := IsUniqueConstraint(err)
// assert.Equal(t, &ErrUniqueConstraint{
// Field: User.Email.Field(),
// }, violation)
assert.Equal(t, User.Email.Field(), violation.Field)

assert.Equal(t, true, ok)
},
}, {
name: "mysql unique constraint violation",
dbs: []test.Database{test.MySQL},
run: func(t *testing.T, client *PrismaClient, ctx cx) {
_, err := client.User.CreateOne(
User.Email.Set("[email protected]"),
User.Username.Set("username"),
).Exec(ctx)
assert.Equal(t, nil, err)

_, err = client.User.CreateOne(
User.Email.Set("[email protected]"),
User.Username.Set("username"),
).Exec(ctx)

violation, ok := IsUniqueConstraint(err)
// assert.Equal(t, &ErrUniqueConstraint{
// Key: "User_email_key",
// }, violation)
assert.Equal(t, "User_email_key", violation.Key)

assert.Equal(t, true, ok)
},
}}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
test.RunSerial(t, test.Databases, func(t *testing.T, db test.Database, ctx context.Context) {
test.RunSerial(t, tt.dbs, func(t *testing.T, db test.Database, ctx context.Context) {
client := NewClient()
mockDBName := test.Start(t, db, client.Engine, tt.before)
defer test.End(t, db, client.Engine, mockDBName)
Expand Down

0 comments on commit 4ade1f0

Please sign in to comment.