diff --git a/config/variable.go b/config/variable.go index 77853a666..e35c2774a 100644 --- a/config/variable.go +++ b/config/variable.go @@ -3,15 +3,16 @@ package config import ( "bytes" "encoding/json" + "errors" "fmt" - "math/big" "os" "path/filepath" + "reflect" "golang.org/x/exp/constraints" ) -const autoTFVarsJson = "generated.auto.tfvars.json" +const AutoTFVarsJson = "generated.auto.tfvars.json" // Variable interface is an alias to json.Marshaler. type Variable interface { @@ -23,7 +24,7 @@ type Variable interface { type Variables map[string]Variable // Write iterates over each element in v and assembles a JSON -// file which is named autoTFVarsJson and written to dest. +// file which is named AutoTFVarsJson and written to dest. func (v Variables) Write(dest string) error { buf := bytes.NewBuffer(nil) @@ -47,7 +48,7 @@ func (v Variables) Write(dest string) error { buf.Write([]byte(`}`)) - outFilename := filepath.Join(dest, autoTFVarsJson) + outFilename := filepath.Join(dest, AutoTFVarsJson) err := os.WriteFile(outFilename, buf.Bytes(), 0700) @@ -65,8 +66,8 @@ type boolVariable struct { } // MarshalJSON returns the JSON encoding of boolVariable. -func (t boolVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v boolVariable) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) } // BoolVariable instantiates an instance of boolVariable, @@ -84,8 +85,12 @@ type listVariable struct { } // MarshalJSON returns the JSON encoding of listVariable. -func (t listVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v listVariable) MarshalJSON() ([]byte, error) { + if !typesEq(v.value) { + return nil, errors.New("lists must contain the same type") + } + + return json.Marshal(v.value) } // ListVariable instantiates an instance of listVariable, @@ -103,8 +108,18 @@ type mapVariable struct { } // MarshalJSON returns the JSON encoding of mapVariable. -func (t mapVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v mapVariable) MarshalJSON() ([]byte, error) { + var variables []Variable + + for _, variable := range v.value { + variables = append(variables, variable) + } + + if !typesEq(variables) { + return nil, errors.New("maps must contain the same type") + } + + return json.Marshal(v.value) } // MapVariable instantiates an instance of mapVariable, @@ -122,8 +137,20 @@ type objectVariable struct { } // MarshalJSON returns the JSON encoding of objectVariable. -func (t objectVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v objectVariable) MarshalJSON() ([]byte, error) { + b, err := json.Marshal(v.value) + + if err != nil { + innerErr := err + + for errors.Unwrap(innerErr) != nil { + innerErr = errors.Unwrap(err) + } + + return nil, innerErr + } + + return b, nil } // ObjectVariable instantiates an instance of objectVariable, @@ -137,7 +164,7 @@ func ObjectVariable(value map[string]Variable) objectVariable { var _ Variable = numberVariable{} type number interface { - constraints.Float | constraints.Integer | *big.Float + constraints.Float | constraints.Integer | string } type numberVariable struct { @@ -145,18 +172,19 @@ type numberVariable struct { } // MarshalJSON returns the JSON encoding of numberVariable. -// If the value of numberVariable is *bigFloat then the -// representation of the value is the smallest number of -// digits required to uniquely identify the value using the -// precision of the *bigFloat that was supplied when -// numberVariable was instantiated. -func (t numberVariable) MarshalJSON() ([]byte, error) { - switch v := t.value.(type) { - case *big.Float: - return []byte(v.Text('g', -1)), nil +// NumberVariable allows initialising a number with any floating +// point or integer type. NumberVariable can be initialised +// with a string for values that do not fit into a floating point +// or integer type. +// TODO: Impose restrictions on what can be held in numberVariable +// to match restrictions imposed by Terraform. +func (v numberVariable) MarshalJSON() ([]byte, error) { + switch v := v.value.(type) { + case string: + return []byte(v), nil } - return json.Marshal(t.value) + return json.Marshal(v.value) } // NumberVariable instantiates an instance of numberVariable, @@ -174,8 +202,33 @@ type setVariable struct { } // MarshalJSON returns the JSON encoding of setVariable. -func (t setVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v setVariable) MarshalJSON() ([]byte, error) { + for kx, x := range v.value { + for ky, y := range v.value { + if kx == ky { + continue + } + + if _, ok := x.(setVariable); !ok { + continue + } + + if _, ok := y.(setVariable); !ok { + continue + } + + if reflect.DeepEqual(x, y) { + return nil, errors.New("sets must contain unique elements") + } + } + + } + + if !typesEq(v.value) { + return nil, errors.New("sets must contain the same type") + } + + return json.Marshal(v.value) } // SetVariable instantiates an instance of setVariable, @@ -193,8 +246,8 @@ type stringVariable struct { } // MarshalJSON returns the JSON encoding of stringVariable. -func (t stringVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v stringVariable) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) } // StringVariable instantiates an instance of stringVariable, @@ -210,8 +263,8 @@ type tupleVariable struct { } // MarshalJSON returns the JSON encoding of tupleVariable. -func (t tupleVariable) MarshalJSON() ([]byte, error) { - return json.Marshal(t.value) +func (v tupleVariable) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) } // TupleVariable instantiates an instance of tupleVariable, @@ -221,3 +274,43 @@ func TupleVariable(value ...Variable) tupleVariable { value: value, } } + +func typesEq(variables []Variable) bool { + var t reflect.Type + + for _, variable := range variables { + switch x := variable.(type) { + case listVariable: + if !typesEq(x.value) { + return false + } + case mapVariable: + var vars []Variable + + for _, v := range x.value { + vars = append(vars, v) + } + + if !typesEq(vars) { + return false + } + case setVariable: + if !typesEq(x.value) { + return false + } + } + + typeOfVariable := reflect.TypeOf(variable) + + if t == nil { + t = typeOfVariable + continue + } + + if t != typeOfVariable { + return false + } + } + + return true +} diff --git a/config/variable_test.go b/config/variable_test.go new file mode 100644 index 000000000..05d568cec --- /dev/null +++ b/config/variable_test.go @@ -0,0 +1,356 @@ +package config_test + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/hashicorp/terraform-plugin-testing/config" +) + +func TestMarshalJSON(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + variable config.Variable + expected []byte + expectedError string + }{ + "bool": { + variable: config.BoolVariable(true), + expected: []byte(`true`), + }, + "list_bool": { + variable: config.ListVariable( + config.BoolVariable(false), + config.BoolVariable(false), + config.BoolVariable(true), + ), + expected: []byte(`[false,false,true]`), + }, + "list_list": { + variable: config.ListVariable( + config.ListVariable( + config.BoolVariable(false), + config.BoolVariable(false), + config.BoolVariable(true), + ), + config.ListVariable( + config.BoolVariable(true), + config.BoolVariable(true), + config.BoolVariable(false), + ), + ), + expected: []byte(`[[false,false,true],[true,true,false]]`), + }, + "list_mixed_types": { + variable: config.ListVariable( + config.BoolVariable(false), + config.StringVariable("str"), + ), + expectedError: "lists must contain the same type", + }, + "list_list_mixed_types": { + variable: config.ListVariable( + config.ListVariable( + config.BoolVariable(false), + config.StringVariable("str"), + ), + ), + expectedError: "lists must contain the same type", + }, + "list_list_mixed_types_multiple_lists": { + variable: config.ListVariable( + config.ListVariable( + config.BoolVariable(false), + config.BoolVariable(false), + ), + config.ListVariable( + config.StringVariable("str"), + config.BoolVariable(false), + ), + ), + expectedError: "lists must contain the same type", + }, + "map_bool": { + variable: config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.BoolVariable(false), + "three": config.BoolVariable(true), + }, + ), + expected: []byte(`{"one":false,"three":true,"two":false}`), + }, + "map_map": { + variable: config.ListVariable( + config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.BoolVariable(false), + "three": config.BoolVariable(true), + }, + ), + config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(true), + "two": config.BoolVariable(true), + "three": config.BoolVariable(false), + }, + ), + ), + expected: []byte(`[{"one":false,"three":true,"two":false},{"one":true,"three":false,"two":true}]`), + }, + "map_mixed_types": { + variable: config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.StringVariable("str"), + }, + ), + expectedError: "maps must contain the same type", + }, + "map_map_mixed_types": { + variable: config.MapVariable( + map[string]config.Variable{ + "mapA": config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.StringVariable("str"), + }, + ), + }, + ), + expectedError: "maps must contain the same type", + }, + "map_map_mixed_types_multiple_maps": { + variable: config.MapVariable( + map[string]config.Variable{ + "mapA": config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.BoolVariable(true), + }, + ), + "mapB": config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.StringVariable("str"), + }, + ), + }, + ), + expectedError: "maps must contain the same type", + }, + "object": { + variable: config.ObjectVariable( + map[string]config.Variable{ + "bool": config.BoolVariable(true), + "list": config.ListVariable( + config.BoolVariable(false), + config.BoolVariable(true), + ), + "map": config.MapVariable( + map[string]config.Variable{ + "one": config.StringVariable("str_one"), + "two": config.StringVariable("str_two"), + }, + ), + }, + ), + expected: []byte(`{"bool":true,"list":[false,true],"map":{"one":"str_one","two":"str_two"}}`), + }, + "object_map_mixed_types": { + variable: config.ObjectVariable( + map[string]config.Variable{ + "bool": config.BoolVariable(true), + "list": config.ListVariable( + config.BoolVariable(false), + config.BoolVariable(true), + ), + "map": config.MapVariable( + map[string]config.Variable{ + "one": config.BoolVariable(false), + "two": config.StringVariable("str_two"), + }, + ), + }, + ), + expectedError: "maps must contain the same type", + }, + "number_float": { + variable: config.NumberVariable(1.2), + expected: []byte(`1.2`), + }, + "number_int": { + variable: config.NumberVariable(12), + expected: []byte(`12`), + }, + "number_big_float": { + variable: config.NumberVariable("1.2000000000000000000000000000000000000000000000000001"), + expected: []byte(`1.2000000000000000000000000000000000000000000000000001`), + }, + "set_bool": { + variable: config.SetVariable( + config.BoolVariable(false), + config.BoolVariable(false), + config.BoolVariable(true), + ), + expected: []byte(`[false,false,true]`), + }, + "set_set": { + variable: config.SetVariable( + config.SetVariable( + config.BoolVariable(false), + config.BoolVariable(false), + config.BoolVariable(true), + ), + config.SetVariable( + config.BoolVariable(true), + config.BoolVariable(true), + config.BoolVariable(false), + ), + ), + expected: []byte(`[[false,false,true],[true,true,false]]`), + }, + "set_mixed_types": { + variable: config.SetVariable( + config.BoolVariable(false), + config.StringVariable("str"), + ), + expectedError: "sets must contain the same type", + }, + "set_set_mixed_types": { + variable: config.SetVariable( + config.SetVariable( + config.BoolVariable(false), + config.StringVariable("str"), + ), + ), + expectedError: "sets must contain the same type", + }, + "set_set_mixed_types_multiple_sets": { + variable: config.SetVariable( + config.SetVariable( + config.BoolVariable(false), + config.BoolVariable(false), + ), + config.SetVariable( + config.StringVariable("str"), + config.BoolVariable(false), + ), + ), + expectedError: "sets must contain the same type", + }, + "set_non_unique": { + variable: config.SetVariable( + config.SetVariable( + config.BoolVariable(false), + config.BoolVariable(false), + ), + config.SetVariable( + config.BoolVariable(false), + config.BoolVariable(false), + ), + ), + expectedError: "sets must contain unique elements", + }, + "string": { + variable: config.StringVariable("str"), + expected: []byte(`"str"`), + }, + "tuple": { + variable: config.TupleVariable( + config.BoolVariable(true), + config.NumberVariable(1.2), + config.StringVariable("str"), + ), + expected: []byte(`[true,1.2,"str"]`), + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := testCase.variable.MarshalJSON() + + if testCase.expectedError == "" && err != nil { + t.Errorf("unexpected error %s", err) + } + + if testCase.expectedError != "" && err == nil { + t.Errorf("expected error but got none") + } + + if testCase.expectedError != "" && err != nil { + if diff := cmp.Diff(err.Error(), testCase.expectedError); diff != "" { + t.Errorf("expected error %s, got error %s", testCase.expectedError, err) + } + } + + if !bytes.Equal(testCase.expected, got) { + t.Errorf("expected %s, got %s", testCase.expected, got) + } + }) + } +} + +func TestVariablesWrite(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + + testCases := map[string]struct { + variables config.Variables + expected []byte + expectedError string + }{ + "write": { + variables: map[string]config.Variable{ + "bool": config.BoolVariable(true), + "string": config.StringVariable("str"), + }, + expected: []byte(`{"bool": true,"string": "str"}`), + }, + } + + for name, testCase := range testCases { + name, testCase := name, testCase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := testCase.variables.Write(tempDir) + + if testCase.expectedError == "" && err != nil { + t.Errorf("unexpected error %s", err) + } + + if testCase.expectedError != "" && err == nil { + t.Errorf("expected error but got none") + } + + if testCase.expectedError != "" && err != nil { + if diff := cmp.Diff(err.Error(), testCase.expectedError); diff != "" { + t.Errorf("expected error %s, got error %s", testCase.expectedError, err) + } + } + + b, err := os.ReadFile(filepath.Join(tempDir, config.AutoTFVarsJson)) + + if err != nil { + t.Errorf("error reading tfvars file: %s", err) + } + + if !bytes.Equal(testCase.expected, b) { + t.Errorf("expected %s, got %s", testCase.expected, b) + } + }) + } +}