diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index d5bfcfa5..0347351f 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -19,11 +19,12 @@ jobs: - uses: actions/setup-go@v2 with: go-version: "1.x" - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | + ${{ runner.os }}-go-${{ env.GITHUB_JOB }}- ${{ runner.os }}-go- - run: make test - name: uncommitted changes? diff --git a/query/sorting.go b/query/sorting.go index 4259be1c..16c1f624 100644 --- a/query/sorting.go +++ b/query/sorting.go @@ -2,6 +2,7 @@ package query import ( "fmt" + "regexp" "strings" ) @@ -22,6 +23,11 @@ func (c SortCriteria) GoString() string { return fmt.Sprintf("%s %s", c.Tag, c.Order) } +// FieldIdentifierRegex is a regular expression that matches valid field +// identifiers. It is used to validate field names in sorting criteria. This can be +// overridden at init() time to allow for custom field name formats. +var FieldIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_\.]*$`) + // ParseSorting parses raw string that represent sort criteria into a Sorting // data structure. // Provided string is supposed to be in accordance with the sorting collection @@ -47,6 +53,15 @@ func ParseSorting(s string) (*Sorting, error) { return nil, fmt.Errorf("invalid sort criteria: %s", craw) } + // check if tag is not valid + if !FieldIdentifierRegex.MatchString(c.Tag) { + return nil, fmt.Errorf("invalid field name: %s", c.Tag) + } + // check if tag is not empty + if c.Tag == "" { + return nil, fmt.Errorf("empty field name") + } + sorting.Criterias = append(sorting.Criterias, &c) } diff --git a/query/sorting_test.go b/query/sorting_test.go index 335c9f85..c6376815 100644 --- a/query/sorting_test.go +++ b/query/sorting_test.go @@ -1,6 +1,7 @@ package query import ( + "reflect" "testing" ) @@ -59,3 +60,35 @@ func TestParseSorting(t *testing.T) { t.Errorf("invalid error message: %s - expected: %s", err, "invalid sort order - \"dask\" in \"name dask\"") } } + +func TestParseSortingInjection(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want *Sorting + wantErr bool + }{ + { + name: "subquery", + args: args{ + s: "(SELECT/**/1)::int", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseSorting(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("ParseSorting() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseSorting() = %v, want %v", got, tt.want) + } + }) + } +}