Skip to content

Commit

Permalink
Merge pull request #13 from dispatchrun/serialize-primitive-slices
Browse files Browse the repository at this point in the history
Native serialization of JSON-like slices & maps
  • Loading branch information
chriso authored Jun 28, 2024
2 parents 9371a7a + d4984b4 commit 01636bf
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 25 deletions.
2 changes: 2 additions & 0 deletions dispatchhttp/client.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/header.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/request.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/response.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
156 changes: 153 additions & 3 deletions dispatchproto/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ func Duration(v time.Duration) Any {
// Primitive values (booleans, integers, floats, strings, bytes, timestamps,
// durations) are supported, along with values that implement either
// proto.Message, json.Marshaler, encoding.TextMarshaler or
// encoding.BinaryMarshaler.
// encoding.BinaryMarshaler. Slices and maps are also supported, as long
// as they are JSON-like in shape.
func Marshal(v any) (Any, error) {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer && rv.IsNil() {
return Nil(), nil
}
var m proto.Message
Expand Down Expand Up @@ -160,7 +162,10 @@ func Marshal(v any) (Any, error) {
case []byte:
m = wrapperspb.Bytes(vv)
default:
return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v)
var err error
if m, err = newStructpbValue(rv); err != nil {
return Any{}, fmt.Errorf("cannot serialize %v: %w", v, err)
}
}

proto, err := anypb.New(m)
Expand Down Expand Up @@ -386,6 +391,10 @@ func (a Any) Unmarshal(v any) error {
}
}

if s, ok := m.(*structpb.Value); ok {
return fromStructpbValue(elem, s)
}

return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind())
}

Expand All @@ -404,3 +413,144 @@ func (a Any) String() string {
func (a Any) Equal(other Any) bool {
return proto.Equal(a.proto, other.proto)
}

func newStructpbValue(rv reflect.Value) (*structpb.Value, error) {
switch rv.Kind() {
case reflect.Bool:
return structpb.NewBoolValue(rv.Bool()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := rv.Int()
f := float64(n)
if int64(f) != n {
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
}
return structpb.NewNumberValue(f), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := rv.Uint()
f := float64(n)
if uint64(f) != n {
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
}
return structpb.NewNumberValue(f), nil
case reflect.Float32, reflect.Float64:
return structpb.NewNumberValue(rv.Float()), nil
case reflect.String:
return structpb.NewStringValue(rv.String()), nil
case reflect.Interface:
if rv.NumMethod() == 0 { // interface{} aka. any
v := rv.Interface()
if v == nil {
return structpb.NewNullValue(), nil
}
return newStructpbValue(reflect.ValueOf(v))
}
case reflect.Slice:
list := &structpb.ListValue{Values: make([]*structpb.Value, rv.Len())}
for i := range list.Values {
elem := rv.Index(i)
var err error
list.Values[i], err = newStructpbValue(elem)
if err != nil {
return nil, err
}
}
return structpb.NewListValue(list), nil
case reflect.Map:
strct := &structpb.Struct{Fields: make(map[string]*structpb.Value, rv.Len())}
iter := rv.MapRange()
for iter.Next() {
k := iter.Key()

var strKey string
var hasStrKey bool
switch k.Kind() {
case reflect.String:
strKey = k.String()
hasStrKey = true
case reflect.Interface:
if s, ok := k.Interface().(string); ok {
strKey = s
hasStrKey = true
}
}
if !hasStrKey {
return nil, fmt.Errorf("cannot serialize map with %s (%s) key", k.Type(), k.Kind())
}

v, err := newStructpbValue(iter.Value())
if err != nil {
return nil, err
}
strct.Fields[strKey] = v
}
return structpb.NewStructValue(strct), nil
}
return nil, fmt.Errorf("not implemented: %s", rv.Type())
}

func fromStructpbValue(rv reflect.Value, s *structpb.Value) error {
switch rv.Kind() {
case reflect.Bool:
if b, ok := s.Kind.(*structpb.Value_BoolValue); ok {
rv.SetBool(b.BoolValue)
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetInt(int64(n.NumberValue))
return nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetUint(uint64(n.NumberValue))
return nil
}
case reflect.Float32, reflect.Float64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetFloat(n.NumberValue)
return nil
}
case reflect.String:
if str, ok := s.Kind.(*structpb.Value_StringValue); ok {
rv.SetString(str.StringValue)
return nil
}
case reflect.Slice:
if l, ok := s.Kind.(*structpb.Value_ListValue); ok {
values := l.ListValue.GetValues()
rv.Grow(len(values))
rv.SetLen(len(values))
for i, value := range values {
if err := fromStructpbValue(rv.Index(i), value); err != nil {
return err
}
}
return nil
}
case reflect.Map:
if strct, ok := s.Kind.(*structpb.Value_StructValue); ok {
fields := strct.StructValue.Fields
rv.Set(reflect.MakeMapWithSize(rv.Type(), len(fields)))
valueType := rv.Type().Elem()
for key, value := range fields {
mv := reflect.New(valueType).Elem()
if err := fromStructpbValue(mv, value); err != nil {
return err
}
rv.SetMapIndex(reflect.ValueOf(key), mv)
}
return nil
}
case reflect.Interface:
if rv.NumMethod() == 0 { // interface{} aka. any
v := s.AsInterface()
if v == nil {
rv.SetZero()
} else {
rv.Set(reflect.ValueOf(s.AsInterface()))
}
return nil
}
}
return fmt.Errorf("cannot deserialize %T into %v (%v kind)", s, rv.Type(), rv.Kind())
}
15 changes: 15 additions & 0 deletions dispatchproto/any_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"net/http"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -355,6 +356,20 @@ func TestAny(t *testing.T) {
List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}},
Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}},
}},

// slices
[]string{"foo", "bar"},
[]int{-1, 1, 111},
[]bool{true, false, true},
[]float64{3.14, 1.25},
[][]string{{"foo", "bar"}, {"abc", "xyz"}},
[]any{3.14, true, "x", nil},

// maps
map[string]string{"abc": "xyz", "foo": "bar"},
map[string]int{"n": 3},
map[string]http.Header{"original": {"X-Foo": []string{"bar"}}},
map[any]any{"foo": "bar", "pi": 3.14},
} {
t.Run(fmt.Sprintf("%v", v), func(t *testing.T) {
boxed, err := dispatchproto.Marshal(v)
Expand Down
29 changes: 7 additions & 22 deletions examples/fanout/main.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package main

import (
Expand All @@ -12,14 +14,14 @@ import (

func main() {
getRepo := dispatch.Func("getRepo", func(ctx context.Context, name string) (*dispatchhttp.Response, error) {
return dispatchhttp.Get(context.Background(), "https://api.github.com/repos/dispatchrun/"+name)
return dispatchhttp.Get(ctx, "https://api.github.com/repos/dispatchrun/"+name)
})

getStargazers := dispatch.Func("getStargazers", func(ctx context.Context, url string) (*dispatchhttp.Response, error) {
return dispatchhttp.Get(context.Background(), url)
return dispatchhttp.Get(ctx, url)
})

reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs strings) (strings, error) {
reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs []string) ([]string, error) {
responses, err := getStargazers.Gather(stargazerURLs)
if err != nil {
return nil, err
Expand All @@ -39,7 +41,7 @@ func main() {
return maps.Keys(stargazers), nil
})

fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames strings) (strings, error) {
fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames []string) ([]string, error) {
responses, err := getRepo.Gather(repoNames)
if err != nil {
return nil, err
Expand All @@ -65,7 +67,7 @@ func main() {
}

go func() {
if _, err := fanout.Dispatch(context.Background(), strings{"coroutine", "dispatch-py"}); err != nil {
if _, err := fanout.Dispatch(context.Background(), []string{"coroutine", "dispatch-py"}); err != nil {
log.Fatalf("failed to dispatch call: %v", err)
}
}()
Expand All @@ -74,20 +76,3 @@ func main() {
log.Fatalf("failed to serve endpoint: %v", err)
}
}

// TODO: update dispatchproto.Marshal to support serializing slices/maps
// natively (if they can be sent on the wire as structpb.Value)
type strings []string

func (s strings) MarshalJSON() ([]byte, error) {
return json.Marshal([]string(s))
}

func (s *strings) UnmarshalJSON(b []byte) error {
var c []string
if err := json.Unmarshal(b, &c); err != nil {
return err
}
*s = c
return nil
}

0 comments on commit 01636bf

Please sign in to comment.