Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ jobs:
- uses: actions/setup-go@v2
with:
go-version: "1.x"
- uses: actions/cache@v2
- uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-${{ env.GITHUB_JOB }}-
${{ runner.os }}-go-
- run: make test
- name: uncommitted changes?
Expand Down
15 changes: 15 additions & 0 deletions query/sorting.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package query

import (
"fmt"
"regexp"
"strings"
)

Expand All @@ -22,6 +23,11 @@ func (c SortCriteria) GoString() string {
return fmt.Sprintf("%s %s", c.Tag, c.Order)
}

// FieldIdentifierRegex is a regular expression that matches valid field
// identifiers. It is used to validate field names in sorting criteria. This can be
// overridden at init() time to allow for custom field name formats.
var FieldIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_\.]*$`)

// ParseSorting parses raw string that represent sort criteria into a Sorting
// data structure.
// Provided string is supposed to be in accordance with the sorting collection
Expand All @@ -47,6 +53,15 @@ func ParseSorting(s string) (*Sorting, error) {
return nil, fmt.Errorf("invalid sort criteria: %s", craw)
}

// check if tag is not valid
if !FieldIdentifierRegex.MatchString(c.Tag) {
return nil, fmt.Errorf("invalid field name: %s", c.Tag)
}
// check if tag is not empty
if c.Tag == "" {
return nil, fmt.Errorf("empty field name")
}

sorting.Criterias = append(sorting.Criterias, &c)
}

Expand Down
33 changes: 33 additions & 0 deletions query/sorting_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package query

import (
"reflect"
"testing"
)

Expand Down Expand Up @@ -59,3 +60,35 @@ func TestParseSorting(t *testing.T) {
t.Errorf("invalid error message: %s - expected: %s", err, "invalid sort order - \"dask\" in \"name dask\"")
}
}

func TestParseSortingInjection(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want *Sorting
wantErr bool
}{
{
name: "subquery",
args: args{
s: "(SELECT/**/1)::int",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseSorting(tt.args.s)
if (err != nil) != tt.wantErr {
t.Errorf("ParseSorting() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseSorting() = %v, want %v", got, tt.want)
}
})
}
}
Loading