Skip to content

Commit

Permalink
Remove 'RETURNING' functionality from MultiInserter
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongable committed Oct 2, 2024
1 parent 0442894 commit 48773ac
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 141 deletions.
8 changes: 0 additions & 8 deletions db/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,9 @@ type Executor interface {
OneSelector
Inserter
SelectExecer
Queryer
Delete(context.Context, ...interface{}) (int64, error)
Get(context.Context, interface{}, ...interface{}) (interface{}, error)
Update(context.Context, ...interface{}) (int64, error)
}

// Queryer offers the QueryContext method. Note that this is not read-only (i.e. not
// Selector), since a QueryContext can be `INSERT`, `UPDATE`, etc. The difference
// between QueryContext and ExecContext is that QueryContext can return rows. So for instance it is
// suitable for inserting rows and getting back ids.
type Queryer interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
}

Expand Down
78 changes: 21 additions & 57 deletions db/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,24 @@ import (
)

// MultiInserter makes it easy to construct a
// `INSERT INTO table (...) VALUES ... RETURNING id;`
// `INSERT INTO table (...) VALUES ...;`
// query which inserts multiple rows into the same table. It can also execute
// the resulting query.
type MultiInserter struct {
// These are validated by the constructor as containing only characters
// that are allowed in an unquoted identifier.
// https://mariadb.com/kb/en/identifier-names/#unquoted
table string
fields []string
returningColumn string
table string
fields []string

values [][]interface{}
}

// NewMultiInserter creates a new MultiInserter, checking for reasonable table
// name and list of fields. returningColumn is the name of a column to be used
// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz`
// clause is used. If returningColumn is present, it must refer to a column
// that can be parsed into an int64.
// Safety: `table`, `fields`, and `returningColumn` must contain only strings
// that are known at compile time. They must not contain user-controlled
// strings.
func NewMultiInserter(table string, fields []string, returningColumn string) (*MultiInserter, error) {
// name and list of fields.
// Safety: `table` and `fields` must contain only strings that are known at
// compile time. They must not contain user-controlled strings.
func NewMultiInserter(table string, fields []string) (*MultiInserter, error) {
if len(table) == 0 || len(fields) == 0 {
return nil, fmt.Errorf("empty table name or fields list")
}
Expand All @@ -44,18 +39,11 @@ func NewMultiInserter(table string, fields []string, returningColumn string) (*M
return nil, err
}
}
if returningColumn != "" {
err := validMariaDBUnquotedIdentifier(returningColumn)
if err != nil {
return nil, err
}
}

return &MultiInserter{
table: table,
fields: fields,
returningColumn: returningColumn,
values: make([][]interface{}, 0),
table: table,
fields: fields,
values: make([][]interface{}, 0),
}, nil
}

Expand Down Expand Up @@ -84,56 +72,32 @@ func (mi *MultiInserter) query() (string, []interface{}) {

questions := strings.TrimRight(questionsBuf.String(), ",")

// Safety: we are interpolating `mi.returningColumn` into an SQL query. We
// know it is a valid unquoted identifier in MariaDB because we verified
// that in the constructor.
returning := ""
if mi.returningColumn != "" {
returning = fmt.Sprintf(" RETURNING %s", mi.returningColumn)
}
// Safety: we are interpolating `mi.table` and `mi.fields` into an SQL
// query. We know they contain, respectively, a valid unquoted identifier
// and a slice of valid unquoted identifiers because we verified that in
// the constructor. We know the query overall has valid syntax because we
// generate it entirely within this function.
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s%s", mi.table, strings.Join(mi.fields, ","), questions, returning)
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions)

return query, queryArgs
}

// Insert inserts all the collected rows into the database represented by
// `queryer`. If a non-empty returningColumn was provided, then it returns
// the list of values from that column returned by the query.
func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) {
// `queryer`.
func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error {
query, queryArgs := mi.query()
rows, err := queryer.QueryContext(ctx, query, queryArgs...)
res, err := db.ExecContext(ctx, query, queryArgs...)
if err != nil {
return nil, err
return err
}

ids := make([]int64, 0, len(mi.values))
if mi.returningColumn != "" {
for rows.Next() {
var id int64
err = rows.Scan(&id)
if err != nil {
rows.Close()
return nil, err
}
ids = append(ids, id)
}
affected, err := res.RowsAffected()
if err != nil {
return err
}

// Hack: sometimes in unittests we make a mock Queryer that returns a nil
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
// on it will panic— but here we choose to treat it like an empty list,
// and skip calling `Close()` to avoid the panic.
if rows != nil {
err = rows.Close()
if err != nil {
return nil, err
}
if affected != int64(len(mi.values)) {
return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values))
}

return ids, nil
return nil
}
32 changes: 8 additions & 24 deletions db/multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,29 @@ import (
)

func TestNewMulti(t *testing.T) {
_, err := NewMultiInserter("", []string{"colA"}, "")
_, err := NewMultiInserter("", []string{"colA"})
test.AssertError(t, err, "Empty table name should fail")

_, err = NewMultiInserter("myTable", nil, "")
_, err = NewMultiInserter("myTable", nil)
test.AssertError(t, err, "Empty fields list should fail")

mi, err := NewMultiInserter("myTable", []string{"colA"}, "")
mi, err := NewMultiInserter("myTable", []string{"colA"})
test.AssertNotError(t, err, "Single-column construction should not fail")
test.AssertEquals(t, len(mi.fields), 1)

mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"}, "")
mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"})
test.AssertNotError(t, err, "Multi-column construction should not fail")
test.AssertEquals(t, len(mi.fields), 3)

_, err = NewMultiInserter("", []string{"colA"}, "colB")
test.AssertError(t, err, "expected error for empty table name")
_, err = NewMultiInserter("foo\"bar", []string{"colA"}, "colB")
_, err = NewMultiInserter("foo\"bar", []string{"colA"})
test.AssertError(t, err, "expected error for invalid table name")

_, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"}, "colB")
_, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"})
test.AssertError(t, err, "expected error for invalid column name")

_, err = NewMultiInserter("myTable", []string{"colA"}, "foo\"bar")
test.AssertError(t, err, "expected error for invalid returning column name")
}

func TestMultiAdd(t *testing.T) {
mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "")
mi, err := NewMultiInserter("table", []string{"a", "b", "c"})
test.AssertNotError(t, err, "Failed to create test MultiInserter")

err = mi.Add([]interface{}{})
Expand All @@ -57,7 +52,7 @@ func TestMultiAdd(t *testing.T) {
}

func TestMultiQuery(t *testing.T) {
mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "")
mi, err := NewMultiInserter("table", []string{"a", "b", "c"})
test.AssertNotError(t, err, "Failed to create test MultiInserter")
err = mi.Add([]interface{}{"one", "two", "three"})
test.AssertNotError(t, err, "Failed to insert test row")
Expand All @@ -67,15 +62,4 @@ func TestMultiQuery(t *testing.T) {
query, queryArgs := mi.query()
test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?)")
test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"})

mi, err = NewMultiInserter("table", []string{"a", "b", "c"}, "id")
test.AssertNotError(t, err, "Failed to create test MultiInserter")
err = mi.Add([]interface{}{"one", "two", "three"})
test.AssertNotError(t, err, "Failed to insert test row")
err = mi.Add([]interface{}{"egy", "kettö", "három"})
test.AssertNotError(t, err, "Failed to insert test row")

query, queryArgs = mi.query()
test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?) RETURNING id")
test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"})
}
7 changes: 3 additions & 4 deletions sa/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1047,12 +1047,12 @@ func deleteOrderFQDNSet(
return nil
}

func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error {
func addIssuedNames(ctx context.Context, queryer db.Execer, cert *x509.Certificate, isRenewal bool) error {
if len(cert.DNSNames) == 0 {
return berrors.InternalServerError("certificate has no DNSNames")
}

multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"}, "")
multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"})
if err != nil {
return err
}
Expand All @@ -1067,8 +1067,7 @@ func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certific
return err
}
}
_, err = multiInserter.Insert(ctx, queryer)
return err
return multiInserter.Insert(ctx, queryer)
}

func addKeyHash(ctx context.Context, db db.Inserter, cert *x509.Certificate) error {
Expand Down
59 changes: 11 additions & 48 deletions sa/sa.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/jmhodges/clock"
Expand Down Expand Up @@ -473,53 +472,17 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb

output, err := db.WithTransaction(ctx, ssa.dbMap, func(tx db.Executor) (interface{}, error) {
// First, insert all of the new authorizations and record their IDs.
newAuthzIDs := make([]int64, 0)
if features.Get().InsertAuthzsIndividually {
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
err = tx.Insert(ctx, am)
if err != nil {
return nil, err
}
newAuthzIDs = append(newAuthzIDs, am.ID)
newAuthzIDs := make([]int64, 0, len(req.NewAuthzs))
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
} else {
if len(req.NewAuthzs) != 0 {
inserter, err := db.NewMultiInserter("authz2", strings.Split(authzFields, ", "), "id")
if err != nil {
return nil, err
}
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
err = inserter.Add([]interface{}{
am.ID,
am.IdentifierType,
am.IdentifierValue,
am.RegistrationID,
statusToUint[core.StatusPending],
am.Expires,
am.Challenges,
nil,
nil,
am.Token,
nil,
nil,
})
if err != nil {
return nil, err
}
}
newAuthzIDs, err = inserter.Insert(ctx, tx)
if err != nil {
return nil, err
}
err = tx.Insert(ctx, am)
if err != nil {
return nil, err
}
newAuthzIDs = append(newAuthzIDs, am.ID)
}

// Second, insert the new order.
Expand Down Expand Up @@ -549,7 +512,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb
}

// Third, insert all of the orderToAuthz relations.
inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"}, "")
inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"})
if err != nil {
return nil, err
}
Expand All @@ -565,7 +528,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb
return nil, err
}
}
_, err = inserter.Insert(ctx, tx)
err = inserter.Insert(ctx, tx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 48773ac

Please sign in to comment.