diff --git a/README.md b/README.md index 3b9d84d..63796a2 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ DynamoDB Adapter currently supports the following DynamoDB data types | `BOOL` (boolean) | `BOOL` | `B` (binary type) | `BYTES(MAX)` | `S` (string and data values) | `STRING(MAX)` +| `SS` (string set) | `ARRAY` ## Configuration diff --git a/api/v1/condition.go b/api/v1/condition.go index 7bcd7d1..29c1a1f 100644 --- a/api/v1/condition.go +++ b/api/v1/condition.go @@ -454,7 +454,7 @@ func ChangeResponseColumn(obj map[string]interface{}) map[string]interface{} { // ChangeColumnToSpanner converts original column name to spanner supported column names func ChangeColumnToSpanner(obj map[string]interface{}) map[string]interface{} { rs := make(map[string]interface{}) - + for k, v := range obj { if k1, ok := models.ColumnToOriginalCol[k]; ok { @@ -519,19 +519,47 @@ func convertFrom(a *dynamodb.AttributeValue, tableName string) interface{} { return a.B } if a.SS != nil { - l := make([]interface{}, len(a.SS)) - for index, v := range a.SS { - l[index] = *v + uniqueStrings := make(map[string]struct{}) + for _, v := range a.SS { + uniqueStrings[*v] = struct{}{} + } + + // Convert map keys to a slice + l := make([]string, 0, len(uniqueStrings)) + for str := range uniqueStrings { + l = append(l, str) } + return l } if a.NS != nil { - l := make([]interface{}, len(a.NS)) - for index, v := range a.NS { - l[index], _ = strconv.ParseFloat(*v, 64) + l := []float64{} + numberMap := make(map[string]struct{}) + for _, v := range a.NS { + if _, exists := numberMap[*v]; !exists { + numberMap[*v] = struct{}{} + n, err := strconv.ParseFloat(*v, 64) + if err != nil { + panic(fmt.Sprintf("Invalid number in NS: %s", *v)) + } + l = append(l, n) + } } return l } + if a.BS != nil { + // Handle Binary Set + binarySet := [][]byte{} + binaryMap := make(map[string]struct{}) + for _, v := range a.BS { + key := string(v) + if _, exists := binaryMap[key]; !exists { + binaryMap[key] = struct{}{} + binarySet = append(binarySet, v) + } + } + return binarySet + } panic(fmt.Sprintf("%#v is not a supported dynamodb.AttributeValue", a)) } @@ -631,7 +659,7 @@ func ChangeQueryResponseColumn(tableName string, obj map[string]interface{}) map return obj } -//ChangeMaptoDynamoMap converts simple map into dynamo map +// ChangeMaptoDynamoMap converts simple map into dynamo map func ChangeMaptoDynamoMap(in interface{}) (map[string]interface{}, error) { if in == nil { return nil, nil @@ -699,6 +727,34 @@ func convertSlice(output map[string]interface{}, v reflect.Value) error { return nil } output["B"] = append([]byte{}, b...) + case reflect.String: + listVal := []string{} + count := 0 + for i := 0; i < v.Len(); i++ { + listVal = append(listVal, v.Index(i).String()) + count++ + } + output["SS"] = listVal + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64: + listVal := []string{} + for i := 0; i < v.Len(); i++ { + listVal = append(listVal, fmt.Sprintf("%v", v.Index(i).Interface())) + } + output["NS"] = listVal + + case reflect.Slice: + if v.Type().Elem().Kind() == reflect.Uint8 { + binarySet := [][]byte{} + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + if elem.Kind() == reflect.Slice && elem.IsValid() && !elem.IsNil() { + binarySet = append(binarySet, elem.Bytes()) + } + } + output["BS"] = binarySet + } + default: listVal := make([]map[string]interface{}, 0, v.Len()) diff --git a/api/v1/condition_test.go b/api/v1/condition_test.go index 6d12f64..4b30739 100644 --- a/api/v1/condition_test.go +++ b/api/v1/condition_test.go @@ -327,11 +327,13 @@ func TestConvertDynamoToMap(t *testing.T) { "address": {S: aws.String("Ney York")}, "first_name": {S: aws.String("Catalina")}, "last_name": {S: aws.String("Smith")}, + "titles": {SS: aws.StringSlice([]string{"Mr", "Dr"})}, }, map[string]interface{}{ "address": "Ney York", "first_name": "Catalina", "last_name": "Smith", + "titles": []string{"Mr", "Dr"}, }, }, { @@ -402,11 +404,7 @@ func TestChangeMaptoDynamoMap(t *testing.T) { "age": map[string]interface{}{"N": "20"}, "value": map[string]interface{}{"N": "10"}, "array": map[string]interface{}{ - "L": []map[string]interface{}{ - {"S": "first"}, - {"S": "second"}, - {"S": "third"}, - }, + "SS": []string{"first", "second", "third"}, }, }, }, diff --git a/api/v1/db.go b/api/v1/db.go index e499c65..dce1f0a 100644 --- a/api/v1/db.go +++ b/api/v1/db.go @@ -59,7 +59,7 @@ func RouteRequest(c *gin.Context) { case "UpdateItem": Update(c) default: - c.JSON(errors.New("ValidationException", "Invalid X-Amz-Target header value of" + amzTarget). + c.JSON(errors.New("ValidationException", "Invalid X-Amz-Target header value of"+amzTarget). HTTPResponse("X-Amz-Target Header not supported")) } } diff --git a/storage/spanner.go b/storage/spanner.go index 544f857..625aa4b 100755 --- a/storage/spanner.go +++ b/storage/spanner.go @@ -612,7 +612,6 @@ func (s Storage) performPutOperation(ctx context.Context, t *spanner.ReadWriteTr m[k] = ba } } - mutation := spanner.InsertOrUpdateMap(table, m) mutations := []*spanner.Mutation{mutation} err := t.BufferWrite(mutations) @@ -733,7 +732,7 @@ func evaluateStatementFromRowMap(conditionalExpression, colName string, rowMap m return true } _, ok := rowMap[colName] - return !ok + return !ok } if strings.HasPrefix(conditionalExpression, "attribute_exists") || strings.HasPrefix(conditionalExpression, "if_exists") { if len(rowMap) == 0 { @@ -745,7 +744,6 @@ func evaluateStatementFromRowMap(conditionalExpression, colName string, rowMap m return rowMap[conditionalExpression] } -//parseRow - Converts Spanner row and datatypes to a map removing null columns from the result. func parseRow(r *spanner.Row, colDDL map[string]string) (map[string]interface{}, error) { singleRow := make(map[string]interface{}) if r == nil { @@ -761,131 +759,190 @@ func parseRow(r *spanner.Row, colDDL map[string]string) (map[string]interface{}, if !ok { return nil, errors.New("ResourceNotFoundException", k) } + + var err error switch v { case "STRING(MAX)": - var s spanner.NullString - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) - } - if !s.IsNull() { - singleRow[k] = s.StringVal - } + err = parseStringColumn(r, i, k, singleRow) case "BYTES(MAX)": - var s []byte - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) - } - if len(s) > 0 { - var m interface{} - err := json.Unmarshal(s, &m) - if err != nil { - logger.LogError(err, string(s)) - singleRow[k] = string(s) - continue - } - val1, ok := m.(string) - if ok { - if base64Regexp.MatchString(val1) { - ba, err := base64.StdEncoding.DecodeString(val1) - if err == nil { - var sample interface{} - err = json.Unmarshal(ba, &sample) - if err == nil { - singleRow[k] = sample - continue - } else { - singleRow[k] = string(s) - continue - } - } - } - } - - if mp, ok := m.(map[string]interface{}); ok { - for k, v := range mp { - if val, ok := v.(string); ok { - if base64Regexp.MatchString(val) { - ba, err := base64.StdEncoding.DecodeString(val) - if err == nil { - var sample interface{} - err = json.Unmarshal(ba, &sample) - if err == nil { - mp[k] = sample - m = mp - } - } - } - } - } - } - singleRow[k] = m - } + err = parseBytesColumn(r, i, k, singleRow) case "INT64": - var s spanner.NullInt64 - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) - } - if !s.IsNull() { - singleRow[k] = s.Int64 - } + err = parseInt64Column(r, i, k, singleRow) case "FLOAT64": - var s spanner.NullFloat64 - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) - - } - if !s.IsNull() { - singleRow[k] = s.Float64 - } + err = parseFloat64Column(r, i, k, singleRow) case "NUMERIC": - var s spanner.NullNumeric - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) - } - if !s.IsNull() { - if s.Numeric.IsInt() { - tmp, _ := s.Numeric.Float64() - singleRow[k] = int64(tmp) - } else { - singleRow[k], _ = s.Numeric.Float64() - } - } + err = parseNumericColumn(r, i, k, singleRow) case "BOOL": - var s spanner.NullBool - err := r.Column(i, &s) - if err != nil { - if strings.Contains(err.Error(), "ambiguous column name") { - continue - } - return nil, errors.New("ValidationException", err, k) + err = parseBoolColumn(r, i, k, singleRow) + case "ARRAY": + err = parseStringArrayColumn(r, i, k, singleRow) + case "ARRAY": + err = parseByteArrayColumn(r, i, k, singleRow) + case "ARRAY": + err = parseNumberArrayColumn(r, i, k, singleRow) + default: + return nil, errors.New("TypeNotFound", err, k) + } + + if err != nil { + return nil, errors.New("ValidationException", err, k) + } + } + return singleRow, nil +} + +func parseStringColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s spanner.NullString + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if !s.IsNull() { + row[col] = s.StringVal + } + return nil +} + +func parseBytesColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s []byte + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if len(s) > 0 { + var m interface{} + if err := json.Unmarshal(s, &m); err != nil { + logger.LogError(err, string(s)) + row[col] = string(s) + return nil + } + m = processDecodedData(m) + row[col] = m + } + return nil +} + +func parseInt64Column(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s spanner.NullInt64 + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if !s.IsNull() { + row[col] = s.Int64 + } + return nil +} + +func parseFloat64Column(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s spanner.NullFloat64 + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if !s.IsNull() { + row[col] = s.Float64 + } + return nil +} + +func parseNumericColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s spanner.NullNumeric + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if !s.IsNull() { + val, _ := s.Numeric.Float64() + if s.Numeric.IsInt() { + row[col] = int64(val) + } else { + row[col] = val + } + } + return nil +} + +func parseBoolColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s spanner.NullBool + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if !s.IsNull() { + row[col] = s.Bool + } + return nil +} + +func parseStringArrayColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var s []spanner.NullString + err := r.Column(idx, &s) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + var temp []string + for _, val := range s { + temp = append(temp, val.StringVal) + } + if len(s) > 0 { + row[col] = temp + } + return nil +} + +func parseByteArrayColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var b [][]byte + err := r.Column(idx, &b) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + if len(b) > 0 { + row[col] = b + } + return nil +} +func parseNumberArrayColumn(r *spanner.Row, idx int, col string, row map[string]interface{}) error { + var nums []spanner.NullFloat64 + err := r.Column(idx, &nums) + if err != nil && !strings.Contains(err.Error(), "ambiguous column name") { + return err + } + var temp []float64 + for _, val := range nums { + if val.Valid { + temp = append(temp, val.Float64) + } + } + if len(nums) > 0 { + row[col] = temp + } + return nil +} + +func processDecodedData(m interface{}) interface{} { + if val, ok := m.(string); ok && base64Regexp.MatchString(val) { + if ba, err := base64.StdEncoding.DecodeString(val); err == nil { + var sample interface{} + if err := json.Unmarshal(ba, &sample); err == nil { + return sample } - if !s.IsNull() { - singleRow[k] = s.Bool + } + } + if mp, ok := m.(map[string]interface{}); ok { + for k, v := range mp { + if val, ok := v.(string); ok && base64Regexp.MatchString(val) { + if ba, err := base64.StdEncoding.DecodeString(val); err == nil { + var sample interface{} + if err := json.Unmarshal(ba, &sample); err == nil { + mp[k] = sample + } + } } } } - return singleRow, nil + return m } func checkInifinty(value float64, logData interface{}) error { diff --git a/storage/spanner_test.go b/storage/spanner_test.go index d2a40be..9663162 100644 --- a/storage/spanner_test.go +++ b/storage/spanner_test.go @@ -31,7 +31,15 @@ func Test_parseRow(t *testing.T) { removeNullRow, _ := spanner.NewRow([]string{"strCol", "nullCol"}, []interface{}{"my-text", spanner.NullString{}}) skipCommitTimestampRow, _ := spanner.NewRow([]string{"strCol", "commit_timestamp"}, []interface{}{"my-text", "2021-01-01"}) multipleValuesRow, _ := spanner.NewRow([]string{"strCol", "intCol", "nullCol", "boolCol"}, []interface{}{"my-text", int64(32), spanner.NullString{}, true}) - + simpleArrayRow, _ := spanner.NewRow([]string{"arrayCol"}, []interface{}{ + []spanner.NullString{ + {StringVal: "element1", Valid: true}, + {StringVal: "element2", Valid: true}, + {StringVal: "element3", Valid: true}, + }, + }) + invalidTypeRow, _ := spanner.NewRow([]string{"strCol"}, []interface{}{1234}) // Invalid data type + missingColumnRow, _ := spanner.NewRow([]string{"missingCol"}, []interface{}{"value"}) type args struct { r *spanner.Row @@ -45,58 +53,82 @@ func Test_parseRow(t *testing.T) { }{ { "ParseStringValue", - args{simpleStringRow, map[string]string{"strCol": "STRING(MAX)"}}, + args{simpleStringRow, map[string]string{"strCol": "STRING(MAX)"}}, map[string]interface{}{"strCol": "my-text"}, false, }, { "ParseIntValue", - args{simpleIntRow, map[string]string{"intCol": "INT64"}}, + args{simpleIntRow, map[string]string{"intCol": "INT64"}}, map[string]interface{}{"intCol": int64(314)}, false, }, { "ParseFloatValue", - args{simpleFloatRow, map[string]string{"floatCol": "FLOAT64"}}, + args{simpleFloatRow, map[string]string{"floatCol": "FLOAT64"}}, map[string]interface{}{"floatCol": 3.14}, false, }, { "ParseNumericIntValue", - args{simpleNumericIntRow, map[string]string{"numericCol": "NUMERIC"}}, + args{simpleNumericIntRow, map[string]string{"numericCol": "NUMERIC"}}, map[string]interface{}{"numericCol": int64(314)}, false, }, { "ParseNumericFloatValue", - args{simpleNumericFloatRow, map[string]string{"numericCol": "NUMERIC"}}, + args{simpleNumericFloatRow, map[string]string{"numericCol": "NUMERIC"}}, map[string]interface{}{"numericCol": 3.25}, false, }, { "ParseBoolValue", - args{simpleBoolRow, map[string]string{"boolCol": "BOOL"}}, + args{simpleBoolRow, map[string]string{"boolCol": "BOOL"}}, map[string]interface{}{"boolCol": true}, false, }, { "RemoveNulls", - args{removeNullRow, map[string]string{"strCol": "STRING(MAX)", "nullCol": "STRING(MAX)"}}, + args{removeNullRow, map[string]string{"strCol": "STRING(MAX)", "nullCol": "STRING(MAX)"}}, map[string]interface{}{"strCol": "my-text"}, false, }, { "SkipCommitTimestamp", - args{skipCommitTimestampRow, map[string]string{"strCol": "STRING(MAX)", "commit_timestamp": "TIMESTAMP"}}, + args{skipCommitTimestampRow, map[string]string{"strCol": "STRING(MAX)", "commit_timestamp": "TIMESTAMP"}}, map[string]interface{}{"strCol": "my-text"}, false, }, { "MultiValueRow", - args{multipleValuesRow, map[string]string{"strCol": "STRING(MAX)", "intCol": "INT64", "nullCol": "STRING(MAX)", "boolCol": "BOOL"}}, + args{multipleValuesRow, map[string]string{"strCol": "STRING(MAX)", "intCol": "INT64", "nullCol": "STRING(MAX)", "boolCol": "BOOL"}}, map[string]interface{}{"strCol": "my-text", "intCol": int64(32), "boolCol": true}, false, }, + { + "ParseStringArray", + args{simpleArrayRow, map[string]string{"arrayCol": "ARRAY"}}, + map[string]interface{}{"arrayCol": []string{"element1", "element2", "element3"}}, + false, + }, + { + "MissingColumnTypeInDDL", + args{simpleStringRow, map[string]string{"strCol": ""}}, // Missing type in DDL + nil, + true, + }, + { + "InvalidTypeConversion", + args{invalidTypeRow, map[string]string{"strCol": "STRING(MAX)"}}, // Incorrectly trying to parse an int as a string + nil, + true, + }, + { + "ColumnNotInDDL", + args{missingColumnRow, map[string]string{"strCol": "STRING(MAX)"}}, // Column not defined in DDL + nil, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {