From a001152658d25cba2078e870f166e4370a28a5e1 Mon Sep 17 00:00:00 2001 From: Doug Clark Date: Sat, 8 Feb 2025 00:45:31 -0600 Subject: [PATCH] feat: add support for table column mapping --- expr.go | 8 +++++++- expr_test.go | 20 +++++++++++++++----- mql.go | 2 +- mql_test.go | 13 +++++++++++++ options.go | 17 +++++++++++++++-- 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/expr.go b/expr.go index 1e4943f..0ef7771 100644 --- a/expr.go +++ b/expr.go @@ -77,7 +77,7 @@ func (e *comparisonExpr) isComplete() bool { // defaultValidateConvert will validate the comparison expr value, and then convert the // expr to its SQL equivalence. -func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) { +func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opts options) (*WhereClause, error) { const op = "mql.(comparisonExpr).convertToSql" switch { case columnName == "": @@ -103,6 +103,12 @@ func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, column if err != nil { return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter) } + newCol, ok := opts.withTableColumnMap[columnName] + if ok { + // override our column name with the mapped column name + columnName = newCol + } + if validator.typ == "time" { columnName = fmt.Sprintf("%s::date", columnName) } diff --git a/expr_test.go b/expr_test.go index b16e365..937dd37 100644 --- a/expr_test.go +++ b/expr_test.go @@ -93,39 +93,49 @@ func Test_defaultValidateConvert(t *testing.T) { t.Parallel() fValidators, err := fieldValidators(reflect.ValueOf(testModel{})) require.NoError(t, err) + opts := getDefaultOptions() t.Run("missing-column", func(t *testing.T) { - e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"]) + e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"], opts) require.Error(t, err) assert.Empty(t, e) assert.ErrorIs(t, err, ErrMissingColumn) assert.ErrorContains(t, err, "missing column") }) t.Run("missing-comparison-op", func(t *testing.T) { - e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"]) + e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"], opts) require.Error(t, err) assert.Empty(t, e) assert.ErrorIs(t, err, ErrMissingComparisonOp) assert.ErrorContains(t, err, "missing comparison operator") }) t.Run("missing-value", func(t *testing.T) { - e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"]) + e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"], opts) require.Error(t, err) assert.Empty(t, e) assert.ErrorIs(t, err, ErrMissingComparisonValue) assert.ErrorContains(t, err, "missing comparison value") }) t.Run("missing-validator-func", func(t *testing.T) { - e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"}) + e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"}, opts) require.Error(t, err) assert.Empty(t, e) assert.ErrorIs(t, err, ErrInvalidParameter) assert.ErrorContains(t, err, "missing validator function") }) t.Run("missing-validator-typ", func(t *testing.T) { - e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn}) + e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn}, opts) require.Error(t, err) assert.Empty(t, e) assert.ErrorIs(t, err, ErrInvalidParameter) assert.ErrorContains(t, err, "missing validator type") }) + t.Run("success-with-table-override", func(t *testing.T) { + opts.withTableColumnMap["name"] = "users.name" + e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn, typ: "default"}, opts) + assert.Empty(t, err) + assert.NotEmpty(t, e) + assert.Equal(t, "users.name=?", e.Condition, "condition") + assert.Len(t, e.Args, 1, "args") + assert.Equal(t, "alice", e.Args[0], "args[0]") + }) } diff --git a/mql.go b/mql.go index aa8154a..7abeea0 100644 --- a/mql.go +++ b/mql.go @@ -87,7 +87,7 @@ func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option) } return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols) } - w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...) + w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opts) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } diff --git a/mql_test.go b/mql_test.go index 98d0f0c..3f8b024 100644 --- a/mql_test.go +++ b/mql_test.go @@ -306,6 +306,19 @@ func TestParse(t *testing.T) { wantErrIs: mql.ErrInvalidParameter, wantErrContains: "missing ConvertToSqlFunc: invalid parameter", }, + { + name: "success-with-table-column-map", + query: "custom_name=\"alice\"", + model: testModel{}, + opts: []mql.Option{ + mql.WithColumnMap(map[string]string{"custom_name": "name"}), + mql.WithTableColumnMap(map[string]string{"name": "users.custom->>'name'"}), + }, + want: &mql.WhereClause{ + Condition: "users.custom->>'name'=?", + Args: []any{"alice"}, + }, + }, } for _, tc := range tests { tc := tc diff --git a/options.go b/options.go index 54d2bdb..827dc63 100644 --- a/options.go +++ b/options.go @@ -13,6 +13,7 @@ type options struct { withValidateConvertFns map[string]ValidateConvertFunc withIgnoredFields []string withPgPlaceholder bool + withTableColumnMap map[string]string // map of model field names to their table.column name } // Option - how options are passed as args @@ -22,6 +23,7 @@ func getDefaultOptions() options { return options{ withColumnMap: make(map[string]string), withValidateConvertFns: make(map[string]ValidateConvertFunc), + withTableColumnMap: make(map[string]string), } } @@ -44,8 +46,8 @@ func withSkipWhitespace() Option { } } -// WithColumnMap provides an optional map of columns from a column in the user -// provided query to a column in the database model +// WithColumnMap provides an optional map of columns from the user +// provided query to a field in the given model func WithColumnMap(m map[string]string) Option { return func(o *options) error { if !isNil(m) { @@ -100,3 +102,14 @@ func WithPgPlaceholders() Option { return nil } } + +// WithTableColumnMap provides an optional map of columns from the +// model to the table.column name in the generated where clause +func WithTableColumnMap(m map[string]string) Option { + return func(o *options) error { + if !isNil(m) { + o.withTableColumnMap = m + } + return nil + } +}