Skip to content

Commit

Permalink
feat: add support for table column mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
dlclark committed Feb 8, 2025
1 parent 0148865 commit a001152
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
8 changes: 7 additions & 1 deletion expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (e *comparisonExpr) isComplete() bool {

// defaultValidateConvert will validate the comparison expr value, and then convert the
// expr to its SQL equivalence.
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) {
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opts options) (*WhereClause, error) {
const op = "mql.(comparisonExpr).convertToSql"
switch {
case columnName == "":
Expand All @@ -103,6 +103,12 @@ func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, column
if err != nil {
return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter)
}
newCol, ok := opts.withTableColumnMap[columnName]
if ok {
// override our column name with the mapped column name
columnName = newCol
}

if validator.typ == "time" {
columnName = fmt.Sprintf("%s::date", columnName)
}
Expand Down
20 changes: 15 additions & 5 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,39 +93,49 @@ func Test_defaultValidateConvert(t *testing.T) {
t.Parallel()
fValidators, err := fieldValidators(reflect.ValueOf(testModel{}))
require.NoError(t, err)
opts := getDefaultOptions()
t.Run("missing-column", func(t *testing.T) {
e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"])
e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"], opts)
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingColumn)
assert.ErrorContains(t, err, "missing column")
})
t.Run("missing-comparison-op", func(t *testing.T) {
e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"])
e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"], opts)
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingComparisonOp)
assert.ErrorContains(t, err, "missing comparison operator")
})
t.Run("missing-value", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"])
e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"], opts)
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingComparisonValue)
assert.ErrorContains(t, err, "missing comparison value")
})
t.Run("missing-validator-func", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"})
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"}, opts)
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "missing validator function")
})
t.Run("missing-validator-typ", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn})
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn}, opts)
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "missing validator type")
})
t.Run("success-with-table-override", func(t *testing.T) {
opts.withTableColumnMap["name"] = "users.name"
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn, typ: "default"}, opts)
assert.Empty(t, err)
assert.NotEmpty(t, e)
assert.Equal(t, "users.name=?", e.Condition, "condition")
assert.Len(t, e.Args, 1, "args")
assert.Equal(t, "alice", e.Args[0], "args[0]")
})
}
2 changes: 1 addition & 1 deletion mql.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option)
}
return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols)
}
w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...)
w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opts)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
Expand Down
13 changes: 13 additions & 0 deletions mql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,19 @@ func TestParse(t *testing.T) {
wantErrIs: mql.ErrInvalidParameter,
wantErrContains: "missing ConvertToSqlFunc: invalid parameter",
},
{
name: "success-with-table-column-map",
query: "custom_name=\"alice\"",
model: testModel{},
opts: []mql.Option{
mql.WithColumnMap(map[string]string{"custom_name": "name"}),
mql.WithTableColumnMap(map[string]string{"name": "users.custom->>'name'"}),
},
want: &mql.WhereClause{
Condition: "users.custom->>'name'=?",
Args: []any{"alice"},
},
},
}
for _, tc := range tests {
tc := tc
Expand Down
17 changes: 15 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type options struct {
withValidateConvertFns map[string]ValidateConvertFunc
withIgnoredFields []string
withPgPlaceholder bool
withTableColumnMap map[string]string // map of model field names to their table.column name
}

// Option - how options are passed as args
Expand All @@ -22,6 +23,7 @@ func getDefaultOptions() options {
return options{
withColumnMap: make(map[string]string),
withValidateConvertFns: make(map[string]ValidateConvertFunc),
withTableColumnMap: make(map[string]string),
}
}

Expand All @@ -44,8 +46,8 @@ func withSkipWhitespace() Option {
}
}

// WithColumnMap provides an optional map of columns from a column in the user
// provided query to a column in the database model
// WithColumnMap provides an optional map of columns from the user
// provided query to a field in the given model
func WithColumnMap(m map[string]string) Option {
return func(o *options) error {
if !isNil(m) {
Expand Down Expand Up @@ -100,3 +102,14 @@ func WithPgPlaceholders() Option {
return nil
}
}

// WithTableColumnMap provides an optional map of columns from the
// model to the table.column name in the generated where clause
func WithTableColumnMap(m map[string]string) Option {
return func(o *options) error {
if !isNil(m) {
o.withTableColumnMap = m
}
return nil
}
}

0 comments on commit a001152

Please sign in to comment.