diff --git a/parameters/query_parameters.go b/parameters/query_parameters.go index 89659a4..01a97aa 100644 --- a/parameters/query_parameters.go +++ b/parameters/query_parameters.go @@ -108,43 +108,15 @@ doneLooking: switch ty { case helpers.String: - - // check if the param is within an enum - if sch.Enum != nil { - matchFound := false - for _, enumVal := range sch.Enum { - if strings.TrimSpace(ef) == fmt.Sprint(enumVal.Value) { - matchFound = true - break - } - } - if !matchFound { - validationErrors = append(validationErrors, - errors.IncorrectQueryParamEnum(params[p], ef, sch)) - } - } - + validationErrors = v.validateSimpleParam(sch, ef, ef, params[p]) case helpers.Integer, helpers.Number: - if _, err := strconv.ParseFloat(ef, 64); err != nil { + efF, err := strconv.ParseFloat(ef, 64) + if err != nil { validationErrors = append(validationErrors, errors.InvalidQueryParamNumber(params[p], ef, sch)) break } - // check if the param is within an enum - if sch.Enum != nil { - matchFound := false - for _, enumVal := range sch.Enum { - if strings.TrimSpace(ef) == fmt.Sprint(enumVal.Value) { - matchFound = true - break - } - } - if !matchFound { - validationErrors = append(validationErrors, - errors.IncorrectQueryParamEnum(params[p], ef, sch)) - } - } - + validationErrors = v.validateSimpleParam(sch, ef, efF, params[p]) case helpers.Boolean: if _, err := strconv.ParseBool(ef); err != nil { validationErrors = append(validationErrors, @@ -245,3 +217,29 @@ doneLooking: } return true, nil } + +func (v *paramValidator) validateSimpleParam(sch *base.Schema, rawParam string, parsedParam any, parameter *v3.Parameter) (validationErrors []*errors.ValidationError) { + // check if the param is within an enum + if sch.Enum != nil { + matchFound := false + for _, enumVal := range sch.Enum { + if strings.TrimSpace(rawParam) == fmt.Sprint(enumVal.Value) { + matchFound = true + break + } + } + if !matchFound { + return []*errors.ValidationError{errors.IncorrectQueryParamEnum(parameter, rawParam, sch)} + } + } + + return ValidateSingleParameterSchema( + sch, + parsedParam, + "Query parameter", + "The query parameter", + parameter.Name, + helpers.ParameterValidation, + helpers.ParameterValidationQuery, + ) +} diff --git a/parameters/query_parameters_test.go b/parameters/query_parameters_test.go index 269c158..84d47d2 100644 --- a/parameters/query_parameters_test.go +++ b/parameters/query_parameters_test.go @@ -10,6 +10,7 @@ import ( "github.com/pb33f/libopenapi" "github.com/pb33f/libopenapi-validator/paths" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewValidator_QueryParamMissing(t *testing.T) { @@ -68,6 +69,66 @@ paths: assert.Nil(t, errors) } +func TestNewValidator_QueryParamMinimum_violation(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /a/fishy/on/a/dishy: + get: + parameters: + - name: fishy + in: query + required: true + schema: + type: string + minLength: 4 + operationId: locateFishy +` + + doc, err := libopenapi.NewDocument([]byte(spec)) + require.NoError(t, err) + m, errs := doc.BuildV3Model() + require.Len(t, errs, 0) + + v := NewParameterValidator(&m.Model) + + request, _ := http.NewRequest(http.MethodGet, "https://things.com/a/fishy/on/a/dishy?fishy=cod", nil) + + valid, errors := v.ValidateQueryParams(request) + assert.False(t, valid) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "Query parameter 'fishy' failed to validate", errors[0].Message) +} + +func TestNewValidator_QueryParamMinimum(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /a/fishy/on/a/dishy: + get: + parameters: + - name: fishy + in: query + required: true + schema: + type: string + minLength: 4 + operationId: locateFishy +` + + doc, err := libopenapi.NewDocument([]byte(spec)) + require.NoError(t, err) + m, errs := doc.BuildV3Model() + require.Len(t, errs, 0) + + v := NewParameterValidator(&m.Model) + + request, _ := http.NewRequest(http.MethodGet, "https://things.com/a/fishy/on/a/dishy?fishy=salmon", nil) + + valid, errors := v.ValidateQueryParams(request) + assert.True(t, valid) + + assert.Nil(t, errors) +} + func TestNewValidator_QueryParamPost(t *testing.T) { spec := `openapi: 3.1.0 paths: @@ -348,6 +409,66 @@ paths: assert.Nil(t, errors) } +func TestNewValidator_QueryParamMinimumNumber(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /a/fishy/on/a/dishy: + get: + parameters: + - name: fishy + in: query + required: true + schema: + type: number + minimum: 200 + operationId: locateFishy +` + + doc, err := libopenapi.NewDocument([]byte(spec)) + require.NoError(t, err) + m, errs := doc.BuildV3Model() + require.Len(t, errs, 0) + + v := NewParameterValidator(&m.Model) + + request, _ := http.NewRequest(http.MethodGet, "https://things.com/a/fishy/on/a/dishy?fishy=300", nil) + + valid, errors := v.ValidateQueryParams(request) + assert.True(t, valid) + + assert.Nil(t, errors) +} + +func TestNewValidator_QueryParamMinimumNumber_violation(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /a/fishy/on/a/dishy: + get: + parameters: + - name: fishy + in: query + required: true + schema: + type: number + minimum: 200 + operationId: locateFishy +` + + doc, err := libopenapi.NewDocument([]byte(spec)) + require.NoError(t, err) + m, errs := doc.BuildV3Model() + require.Len(t, errs, 0) + + v := NewParameterValidator(&m.Model) + + request, _ := http.NewRequest(http.MethodGet, "https://things.com/a/fishy/on/a/dishy?fishy=123", nil) + + valid, errors := v.ValidateQueryParams(request) + assert.False(t, valid) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "Query parameter 'fishy' failed to validate", errors[0].Message) +} + func TestNewValidator_QueryParamValidTypeFloat(t *testing.T) { spec := `openapi: 3.1.0 paths: @@ -2533,7 +2654,7 @@ components: doc, _ := libopenapi.NewDocument([]byte(spec)) m, err := doc.BuildV3Model() - assert.Len(t, err, 0) //no patch build here + assert.Len(t, err, 0) // no patch build here v := NewParameterValidator(&m.Model) diff --git a/parameters/validate_parameter.go b/parameters/validate_parameter.go index 2625245..a6f5a28 100644 --- a/parameters/validate_parameter.go +++ b/parameters/validate_parameter.go @@ -5,16 +5,52 @@ package parameters import ( "encoding/json" + stdError "errors" "fmt" + "net/url" + "reflect" + "strings" + "github.com/pb33f/libopenapi-validator/errors" "github.com/pb33f/libopenapi/datamodel/high/base" "github.com/pb33f/libopenapi/utils" "github.com/santhosh-tekuri/jsonschema/v5" - "net/url" - "reflect" - "strings" ) +func ValidateSingleParameterSchema( + schema *base.Schema, + rawObject any, + entity string, + reasonEntity string, + name string, + validationType string, + subValType string, +) (validationErrors []*errors.ValidationError) { + jsch := compileSchema(name, buildJsonRender(schema)) + + scErrs := jsch.Validate(rawObject) + var werras *jsonschema.ValidationError + if stdError.As(scErrs, &werras) { + validationErrors = formatJsonSchemaValidationError(schema, werras, entity, reasonEntity, name, validationType, subValType) + } + return validationErrors +} + +// compileSchema create a new json schema compiler and add the schema to it. +func compileSchema(name string, jsonSchema []byte) *jsonschema.Schema { + compiler := jsonschema.NewCompiler() + _ = compiler.AddResource(fmt.Sprintf("%s.json", name), strings.NewReader(string(jsonSchema))) + jsch, _ := compiler.Compile(fmt.Sprintf("%s.json", name)) + return jsch +} + +// buildJsonRender build a JSON render of the schema. +func buildJsonRender(schema *base.Schema) []byte { + renderedSchema, _ := schema.Render() + jsonSchema, _ := utils.ConvertYAMLtoJSON(renderedSchema) + return jsonSchema +} + // ValidateParameterSchema will validate a parameter against a raw object, or a blob of json/yaml. // It will return a list of validation errors, if any. // @@ -108,35 +144,9 @@ func ValidateParameterSchema( } } } - if scErrs != nil { - jk := scErrs.(*jsonschema.ValidationError) - - // flatten the validationErrors - schFlatErrs := jk.BasicOutput().Errors - var schemaValidationErrors []*errors.SchemaValidationFailure - for q := range schFlatErrs { - er := schFlatErrs[q] - if er.KeywordLocation == "" || strings.HasPrefix(er.Error, "doesn't validate with") { - continue // ignore this error, it's not useful - } - schemaValidationErrors = append(schemaValidationErrors, &errors.SchemaValidationFailure{ - Reason: er.Error, - Location: er.KeywordLocation, - OriginalError: jk, - }) - } - // add the error to the list - validationErrors = append(validationErrors, &errors.ValidationError{ - ValidationType: validationType, - ValidationSubType: subValType, - Message: fmt.Sprintf("%s '%s' failed to validate", entity, name), - Reason: fmt.Sprintf("%s '%s' is defined as an object, "+ - "however it failed to pass a schema validation", reasonEntity, name), - SpecLine: schema.GoLow().Type.KeyNode.Line, - SpecCol: schema.GoLow().Type.KeyNode.Column, - SchemaValidationErrors: schemaValidationErrors, - HowToFix: errors.HowToFixInvalidSchema, - }) + var werras *jsonschema.ValidationError + if stdError.As(scErrs, &werras) { + validationErrors = formatJsonSchemaValidationError(schema, werras, entity, reasonEntity, name, validationType, subValType) } // if there are no validationErrors, check that the supplied value is even JSON @@ -159,3 +169,36 @@ func ValidateParameterSchema( } return validationErrors } + +func formatJsonSchemaValidationError(schema *base.Schema, scErrs *jsonschema.ValidationError, entity string, reasonEntity string, name string, validationType string, subValType string) (validationErrors []*errors.ValidationError) { + // flatten the validationErrors + schFlatErrs := scErrs.BasicOutput().Errors + var schemaValidationErrors []*errors.SchemaValidationFailure + for q := range schFlatErrs { + er := schFlatErrs[q] + if er.KeywordLocation == "" || strings.HasPrefix(er.Error, "doesn't validate with") { + continue // ignore this error, it's not useful + } + schemaValidationErrors = append(schemaValidationErrors, &errors.SchemaValidationFailure{ + Reason: er.Error, + Location: er.KeywordLocation, + OriginalError: scErrs, + }) + } + schemaType := "undefined" + if len(schema.Type) > 0 { + schemaType = schema.Type[0] + } + validationErrors = append(validationErrors, &errors.ValidationError{ + ValidationType: validationType, + ValidationSubType: subValType, + Message: fmt.Sprintf("%s '%s' failed to validate", entity, name), + Reason: fmt.Sprintf("%s '%s' is defined as an %s, "+ + "however it failed to pass a schema validation", reasonEntity, name, schemaType), + SpecLine: schema.GoLow().Type.KeyNode.Line, + SpecCol: schema.GoLow().Type.KeyNode.Column, + SchemaValidationErrors: schemaValidationErrors, + HowToFix: errors.HowToFixInvalidSchema, + }) + return validationErrors +}