Skip to content

Commit

Permalink
fix: do not roll back transaction on partial insert error
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Nov 14, 2024
1 parent b40606c commit de962ce
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 20 deletions.
42 changes: 31 additions & 11 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) {
})

t.Run("suite=PATCH identities", func(t *testing.T) {
t.Run("case=fails on > 100 identities", func(t *testing.T) {
t.Run("case=fails with too many patches", func(t *testing.T) {
tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1)
for i := range tooMany {
tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)}
Expand All @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expect the server to fail only the bad patches in the list.
// Finally, we send just valid patches
// --> we expect the server to succeed all patches in the list.

t.Run("case=invalid patches fail", func(t *testing.T) {
Expand All @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) {
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}
expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]}

// Create unique IDs for each patch
var patchIDs []string
patchIDs := make([]string, len(patches))
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
patchIDs[i] = id.String()
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body)
assert.Equalf(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)
actions, "%s", body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
Expand All @@ -811,6 +810,27 @@ func TestHandler(t *testing.T) {
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

var identityIDs []uuid.UUID
require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.identity").Raw), &identityIDs), "%s", body)

actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs})
require.NoError(t, err)
actualIdentityIDs := make([]uuid.UUID, len(actualIdentities))
for i, id := range actualIdentities {
actualIdentityIDs[i] = id.ID
}
assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body)

expectedTraits := make(map[string]string, len(expectedToPass))
for i, p := range expectedToPass {
expectedTraits[identityIDs[i].String()] = string(p.Create.Traits)
}
actualTraits := make(map[string]string, len(actualIdentities))
for _, id := range actualIdentities {
actualTraits[id.ID.String()] = string(id.Traits)
}

assert.Equal(t, expectedTraits, actualTraits)
})

t.Run("valid patches succeed", func(t *testing.T) {
Expand Down Expand Up @@ -1928,7 +1948,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody
identity.VerifiableAddressStatusCompleted,
}

for j := 0; j < 4; j++ {
for j := range 4 {
email := fmt.Sprintf("%s-%d-%[email protected]", prefix, i, j)
traits.Emails = append(traits.Emails, email)
verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{
Expand Down
10 changes: 8 additions & 2 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ type CreateIdentitiesError struct {
failedIdentities map[*Identity]*herodot.DefaultError
}

func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError {
return &CreateIdentitiesError{
failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity),
}
}

func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
Expand Down Expand Up @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if len(e.failedIdentities) == 0 {
if e == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
Expand All @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity,
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)

createIdentitiesError := &CreateIdentitiesError{}
createIdentitiesError := NewCreateIdentitiesError(len(identities))
validIdentities := make([]*Identity, 0, len(identities))
for _, ident := range identities {
if ident.SchemaID == "" {
Expand Down
42 changes: 39 additions & 3 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,46 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second)
assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second)
assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt)
assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt)
}
})

t.Run("create exactly the non-conflicting ones", func(t *testing.T) {
identities := make([]*identity.Identity, 100)
for i := range identities {
identities[i] = NewTestIdentity(4, "persister-create-multiple-2", i%60)
}
err := p.CreateIdentities(ctx, identities...)
errWithCtx := new(identity.CreateIdentitiesError)
require.ErrorAsf(t, err, &errWithCtx, "%#v", err)

for _, id := range identities[:60] {
require.NotZero(t, id.ID)

idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything)
require.NoError(t, err)

credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword]
assert.Equal(t, id.ID, idFromDB.ID)
assert.Equal(t, id.SchemaID, idFromDB.SchemaID)
assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL)
assert.Equal(t, id.State, idFromDB.State)

// We test that the values are plausible in the handler test already.
assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses))
assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses))

assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
assert.Equal(t, id.CreatedAt, idFromDB.CreatedAt)
assert.Equal(t, id.UpdatedAt, idFromDB.UpdatedAt)
}

require.NoError(t, p.DeleteIdentity(ctx, id.ID))
for _, id := range identities[60:] {
failed := errWithCtx.Find(id)
assert.NotNil(t, failed)
}
})
})
Expand Down
13 changes: 9 additions & 4 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,16 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}
}()

return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
var partialErr *identity.CreateIdentitiesError
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
conn := &batch.TracerConnection{
Tracer: p.r.Tracer(ctx),
Connection: tx,
}

succeededIDs = make([]uuid.UUID, 0, len(identities))
failedIdentityIDs := make(map[uuid.UUID]struct{})
partialErr = nil

// Don't use batch.WithPartialInserts, because identities have no other
// constraints other than the primary key that could cause conflicts.
Expand Down Expand Up @@ -620,7 +622,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
// If any of the batch inserts failed on conflict, let's delete the corresponding
// identities and return a list of failed identities in the error.
if len(failedIdentityIDs) > 0 {
partialErr := &identity.CreateIdentitiesError{}
partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs))
failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs))

for _, ident := range identities {
Expand All @@ -637,7 +639,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
return sqlcon.HandleError(err)
}

return partialErr
return nil
} else {
// No failures: report all identities as created.
for _, ident := range identities {
Expand All @@ -646,7 +648,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}

return nil
})
}); err != nil {
return err
}
return partialErr.ErrOrNil()
}

func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
Expand Down

0 comments on commit de962ce

Please sign in to comment.