diff --git a/lotman/lotman_linux.go b/lotman/lotman_linux.go index 8feabcec1..932babcf5 100644 --- a/lotman/lotman_linux.go +++ b/lotman/lotman_linux.go @@ -28,6 +28,7 @@ import ( "encoding/json" "fmt" "os" + "reflect" "runtime" "strconv" "strings" @@ -37,6 +38,7 @@ import ( "unsafe" "github.com/ebitengine/purego" + "github.com/mitchellh/mapstructure" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -77,7 +79,7 @@ var ( type ( Int64FromFloat struct { - Value int64 + Value int64 `mapstructure:"Value"` } LotPath struct { @@ -170,12 +172,12 @@ type ( } PurgePolicy struct { - PurgeOrder []string `json:"purge_order"` - PolicyName string `json:"policy_name"` - DiscoverPrefixes bool `json:"discover_prefixes"` - MergeLocalWithDiscovered bool `json:"merge_local_with_discovered"` - DivideUnallocated bool `json:"divide_unallocated"` - Lots []Lot `json:"lots"` + PurgeOrder []string `mapstructure:"PurgeOrder"` + PolicyName string `mapstructure:"PolicyName"` + DiscoverPrefixes bool `mapstructure:"DiscoverPrefixes"` + MergeLocalWithDiscovered bool `mapstructure:"MergeLocalWithDiscovered"` + DivideUnallocated bool `mapstructure:"DivideUnallocated"` + Lots []Lot `mapstructure:"Lots"` } ) @@ -428,13 +430,40 @@ func mergeLotMaps(map1, map2 map[string]Lot) (map[string]Lot, error) { return result, nil } +// A hook function for mapstructure that validates that all fields in the map are present in the struct. +// Used to verify the user's input for PolicyDefinitions, since these aren't top-level fields in parameters.yaml +func validateFieldsHook() mapstructure.DecodeHookFunc { + return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.Map || to.Kind() != reflect.Struct { + return data, nil + } + + mapKeys := reflect.ValueOf(data).MapKeys() + structFields := make(map[string]struct{}) + for i := 0; i < to.NumField(); i++ { + field := to.Field(i) + // Normalize the field name to lowercase + structFields[strings.ToLower(field.Tag.Get("mapstructure"))] = struct{}{} + } + + // Check for unknown fields + for _, key := range mapKeys { + if _, ok := structFields[strings.ToLower(key.String())]; !ok { + return nil, fmt.Errorf("unknown configuration field in Lotman policy definitions: %s", key.String()) + } + } + + return data, nil + } +} + // Grab a map of policy definitions from the config file, where the policy // name is the key and its attributes comprise the value. func getPolicyMap() (map[string]PurgePolicy, error) { policyMap := make(map[string]PurgePolicy) var policies []PurgePolicy - err := viper.UnmarshalKey("Lotman.PolicyDefinitions", &policies) - if err != nil { + // Use custom decoder hook to validate fields. This validates all the way down to the bottom of the lot object. + if err := viper.UnmarshalKey(param.Lotman_PolicyDefinitions.GetName(), &policies, viper.DecodeHook(validateFieldsHook())); err != nil { return policyMap, errors.Wrap(err, "error unmarshaling Lotman policy definitions") } diff --git a/lotman/lotman_test.go b/lotman/lotman_test.go index a03f47917..d872d4b02 100644 --- a/lotman/lotman_test.go +++ b/lotman/lotman_test.go @@ -47,6 +47,9 @@ import ( //go:embed resources/lots-config.yaml var yamlMockup string +//go:embed resources/malformed-lots-config.yaml +var badYamlMockup string + // Helper function for determining policy index from lot config yaml func findPolicyIndex(policyName string, policies []PurgePolicy) int { for i, policy := range policies { @@ -523,18 +526,48 @@ func TestLotMerging(t *testing.T) { func TestGetPolicyMap(t *testing.T) { server_utils.ResetTestState() defer server_utils.ResetTestState() - viper.SetConfigType("yaml") - err := viper.ReadConfig(strings.NewReader(yamlMockup)) - if err != nil { - t.Fatalf("Error reading config: %v", err) + + testCases := []struct { + name string + yamlConfig string + expectErr bool + expectedPolicies []string + }{ + { + name: "ValidConfig", + yamlConfig: yamlMockup, + expectErr: false, + expectedPolicies: []string{"different-policy", "another policy"}, + }, + { + name: "InvalidConfig", + yamlConfig: badYamlMockup, + expectErr: true, + expectedPolicies: nil, + }, } - policyMap, err := getPolicyMap() - require.NoError(t, err) - require.Equal(t, 2, len(policyMap)) - require.Contains(t, policyMap, "different-policy") - require.Contains(t, policyMap, "another policy") - require.Equal(t, "different-policy", viper.GetString("Lotman.EnabledPolicy")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + viper.SetConfigType("yaml") + err := viper.ReadConfig(strings.NewReader(tc.yamlConfig)) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + policyMap, err := getPolicyMap() + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, len(tc.expectedPolicies), len(policyMap)) + for _, policy := range tc.expectedPolicies { + require.Contains(t, policyMap, policy) + } + require.Equal(t, "different-policy", viper.GetString("Lotman.EnabledPolicy")) + } + }) + } } func TestByteConversions(t *testing.T) { diff --git a/lotman/resources/malformed-lots-config.yaml b/lotman/resources/malformed-lots-config.yaml new file mode 100644 index 000000000..b1f2fd91d --- /dev/null +++ b/lotman/resources/malformed-lots-config.yaml @@ -0,0 +1,25 @@ +# *************************************************************** +# +# Copyright (C) 2024, Pelican Project, Morgridge Institute for Research +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# *************************************************************** + +Lotman: + EnabledPolicy: "my-bad-policy" + PolicyDefinitions: + - PolicyName: "my-bad-policy" + IShouldCreateAnUnmarshalError: true + NoReallyImBad: ["ded", "opp", "exp", "del"] + PleaseDontLetMeWork: true