diff --git a/.run/Run All Tests.run.xml b/.run/Run all tests.run.xml similarity index 67% rename from .run/Run All Tests.run.xml rename to .run/Run all tests.run.xml index 6cf7d44..7946e6a 100644 --- a/.run/Run All Tests.run.xml +++ b/.run/Run all tests.run.xml @@ -1,16 +1,8 @@ - - + @@ -18,4 +10,4 @@ - + \ No newline at end of file diff --git a/Makefile b/Makefile index d1a8651..06e94e6 100644 --- a/Makefile +++ b/Makefile @@ -38,10 +38,10 @@ help: ## Display this help. @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) .PHONY: clean +img=$(shell docker images -q --filter=reference=$(image)) clean: ## Cleans up the binary, container image and other data @rm -f $(out) - @docker-compose -f $(compose) down - @docker rmi $(shell docker images -q --filter=reference=$(image)) + @[ ! -z $(img) ] && docker rmi $(img) || true .PHONY: build test container cov clean fmt fmt: ## Formats the Go source code using 'go fmt' diff --git a/build.settings b/build.settings index f1245a2..3f5338a 100644 --- a/build.settings +++ b/build.settings @@ -1,3 +1,3 @@ # Build configuration -version = 0.7.1 +version = 0.8.0 diff --git a/docs/images/datamodel.png b/docs/images/datamodel.png new file mode 100644 index 0000000..b4c49d6 Binary files /dev/null and b/docs/images/datamodel.png differ diff --git a/go.mod b/go.mod index 6c91679..34f1b35 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,11 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/massenz/slf4go v0.3.2-g4eb5504 - github.com/massenz/statemachine-proto/golang v0.6.0-ga901a76 + github.com/massenz/statemachine-proto/golang v1.1.0-beta-g1fc5dd8 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.18.1 github.com/testcontainers/testcontainers-go v0.16.0 - google.golang.org/grpc v1.49.0 + google.golang.org/grpc v1.51.0 google.golang.org/protobuf v1.28.1 ) @@ -132,7 +132,7 @@ require ( golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.1.0 // indirect golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/text v0.4.0 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad // indirect @@ -153,7 +153,6 @@ replace ( github.com/docker/cli => github.com/docker/cli v20.10.3-0.20221013132413-1d6c6e2367e2+incompatible // 22.06 master branch github.com/docker/docker => github.com/docker/docker v20.10.3-0.20221013203545-33ab36d6b304+incompatible // 22.06 branch github.com/moby/buildkit => github.com/moby/buildkit v0.10.1-0.20220816171719-55ba9d14360a // same as buildx - github.com/opencontainers/runc => github.com/opencontainers/runc v1.1.2 // Can be removed on next bump of containerd to > 1.6.4 // For k8s dependencies, we use a replace directive, to prevent them being diff --git a/go.sum b/go.sum index 36598d0..d86186f 100644 --- a/go.sum +++ b/go.sum @@ -380,8 +380,8 @@ github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/massenz/slf4go v0.3.2-g4eb5504 h1:tRrxPOKcqNKQn25eS8Dy9bW3NMNPpuK4Sla9jKfWmSs= github.com/massenz/slf4go v0.3.2-g4eb5504/go.mod h1:ZJjthXAnZMJGwXUz3Z3v5uyban00uAFFoDYODOoLFpw= -github.com/massenz/statemachine-proto/golang v0.6.0-ga901a76 h1:tik7Xn5GL+w9U5RTJZ3mieoP2sun6RDM+cUBi7WGrUU= -github.com/massenz/statemachine-proto/golang v0.6.0-ga901a76/go.mod h1:EkwQg7wD6c/cmXVxfqNaUOVSrBLlti+xYljIxaQNJqA= +github.com/massenz/statemachine-proto/golang v1.1.0-beta-g1fc5dd8 h1:Dp2yv070ogiHLwQU5LppXskUDnCoO8tDkqgszyZMNmk= +github.com/massenz/statemachine-proto/golang v1.1.0-beta-g1fc5dd8/go.mod h1:g6CkyXxfs7XF8wv6OLdMZZDUu0fn4PY6HQQ2WDbW3GU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= @@ -820,8 +820,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -954,8 +955,8 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= -google.golang.org/grpc v1.49.0 h1:WTLtQzmQori5FUH25Pq4WT22oCsv8USpQ+F6rqtsmxw= -google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= +google.golang.org/grpc v1.51.0 h1:E1eGv1FTqoLIdnBCZufiSHgKjlqG6fKFf6pPWtMTh8U= +google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/grpc/grpc_server.go b/grpc/grpc_server.go index a02683e..8e6770d 100644 --- a/grpc/grpc_server.go +++ b/grpc/grpc_server.go @@ -17,6 +17,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" "strings" "time" @@ -32,6 +33,9 @@ type Config struct { Timeout time.Duration } +type StatemachineStream = protos.StatemachineService_StreamAllInstateServer +type ConfigurationStream = protos.StatemachineService_StreamAllConfigurationsServer + var _ protos.StatemachineServiceServer = (*grpcSubscriber)(nil) const ( @@ -43,9 +47,9 @@ type grpcSubscriber struct { *Config } -func (s *grpcSubscriber) ProcessEvent(ctx context.Context, request *protos.EventRequest) (*protos. +func (s *grpcSubscriber) SendEvent(ctx context.Context, request *protos.EventRequest) (*protos. EventResponse, error) { - if request.Dest == "" { + if request.GetId() == "" { return nil, status.Error(codes.FailedPrecondition, api.MissingDestinationError.Error()) } if request.GetEvent() == nil || request.Event.GetTransition() == nil || @@ -87,31 +91,46 @@ func (s *grpcSubscriber) PutConfiguration(ctx context.Context, cfg *protos.Confi } s.Logger.Trace("configuration stored: %s", api.GetVersionId(cfg)) return &protos.PutResponse{ - Id: api.GetVersionId(cfg), - Config: cfg, + Id: api.GetVersionId(cfg), + // Note: this is the magic incantation to use a `one_of` field in Protobuf. + EntityResponse: &protos.PutResponse_Config{Config: cfg}, }, nil } +func (s *grpcSubscriber) GetAllConfigurations(ctx context.Context, req *wrapperspb.StringValue) ( + *protos.ListResponse, error) { + cfgName := req.Value + if cfgName == "" { + s.Logger.Trace("looking up all available configurations on server") + return &protos.ListResponse{Ids: s.Store.GetAllConfigs()}, nil + } + s.Logger.Trace("looking up all version for configuration %s", cfgName) + return &protos.ListResponse{Ids: s.Store.GetAllVersions(cfgName)}, nil +} -func (s *grpcSubscriber) GetConfiguration(ctx context.Context, request *protos.GetRequest) ( +func (s *grpcSubscriber) GetConfiguration(ctx context.Context, configId *wrapperspb.StringValue) ( *protos.Configuration, error) { - s.Logger.Trace("retrieving Configuration %s", request.GetId()) - cfg, found := s.Store.GetConfig(request.GetId()) + cfgId := configId.Value + s.Logger.Trace("retrieving Configuration %s", cfgId) + cfg, found := s.Store.GetConfig(cfgId) if !found { - return nil, status.Errorf(codes.NotFound, "configuration %s not found", request.GetId()) + return nil, status.Errorf(codes.NotFound, "configuration %s not found", cfgId) } return cfg, nil } func (s *grpcSubscriber) PutFiniteStateMachine(ctx context.Context, - fsm *protos.FiniteStateMachine) (*protos.PutResponse, error) { + request *protos.PutFsmRequest) (*protos.PutResponse, error) { + fsm := request.Fsm // First check that the configuration for the FSM is valid cfg, ok := s.Store.GetConfig(fsm.ConfigId) if !ok { return nil, status.Error(codes.FailedPrecondition, storage.NotFoundError( fsm.ConfigId).Error()) } - // FIXME: we need to allow clients to specify the ID of the FSM to create - id := uuid.NewString() + var id = request.Id + if id == "" { + id = uuid.NewString() + } // If the State of the FSM is not specified, // we set it to the initial state of the configuration. if fsm.State == "" { @@ -122,38 +141,52 @@ func (s *grpcSubscriber) PutFiniteStateMachine(ctx context.Context, s.Logger.Error("could not store FSM [%v]: %v", fsm, err) return nil, status.Error(codes.Internal, err.Error()) } - return &protos.PutResponse{Id: id, Fsm: fsm}, nil + if err := s.Store.UpdateState(cfg.Name, id, "", fsm.State); err != nil { + s.Logger.Error("could not store FSM in state set [%s]: %v", fsm.State, err) + return nil, status.Error(codes.Internal, err.Error()) + } + return &protos.PutResponse{Id: id, EntityResponse: &protos.PutResponse_Fsm{Fsm: fsm}}, nil } -func (s *grpcSubscriber) GetFiniteStateMachine(ctx context.Context, request *protos.GetRequest) ( +func (s *grpcSubscriber) GetFiniteStateMachine(ctx context.Context, in *protos.GetFsmRequest) ( *protos.FiniteStateMachine, error) { - // TODO: use Context to set a timeout, and then pass it on to the Store. - // This may require a pretty large refactoring of the store interface. - s.Logger.Debug("looking up FSM %s", request.GetId()) - // The ID in the request contains the FSM ID, - // prefixed by the Config Name (which defines the "type" of FSM) - splitId := strings.Split(request.GetId(), storage.KeyPrefixIDSeparator) - if len(splitId) != 2 { - return nil, status.Errorf(codes.InvalidArgument, "invalid FSM ID: %s", request.GetId()) - } - fsm, ok := s.Store.GetStateMachine(splitId[1], splitId[0]) + cfg := in.GetConfig() + if cfg == "" { + return nil, status.Error(codes.InvalidArgument, "configuration name must always be provided when looking up statemachine") + } + fsmId := in.GetId() + if fsmId == "" { + return nil, status.Error(codes.InvalidArgument, "ID must always be provided when looking up statemachine") + } + s.Logger.Debug("looking up FSM [%s] (Configuration: %s)", fsmId, cfg) + fsm, ok := s.Store.GetStateMachine(fsmId, cfg) if !ok { - return nil, status.Error(codes.NotFound, storage.NotFoundError(request.GetId()).Error()) + return nil, status.Error(codes.NotFound, storage.NotFoundError(fsmId).Error()) } return fsm, nil } -func (s *grpcSubscriber) GetEventOutcome(ctx context.Context, request *protos.GetRequest) ( - *protos.EventResponse, error) { - - s.Logger.Debug("looking up EventOutcome %s", request.GetId()) - dest := strings.Split(request.GetId(), storage.KeyPrefixIDSeparator) - if len(dest) != 2 { - return nil, status.Error(codes.InvalidArgument, - fmt.Sprintf("invalid destination [%s] expected: #", request.GetId())) +func (s *grpcSubscriber) GetAllInState(ctx context.Context, in *protos.GetFsmRequest) ( + *protos.ListResponse, error) { + cfgName := in.GetConfig() + if cfgName == "" { + return nil, status.Errorf(codes.InvalidArgument, "configuration must always be specified") + } + state := in.GetState() + if state == "" { + // TODO: implement table scanning + return nil, status.Errorf(codes.Unimplemented, "missing state, table scan not implemented") } - smType, evtId := dest[0], dest[1] - outcome, ok := s.Store.GetOutcomeForEvent(evtId, smType) + ids := s.Store.GetAllInState(cfgName, state) + return &protos.ListResponse{Ids: ids}, nil +} + +func (s *grpcSubscriber) GetEventOutcome(ctx context.Context, in *protos.EventRequest) ( + *protos.EventResponse, error) { + evtId := in.GetId() + config := in.GetConfig() + s.Logger.Debug("looking up EventOutcome %s (%s)", evtId, config) + outcome, ok := s.Store.GetOutcomeForEvent(evtId, config) if !ok { return nil, status.Error(codes.NotFound, fmt.Sprintf("outcome for event %s not found", evtId)) } @@ -163,6 +196,48 @@ func (s *grpcSubscriber) GetEventOutcome(ctx context.Context, request *protos.Ge }, nil } +func (s *grpcSubscriber) StreamAllInstate(in *protos.GetFsmRequest, stream StatemachineStream) error { + response, err := s.GetAllInState(context.Background(), in) + if err != nil { + return err + } + cfgName := in.GetConfig() + for _, id := range response.GetIds() { + fsm, found := s.Store.GetStateMachine(id, cfgName) + if !found { + return storage.NotFoundError(id) + } + if err = stream.SendMsg(&protos.PutResponse{ + Id: id, + EntityResponse: &protos.PutResponse_Fsm{Fsm: fsm}, + }); err != nil { + s.Logger.Error("could not stream response back: %s", err) + return err + } + } + return nil +} + +func (s *grpcSubscriber) StreamAllConfigurations(in *wrapperspb.StringValue, stream ConfigurationStream) error { + if in.GetValue() == "" { + return status.Errorf(codes.InvalidArgument, "must specify the Configuration name") + } + response, err := s.GetAllConfigurations(context.Background(), in) + if err != nil { + return nil + } + for _, cfgId := range response.GetIds() { + cfg, found := s.Store.GetConfig(cfgId) + if !found { + return storage.NotFoundError(cfgId) + } + if err = stream.SendMsg(cfg); err != nil { + return err + } + } + return nil +} + // NewGrpcServer creates a new gRPC server to handle incoming events and other API calls. // The `Config` can be used to configure the backing store, a timeout and the logger. func NewGrpcServer(config *Config) (*grpc.Server, error) { @@ -170,7 +245,7 @@ func NewGrpcServer(config *Config) (*grpc.Server, error) { if config.Timeout == 0 { config.Timeout = DefaultTimeout } - gsrv := grpc.NewServer() - protos.RegisterStatemachineServiceServer(gsrv, &grpcSubscriber{Config: config}) - return gsrv, nil + server := grpc.NewServer() + protos.RegisterStatemachineServiceServer(server, &grpcSubscriber{Config: config}) + return server, nil } diff --git a/grpc/grpc_server_stream_test.go b/grpc/grpc_server_stream_test.go new file mode 100644 index 0000000..781853f --- /dev/null +++ b/grpc/grpc_server_stream_test.go @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2022 AlertAvert.com. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Author: Marco Massenzio (marco@alertavert.com) + */ + +package grpc_test + +import ( + "context" + "github.com/go-redis/redis/v8" + "github.com/massenz/slf4go/logging" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + g "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/wrapperspb" + "io" + "net" + + . "github.com/massenz/go-statemachine/api" + "github.com/massenz/go-statemachine/grpc" + "github.com/massenz/go-statemachine/storage" + "github.com/massenz/statemachine-proto/golang/api" +) + +var _ = Describe("gRPC Server Streams", func() { + When("using Redis as the backing store", func() { + var ( + listener net.Listener + client api.StatemachineServiceClient + cfg *api.Configuration + done func() + store storage.StoreManager + ) + // Server setup + BeforeEach(func() { + store = storage.NewRedisStoreWithDefaults(container.Address) + store.SetLogLevel(logging.NONE) + listener, _ = net.Listen("tcp", ":0") + cc, _ := g.Dial(listener.Addr().String(), + g.WithTransportCredentials(insecure.NewCredentials())) + client = api.NewStatemachineServiceClient(cc) + // Use this to log errors when diagnosing test failures; then set to NONE once done. + l := logging.NewLog("grpc-server-test") + l.Level = logging.NONE + server, _ := grpc.NewGrpcServer(&grpc.Config{ + Store: store, + Logger: l, + }) + go func() { + Ω(server.Serve(listener)).Should(Succeed()) + }() + done = func() { + server.Stop() + } + }) + // Server shutdown & Clean up the DB + AfterEach(func() { + done() + rdb := redis.NewClient(&redis.Options{ + Addr: container.Address, + DB: storage.DefaultRedisDb, + }) + rdb.FlushDB(context.Background()) + }) + Context("streaming Configurations", func() { + var versions []string + var name = "test-conf" + // Test data setup + BeforeEach(func() { + versions = []string{"v1", "v2", "v3"} + cfg = &api.Configuration{ + Name: name, + States: []string{"start", "stop"}, + Transitions: []*api.Transition{ + {From: "start", To: "stop", Event: "shutdown"}, + }, + StartingState: "start", + } + for _, v := range versions { + cfg.Version = v + Ω(store.PutConfig(cfg)).ToNot(HaveOccurred()) + } + }) + It("should find all configurations", func() { + stream, err := client.StreamAllConfigurations(bkgnd, + &wrapperspb.StringValue{Value: name}) + Ω(err).ShouldNot(HaveOccurred()) + count := 0 + for { + item, err := stream.Recv() + if err == io.EOF { + Ω(count).Should(Equal(len(versions))) + break + } + count++ + Ω(err).ShouldNot(HaveOccurred()) + Ω(item).ShouldNot(BeNil()) + Ω(item.Name).Should(Equal(name)) + Ω(versions).Should(ContainElement(item.Version)) + } + }) + It("should fail for empty name", func() { + data, err := client.StreamAllConfigurations(bkgnd, &wrapperspb.StringValue{Value: ""}) + Ω(err).ShouldNot(HaveOccurred()) + _, err = data.Recv() + Ω(err).Should(HaveOccurred()) + AssertStatusCode(codes.InvalidArgument, err) + }) + It("should retrieve an empty stream for valid but non-existent configuration", func() { + response, err := client.StreamAllConfigurations(bkgnd, + &wrapperspb.StringValue{Value: "fake"}) + Ω(err).ShouldNot(HaveOccurred()) + Ω(response).ShouldNot(BeNil()) + _, err = response.Recv() + Ω(err).Should(Equal(io.EOF)) + }) + }) + Context("streaming Statemachines", func() { + var ids = []string{"1", "2", "3"} + // Test data setup + BeforeEach(func() { + cfg = &api.Configuration{ + Name: "test-conf", + Version: "v1", + States: []string{"start", "stop"}, + Transitions: []*api.Transition{ + {From: "start", To: "stop", Event: "shutdown"}, + }, + StartingState: "start", + } + Ω(store.PutConfig(cfg)).ShouldNot(HaveOccurred()) + for _, id := range ids { + Ω(store.PutStateMachine(id, &api.FiniteStateMachine{ + ConfigId: GetVersionId(cfg), + State: "start", + })).ShouldNot(HaveOccurred()) + Ω(store.UpdateState(cfg.Name, id, "", "start")). + ShouldNot(HaveOccurred()) + } + }) + It("should find all FSM", func() { + resp, err := client.StreamAllInstate(bkgnd, + &api.GetFsmRequest{ + Config: cfg.Name, + Query: &api.GetFsmRequest_State{State: "start"}, + }) + Ω(err).ShouldNot(HaveOccurred()) + Ω(resp).ShouldNot(BeNil()) + count := 0 + for { + item, err := resp.Recv() + if err == io.EOF { + Ω(count).Should(Equal(len(ids))) + break + } + count++ + Ω(err).ShouldNot(HaveOccurred()) + Ω(ids).Should(ContainElement(item.Id)) + fsm := item.GetFsm() + Ω(fsm).ShouldNot(BeNil()) + Ω(fsm.State).Should(Equal("start")) + Ω(fsm.ConfigId).Should(Equal(GetVersionId(cfg))) + } + }) + }) + }) +}) diff --git a/grpc/grpc_server_test.go b/grpc/grpc_server_test.go index 8095233..8d82a5e 100644 --- a/grpc/grpc_server_test.go +++ b/grpc/grpc_server_test.go @@ -10,46 +10,54 @@ package grpc_test import ( - . "github.com/JiaYongfei/respect/gomega" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "google.golang.org/grpc/codes" - "strings" - "context" "fmt" - "github.com/massenz/slf4go/logging" - g "google.golang.org/grpc" - "google.golang.org/grpc/status" "net" + "strings" "time" + "github.com/go-redis/redis/v8" + "github.com/massenz/slf4go/logging" + g "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/wrapperspb" + + . "github.com/JiaYongfei/respect/gomega" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + . "github.com/massenz/go-statemachine/api" "github.com/massenz/go-statemachine/grpc" "github.com/massenz/go-statemachine/storage" - "github.com/massenz/statemachine-proto/golang/api" + protos "github.com/massenz/statemachine-proto/golang/api" ) -var _ = Describe("GrpcServer", func() { - - Context("when processing events", func() { - var testCh chan api.EventRequest +var bkgnd = context.Background() +var _ = Describe("the gRPC Server", func() { + When("processing events", func() { + var testCh chan protos.EventRequest var listener net.Listener - var client api.StatemachineServiceClient + var client protos.StatemachineServiceClient var done func() BeforeEach(func() { var err error - testCh = make(chan api.EventRequest, 5) + testCh = make(chan protos.EventRequest, 5) listener, err = net.Listen("tcp", ":0") Ω(err).ShouldNot(HaveOccurred()) - cc, err := g.Dial(listener.Addr().String(), g.WithInsecure()) + cc, err := g.Dial(listener.Addr().String(), + g.WithTransportCredentials(insecure.NewCredentials())) Ω(err).ShouldNot(HaveOccurred()) - client = api.NewStatemachineServiceClient(cc) + client = protos.NewStatemachineServiceClient(cc) + // TODO: use GinkgoWriter for logs l := logging.NewLog("grpc-server-test") l.Level = logging.NONE + // Note the `Config` here has no store configured, because we are + // only testing events in this Context, and those are never stored + // in Redis by the gRPC server (other parts of the system do). server, err := grpc.NewGrpcServer(&grpc.Config{ EventsChannel: testCh, Logger: l, @@ -65,15 +73,16 @@ var _ = Describe("GrpcServer", func() { } }) It("should succeed for well-formed events", func() { - response, err := client.ProcessEvent(context.Background(), &api.EventRequest{ - Event: &api.Event{ + response, err := client.SendEvent(bkgnd, &protos.EventRequest{ + Event: &protos.Event{ EventId: "1", - Transition: &api.Transition{ + Transition: &protos.Transition{ Event: "test-vt", }, Originator: "test", }, - Dest: "2", + Config: "test-cfg", + Id: "2", }) Ω(err).ToNot(HaveOccurred()) Ω(response).ToNot(BeNil()) @@ -84,21 +93,22 @@ var _ = Describe("GrpcServer", func() { Ω(evt.Event.EventId).To(Equal("1")) Ω(evt.Event.Transition.Event).To(Equal("test-vt")) Ω(evt.Event.Originator).To(Equal("test")) - Ω(evt.Dest).To(Equal("2")) + Ω(evt.Id).To(Equal("2")) case <-time.After(10 * time.Millisecond): Fail("Timed out") } }) It("should create an ID for events without", func() { - response, err := client.ProcessEvent(context.Background(), &api.EventRequest{ - Event: &api.Event{ - Transition: &api.Transition{ + response, err := client.SendEvent(bkgnd, &protos.EventRequest{ + Event: &protos.Event{ + Transition: &protos.Transition{ Event: "test-vt", }, Originator: "test", }, - Dest: "123456", + Config: "test-cfg", + Id: "123456", }) Ω(err).ToNot(HaveOccurred()) Ω(response.EventId).ToNot(BeNil()) @@ -113,15 +123,15 @@ var _ = Describe("GrpcServer", func() { } }) It("should fail for missing destination", func() { - _, err := client.ProcessEvent(context.Background(), &api.EventRequest{ - Event: &api.Event{ - Transition: &api.Transition{ + _, err := client.SendEvent(bkgnd, &protos.EventRequest{ + Event: &protos.Event{ + Transition: &protos.Transition{ Event: "test-vt", }, Originator: "test", }, }) - assertStatusCode(codes.FailedPrecondition, err) + AssertStatusCode(codes.FailedPrecondition, err) done() select { case evt := <-testCh: @@ -131,16 +141,17 @@ var _ = Describe("GrpcServer", func() { } }) It("should fail for missing event", func() { - _, err := client.ProcessEvent(context.Background(), &api.EventRequest{ - Event: &api.Event{ - Transition: &api.Transition{ + _, err := client.SendEvent(bkgnd, &protos.EventRequest{ + Event: &protos.Event{ + Transition: &protos.Transition{ Event: "", }, Originator: "test", }, - Dest: "9876", + Config: "test", + Id: "9876", }) - assertStatusCode(codes.FailedPrecondition, err) + AssertStatusCode(codes.FailedPrecondition, err) done() select { case evt := <-testCh: @@ -151,24 +162,24 @@ var _ = Describe("GrpcServer", func() { }) }) - Context("when retrieving data from the store", func() { + When("using Redis as the backing store", func() { var ( listener net.Listener - client api.StatemachineServiceClient - cfg *api.Configuration - fsm *api.FiniteStateMachine + client protos.StatemachineServiceClient + cfg *protos.Configuration + fsm *protos.FiniteStateMachine done func() store storage.StoreManager ) // Server setup BeforeEach(func() { - store = storage.NewInMemoryStore() + store = storage.NewRedisStoreWithDefaults(container.Address) store.SetLogLevel(logging.NONE) - listener, _ = net.Listen("tcp", ":0") - cc, _ := g.Dial(listener.Addr().String(), g.WithInsecure()) - client = api.NewStatemachineServiceClient(cc) + cc, _ := g.Dial(listener.Addr().String(), + g.WithTransportCredentials(insecure.NewCredentials())) + client = protos.NewStatemachineServiceClient(cc) // Use this to log errors when diagnosing test failures; then set to NONE once done. l := logging.NewLog("grpc-server-test") l.Level = logging.NONE @@ -178,100 +189,213 @@ var _ = Describe("GrpcServer", func() { }) go func() { + defer GinkgoRecover() Ω(server.Serve(listener)).Should(Succeed()) }() done = func() { server.Stop() } }) - // Server shutdown + // Server shutdown & Clean up the DB AfterEach(func() { done() + rdb := redis.NewClient(&redis.Options{ + Addr: container.Address, + DB: storage.DefaultRedisDb, + }) + rdb.FlushDB(context.Background()) }) - // Test data setup - BeforeEach(func() { - cfg = &api.Configuration{ - Name: "test-conf", - Version: "v1", - States: []string{"start", "stop"}, - Transitions: []*api.Transition{ - {From: "start", To: "stop", Event: "shutdown"}, - }, - StartingState: "start", - } - fsm = &api.FiniteStateMachine{ConfigId: GetVersionId(cfg)} - }) - It("should store valid configurations", func() { - _, ok := store.GetConfig(GetVersionId(cfg)) - Ω(ok).To(BeFalse()) - response, err := client.PutConfiguration(context.Background(), cfg) - Ω(err).ToNot(HaveOccurred()) - Ω(response).ToNot(BeNil()) - Ω(response.Id).To(Equal(GetVersionId(cfg))) - found, ok := store.GetConfig(response.Id) - Ω(ok).Should(BeTrue()) - Ω(found).Should(Respect(cfg)) - }) - It("should fail for invalid configuration", func() { - invalid := &api.Configuration{ - Name: "invalid", - Version: "v1", - States: []string{}, - Transitions: nil, - StartingState: "", - } - _, err := client.PutConfiguration(context.Background(), invalid) - assertStatusCode(codes.InvalidArgument, err) - }) - It("should retrieve a valid configuration", func() { - Ω(store.PutConfig(cfg)).To(Succeed()) - response, err := client.GetConfiguration(context.Background(), - &api.GetRequest{Id: GetVersionId(cfg)}) - Ω(err).ToNot(HaveOccurred()) - Ω(response).ToNot(BeNil()) - Ω(response).Should(Respect(cfg)) - }) - It("should return an empty configuration for an invalid ID", func() { - _, err := client.GetConfiguration(context.Background(), &api.GetRequest{Id: "fake"}) - assertStatusCode(codes.NotFound, err) - }) - It("should store a valid FSM", func() { - Ω(store.PutConfig(cfg)).To(Succeed()) - resp, err := client.PutFiniteStateMachine(context.Background(), fsm) - Ω(err).ToNot(HaveOccurred()) - Ω(resp).ToNot(BeNil()) - Ω(resp.Id).ToNot(BeNil()) - Ω(resp.Fsm).Should(Respect(fsm)) - }) - It("should fail with an invalid Config ID", func() { - invalid := &api.FiniteStateMachine{ConfigId: "fake"} - _, err := client.PutFiniteStateMachine(context.Background(), invalid) - assertStatusCode(codes.FailedPrecondition, err) - }) - It("can retrieve a stored FSM", func() { - id := "123456" - Ω(store.PutConfig(cfg)) - Ω(store.PutStateMachine(id, fsm)).Should(Succeed()) - Ω(client.GetFiniteStateMachine(context.Background(), - &api.GetRequest{ - Id: strings.Join([]string{cfg.Name, id}, storage.KeyPrefixIDSeparator), - })).Should(Respect(fsm)) - }) - It("will return an Invalid error for an invalid ID", func() { - _, err := client.GetFiniteStateMachine(context.Background(), &api.GetRequest{Id: "fake"}) - assertStatusCode(codes.InvalidArgument, err) + Context("handling Configuration API requests", func() { + // Test data setup + BeforeEach(func() { + cfg = &protos.Configuration{ + Name: "test-conf", + Version: "v1", + States: []string{"start", "stop"}, + Transitions: []*protos.Transition{ + {From: "start", To: "stop", Event: "shutdown"}, + }, + StartingState: "start", + } + }) + It("should store valid configurations", func() { + _, ok := store.GetConfig(GetVersionId(cfg)) + Ω(ok).To(BeFalse()) + response, err := client.PutConfiguration(bkgnd, cfg) + Ω(err).ToNot(HaveOccurred()) + Ω(response).ToNot(BeNil()) + Ω(response.Id).To(Equal(GetVersionId(cfg))) + found, ok := store.GetConfig(response.Id) + Ω(ok).Should(BeTrue()) + Ω(found).Should(Respect(cfg)) + }) + It("should fail for invalid configuration", func() { + invalid := &protos.Configuration{ + Name: "invalid", + Version: "v1", + States: []string{}, + Transitions: nil, + StartingState: "", + } + _, err := client.PutConfiguration(bkgnd, invalid) + AssertStatusCode(codes.InvalidArgument, err) + }) + It("should retrieve a valid configuration", func() { + Ω(store.PutConfig(cfg)).To(Succeed()) + response, err := client.GetConfiguration(bkgnd, + &wrapperspb.StringValue{Value: GetVersionId(cfg)}) + Ω(err).ToNot(HaveOccurred()) + Ω(response).ToNot(BeNil()) + Ω(response).Should(Respect(cfg)) + }) + It("should return an empty configuration for an invalid ID", func() { + _, err := client.GetConfiguration(bkgnd, &wrapperspb.StringValue{Value: "fake"}) + AssertStatusCode(codes.NotFound, err) + }) + It("will find all configurations", func() { + names := []string{"orders", "devices", "users"} + for _, name := range names { + cfg = &protos.Configuration{ + Name: name, + Version: "v1", + States: []string{"start", "stop"}, + Transitions: []*protos.Transition{ + {From: "start", To: "stop", Event: "shutdown"}, + }, + StartingState: "start", + } + Ω(store.PutConfig(cfg)).Should(Succeed()) + } + found, err := client.GetAllConfigurations(bkgnd, &wrapperspb.StringValue{}) + Ω(err).Should(Succeed()) + Ω(len(found.Ids)).To(Equal(3)) + for _, value := range found.Ids { + Ω(names).To(ContainElement(value)) + } + }) + It("will find all version for a configuration", func() { + name := "store.api" + versions := []string{"v1alpha", "v1beta", "v1"} + for _, v := range versions { + cfg = &protos.Configuration{ + Name: name, + Version: v, + States: []string{"checkout", "close"}, + Transitions: []*protos.Transition{ + {From: "checkout", To: "close", Event: "payment"}, + }, + StartingState: "checkout", + } + Ω(store.PutConfig(cfg)).Should(Succeed()) + } + found, err := client.GetAllConfigurations(bkgnd, &wrapperspb.StringValue{Value: name}) + Ω(err).Should(Succeed()) + Ω(len(found.Ids)).To(Equal(3)) + for _, value := range versions { + Ω(found.Ids).To(ContainElement( + strings.Join([]string{name, value}, storage.KeyPrefixComponentsSeparator))) + } + }) }) - It("will return a NotFound error for a missing ID", func() { - _, err := client.GetFiniteStateMachine(context.Background(), - &api.GetRequest{Id: "cfg#fake"}) - assertStatusCode(codes.NotFound, err) + Context("handling Statemachine API requests", func() { + // Test data setup + BeforeEach(func() { + cfg = &protos.Configuration{ + Name: "test-conf", + Version: "v1", + States: []string{"start", "stop"}, + Transitions: []*protos.Transition{ + {From: "start", To: "stop", Event: "shutdown"}, + }, + StartingState: "start", + } + fsm = &protos.FiniteStateMachine{ConfigId: GetVersionId(cfg)} + }) + It("should store a valid FSM", func() { + Ω(store.PutConfig(cfg)).To(Succeed()) + resp, err := client.PutFiniteStateMachine(bkgnd, + &protos.PutFsmRequest{Id: "123456", Fsm: fsm}) + Ω(err).ToNot(HaveOccurred()) + Ω(resp).ToNot(BeNil()) + Ω(resp.Id).To(Equal("123456")) + Ω(resp.GetFsm()).Should(Respect(fsm)) + // As we didn't specify a state when creating the FSM, the `StartingState` + // was automatically configured. + found := store.GetAllInState(cfg.Name, cfg.StartingState) + Ω(len(found)).To(Equal(1)) + Ω(found[0]).To(Equal(resp.Id)) + }) + It("should fail with an invalid Config ID", func() { + invalid := &protos.FiniteStateMachine{ConfigId: "fake"} + _, err := client.PutFiniteStateMachine(bkgnd, + &protos.PutFsmRequest{Fsm: invalid}) + AssertStatusCode(codes.FailedPrecondition, err) + }) + It("can retrieve a stored FSM", func() { + id := "123456" + Ω(store.PutConfig(cfg)) + Ω(store.PutStateMachine(id, fsm)).Should(Succeed()) + Ω(client.GetFiniteStateMachine(bkgnd, + &protos.GetFsmRequest{ + Config: cfg.Name, + Query: &protos.GetFsmRequest_Id{Id: id}, + })).Should(Respect(fsm)) + }) + It("will return an Invalid error for missing config or ID", func() { + _, err := client.GetFiniteStateMachine(bkgnd, + &protos.GetFsmRequest{ + Query: &protos.GetFsmRequest_Id{Id: "fake"}, + }) + AssertStatusCode(codes.InvalidArgument, err) + _, err = client.GetFiniteStateMachine(bkgnd, + &protos.GetFsmRequest{ + Config: cfg.Name, + }) + AssertStatusCode(codes.InvalidArgument, err) + }) + It("will return a NotFound error for a missing ID", func() { + _, err := client.GetFiniteStateMachine(bkgnd, + &protos.GetFsmRequest{ + Config: cfg.Name, + Query: &protos.GetFsmRequest_Id{Id: "12345"}, + }) + AssertStatusCode(codes.NotFound, err) + }) + It("will find all FSMs by State", func() { + for i := 1; i <= 5; i++ { + id := fmt.Sprintf("fsm-%d", i) + Ω(store.PutStateMachine(id, + &protos.FiniteStateMachine{ + ConfigId: "test.m:v1", + State: "start", + })).Should(Succeed()) + store.UpdateState("test.m", id, "", "start") + } + for i := 10; i < 13; i++ { + id := fmt.Sprintf("fsm-%d", i) + Ω(store.PutStateMachine(id, + &protos.FiniteStateMachine{ + ConfigId: "test.m:v1", + State: "stop", + })).Should(Succeed()) + store.UpdateState("test.m", id, "", "stop") + + } + items, err := client.GetAllInState(bkgnd, &protos.GetFsmRequest{ + Config: "test.m", + Query: &protos.GetFsmRequest_State{State: "start"}, + }) + Ω(err).ShouldNot(HaveOccurred()) + Ω(len(items.GetIds())).Should(Equal(5)) + Ω(items.GetIds()).Should(ContainElements("fsm-3", "fsm-5")) + items, err = client.GetAllInState(bkgnd, &protos.GetFsmRequest{ + Config: "test.m", + Query: &protos.GetFsmRequest_State{State: "stop"}, + }) + Ω(err).ShouldNot(HaveOccurred()) + Ω(len(items.GetIds())).Should(Equal(3)) + Ω(items.GetIds()).Should(ContainElements("fsm-10", "fsm-12")) + }) }) }) }) - -func assertStatusCode(code codes.Code, err error) { - Ω(err).To(HaveOccurred()) - s, ok := status.FromError(err) - Ω(ok).To(BeTrue()) - Ω(s.Code()).To(Equal(code)) -} diff --git a/grpc/grpc_suite_test.go b/grpc/grpc_suite_test.go index b1531ca..c40b29a 100644 --- a/grpc/grpc_suite_test.go +++ b/grpc/grpc_suite_test.go @@ -10,13 +10,40 @@ package grpc_test import ( + "context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + + internals "github.com/massenz/go-statemachine/internals/testing" ) func TestGrpc(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "gRPC Suite") + RunSpecs(t, "gRPC Server") +} + +var container *internals.Container +var _ = BeforeSuite(func() { + var err error + container, err = internals.NewRedisContainer(context.Background()) + Expect(err).ToNot(HaveOccurred()) + // Note the timeout here is in seconds (and it's not a time.Duration either) +}, 5.0) + +var _ = AfterSuite(func() { + if container != nil { + Expect(container.Terminate(context.Background())).To(Succeed()) + } +}, 2.0) + +// TODO: should be an Omega Matcher +func AssertStatusCode(code codes.Code, err error) { + Ω(err).To(HaveOccurred()) + s, ok := status.FromError(err) + Ω(ok).To(BeTrue()) + Ω(s.Code()).To(Equal(code)) } diff --git a/internals/testing/containers.go b/internals/testing/containers.go new file mode 100644 index 0000000..0d23cf5 --- /dev/null +++ b/internals/testing/containers.go @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022 AlertAvert.com. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Author: Marco Massenzio (marco@alertavert.com) + */ + +package testing + +import ( + "context" + "fmt" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + localstackImage = "localstack/localstack:1.3" + localstackEdgePort = "4566" + redisImage = "redis:6" + redisPort = "6379/tcp" + Region = "us-west-2" +) + +// Container is an internal wrapper around the `testcontainers.Container` carrying also +// the `Address` (which could be a URI) to which the server can be reached at. +type Container struct { + testcontainers.Container + Address string +} + +// NewLocalstackContainer creates a new connection to the `LocalStack` `testcontainers` +func NewLocalstackContainer(ctx context.Context) (*Container, error) { + req := testcontainers.ContainerRequest{ + Image: localstackImage, + ExposedPorts: []string{localstackEdgePort}, + WaitingFor: wait.ForLog("Ready."), + Env: map[string]string{ + "AWS_REGION": Region, + "EDGE_PORT": localstackEdgePort, + "SERVICES": "sqs", + }, + } + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + return nil, err + } + + ip, err := container.Host(ctx) + if err != nil { + return nil, err + } + + mappedPort, err := container.MappedPort(ctx, localstackEdgePort) + if err != nil { + return nil, err + } + + uri := fmt.Sprintf("http://%s:%s", ip, mappedPort.Port()) + return &Container{Container: container, Address: uri}, nil +} + +func NewRedisContainer(ctx context.Context) (*Container, error) { + req := testcontainers.ContainerRequest{ + Image: redisImage, + ExposedPorts: []string{redisPort}, + WaitingFor: wait.ForLog("* Ready to accept connections"), + } + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + return nil, err + } + + mappedPort, err := container.MappedPort(ctx, "6379") + if err != nil { + return nil, err + } + + hostIP, err := container.Host(ctx) + if err != nil { + return nil, err + } + + address := fmt.Sprintf("%s:%s", hostIP, mappedPort.Port()) + return &Container{Container: container, Address: address}, nil +} diff --git a/pubsub/listener.go b/pubsub/listener.go index 288327a..94bf1d3 100644 --- a/pubsub/listener.go +++ b/pubsub/listener.go @@ -15,7 +15,6 @@ import ( "github.com/massenz/go-statemachine/storage" log "github.com/massenz/slf4go/logging" protos "github.com/massenz/statemachine-proto/golang/api" - "strings" ) func NewEventsListener(options *ListenerOptions) *EventsListener { @@ -34,10 +33,11 @@ func (listener *EventsListener) SetLogLevel(level log.LogLevel) { func (listener *EventsListener) PostNotificationAndReportOutcome(eventResponse *protos.EventResponse) { if eventResponse.Outcome.Code != protos.EventOutcome_Ok { - listener.logger.Error("[Event ID: %s]: %s", eventResponse.EventId, eventResponse.GetOutcome().Details) + listener.logger.Error("event [%s]: %s", + eventResponse.GetEventId(), eventResponse.GetOutcome().Details) } if listener.notifications != nil { - listener.logger.Debug("Posting notification: %v", eventResponse.GetEventId()) + listener.logger.Debug("posting notification: %v", eventResponse.GetEventId()) listener.notifications <- *eventResponse } listener.logger.Debug("Reporting outcome: %v", eventResponse.GetEventId()) @@ -48,29 +48,25 @@ func (listener *EventsListener) ListenForMessages() { listener.logger.Info("Events message listener started") for request := range listener.events { listener.logger.Debug("Received request %s", request.Event.String()) - if request.Dest == "" { + fsmId := request.GetId() + if fsmId == "" { listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_MissingDestination, - fmt.Sprintf("no destination specified"))) + fmt.Sprintf("no statemachine ID specified"))) continue } - // TODO: this is an API change and needs to be documented - // Destination comprises both the type (configuration name) and ID of the statemachine, - // separated by the # character: # (e.g., `order#1234`) - dest := strings.Split(request.Dest, storage.KeyPrefixIDSeparator) - if len(dest) != 2 { + config := request.GetConfig() + if config == "" { listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_MissingDestination, - fmt.Sprintf("invalid destination [%s] expected #", - request.Dest))) + fmt.Sprintf("no Configuration name specified"))) continue } - smType, smId := dest[0], dest[1] - fsm, ok := listener.store.GetStateMachine(smId, smType) + fsm, ok := listener.store.GetStateMachine(fsmId, config) if !ok { listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_FsmNotFound, - fmt.Sprintf("statemachine [%s] could not be found", request.Dest))) + fmt.Sprintf("statemachine [%s] could not be found", fsmId))) continue } // TODO: cache the configuration locally: they are immutable anyway. @@ -86,8 +82,8 @@ func (listener *EventsListener) ListenForMessages() { Config: cfg, FSM: fsm, } - listener.logger.Debug("Preparing to send event `%s` for FSM [%s] (current state: %s)", - request.Event.Transition.Event, smId, previousState) + listener.logger.Debug("preparing to send event `%s` for FSM [%s] (current state: %s)", + request.Event.Transition.Event, fsmId, previousState) if err := cfgFsm.SendEvent(request.Event); err != nil { listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_TransitionNotAllowed, @@ -95,29 +91,34 @@ func (listener *EventsListener) ListenForMessages() { request.GetEvent().GetTransition().GetEvent(), err))) continue } - listener.logger.Debug("Event `%s` transitioned FSM [%s] to state `%s` from state `%s` - updating store", - request.Event.Transition.Event, smId, fsm.State, previousState) - err := listener.store.PutStateMachine(smId, fsm) - if err != nil { + if err := listener.store.PutStateMachine(fsmId, fsm); err != nil { + listener.PostNotificationAndReportOutcome(makeResponse(&request, + protos.EventOutcome_InternalError, + fmt.Sprintf("could not update statemachine [%s#%s] in store: %v", + config, fsmId, err))) + continue + } + if err := listener.store.UpdateState(config, fsmId, previousState, fsm.State); err != nil { listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_InternalError, - fmt.Sprintf("could not update statemachine [%s] in store: %v", - request.Dest, err))) + fmt.Sprintf("could not update statemachine state set (%s#%s): %v", + config, fsmId, err))) continue } // All good, we want to report success too. + listener.logger.Debug("Event `%s` transitioned FSM [%s] to state `%s` from state `%s` - updating store", + request.Event.Transition.Event, fsmId, fsm.State, previousState) listener.PostNotificationAndReportOutcome(makeResponse(&request, protos.EventOutcome_Ok, fmt.Sprintf("event [%s] transitioned FSM [%s] to state [%s]", - request.Event.Transition.Event, smId, fsm.State))) + request.Event.Transition.Event, fsmId, fsm.State))) } } func (listener *EventsListener) reportOutcome(response *protos.EventResponse) { - smType := strings.Split(response.Outcome.Dest, storage.KeyPrefixIDSeparator)[0] - if err := listener.store.AddEventOutcome(response.EventId, smType, response.Outcome, - storage.NeverExpire); err != nil { - listener.logger.Error("could not add outcome to store: %v", err) + if err := listener.store.AddEventOutcome(response.EventId, response.GetOutcome().GetConfig(), + response.Outcome, storage.NeverExpire); err != nil { + listener.logger.Error("could not save event outcome: %v", err) } } @@ -127,8 +128,9 @@ func makeResponse(request *protos.EventRequest, code protos.EventOutcome_StatusC EventId: request.GetEvent().GetEventId(), Outcome: &protos.EventOutcome{ Code: code, - Dest: request.Dest, Details: details, + Config: request.Config, + Id: request.Id, }, } } diff --git a/pubsub/listener_test.go b/pubsub/listener_test.go index ac06da9..e7162e9 100644 --- a/pubsub/listener_test.go +++ b/pubsub/listener_test.go @@ -66,7 +66,7 @@ var _ = Describe("A Listener", func() { case n := <-notificationsCh: Ω(n.EventId).To(Equal(msg.GetEventId())) Ω(n.Outcome).ToNot(BeNil()) - Ω(n.Outcome.Dest).To(BeEmpty()) + Ω(n.Outcome.Id).To(BeEmpty()) Ω(n.Outcome.Details).To(Equal(detail)) Ω(n.Outcome.Code).To(Equal(protos.EventOutcome_MissingDestination)) @@ -84,8 +84,9 @@ var _ = Describe("A Listener", func() { Details: "more details", } request := protos.EventRequest{ - Event: &event, - Dest: "test#12345-faa44", + Event: &event, + Config: "test", + Id: "12345-faa44", } Ω(store.PutConfig(&protos.Configuration{ Name: "test", @@ -111,7 +112,7 @@ var _ = Describe("A Listener", func() { // First we want to test that the outcome was successful Ω(notification.EventId).To(Equal(event.GetEventId())) Ω(notification.Outcome).ToNot(BeNil()) - Ω(notification.Outcome.Dest).To(Equal(request.GetDest())) + Ω(notification.Outcome.Id).To(Equal(request.GetId())) Ω(notification.Outcome.Details).To(ContainSubstring("transitioned")) Ω(notification.Outcome.Code).To(Equal(protos.EventOutcome_Ok)) @@ -136,8 +137,9 @@ var _ = Describe("A Listener", func() { Details: "more details", } request := protos.EventRequest{ - Event: &event, - Dest: "test#fake-fsm", + Event: &event, + Config: "test", + Id: "fake-fsm", } go func() { testListener.ListenForMessages() @@ -148,7 +150,7 @@ var _ = Describe("A Listener", func() { case n := <-notificationsCh: Ω(n.EventId).To(Equal(request.Event.EventId)) Ω(n.Outcome).ToNot(BeNil()) - Ω(n.Outcome.Dest).To(Equal(request.Dest)) + Ω(n.Outcome.Id).To(Equal(request.GetId())) Ω(n.Outcome.Code).To(Equal(protos.EventOutcome_FsmNotFound)) case <-time.After(timeout): Fail("the listener did not exit when the events channel was closed") @@ -159,7 +161,6 @@ var _ = Describe("A Listener", func() { Event: &protos.Event{ EventId: "feed-beef", }, - Dest: "", } go func() { testListener.ListenForMessages() }() eventsCh <- request diff --git a/pubsub/pubsub_suite_test.go b/pubsub/pubsub_suite_test.go index 9ab1238..1c15e20 100644 --- a/pubsub/pubsub_suite_test.go +++ b/pubsub/pubsub_suite_test.go @@ -13,10 +13,6 @@ import ( "context" "fmt" "github.com/golang/protobuf/proto" - "github.com/massenz/go-statemachine/pubsub" - "github.com/massenz/statemachine-proto/golang/api" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" "os" "testing" "time" @@ -27,11 +23,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" + + internals "github.com/massenz/go-statemachine/internals/testing" + "github.com/massenz/go-statemachine/pubsub" + "github.com/massenz/statemachine-proto/golang/api" ) const ( - localstackImage = "localstack/localstack:1.3" - localstackEdgePort = "4566" eventsQueue = "test-events" notificationsQueue = "test-notifications" acksQueue = "test-acks" @@ -44,63 +42,27 @@ func TestPubSub(t *testing.T) { RunSpecs(t, "Pub/Sub Suite") } -type LocalstackContainer struct { - testcontainers.Container - EndpointUri string -} -func SetupAwsLocal(ctx context.Context) (*LocalstackContainer, error) { - req := testcontainers.ContainerRequest{ - Image: localstackImage, - ExposedPorts: []string{localstackEdgePort}, - WaitingFor: wait.ForLog("Ready."), - Env: map[string]string{ - "AWS_REGION": "us-west-2", - "EDGE_PORT": "4566", - "SERVICES": "sqs", - }, - } - container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - }) - if err != nil { - return nil, err - } - - ip, err := container.Host(ctx) - if err != nil { - return nil, err - } - - mappedPort, err := container.MappedPort(ctx, localstackEdgePort) - if err != nil { - return nil, err - } - - uri := fmt.Sprintf("http://%s:%s", ip, mappedPort.Port()) - - return &LocalstackContainer{Container: container, EndpointUri: uri}, nil -} // Although these are constants, we cannot take the pointers unless we declare them vars. var ( - region = "us-west-2" - awsLocal *LocalstackContainer + awsLocal *internals.Container testSqsClient *sqs.SQS ) var _ = BeforeSuite(func() { - Expect(os.Setenv("AWS_REGION", region)).ToNot(HaveOccurred()) + Expect(os.Setenv("AWS_REGION", internals.Region)).ToNot(HaveOccurred()) var err error - awsLocal, err = SetupAwsLocal(context.Background()) + awsLocal, err = internals.NewLocalstackContainer(context.Background()) Expect(err).ToNot(HaveOccurred()) + // Can't take the address of a constant. + region := internals.Region testSqsClient = sqs.New(session.Must(session.NewSessionWithOptions(session.Options{ SharedConfigState: session.SharedConfigEnable, Config: aws.Config{ - Endpoint: &awsLocal.EndpointUri, + Endpoint: &awsLocal.Address, Region: ®ion, }, }))) @@ -134,7 +96,7 @@ var _ = AfterSuite(func() { Expect(err).NotTo(HaveOccurred()) } } - awsLocal.Terminate(context.Background()) + Expect(awsLocal.Terminate(context.Background())).ToNot(HaveOccurred()) }, 2.0) // getQueueName provides a way to obtain a process-independent name for the SQS queue, diff --git a/pubsub/sqs_pub_test.go b/pubsub/sqs_pub_test.go index 7ff8135..0ce7b9b 100644 --- a/pubsub/sqs_pub_test.go +++ b/pubsub/sqs_pub_test.go @@ -34,7 +34,7 @@ var _ = Describe("SQS Publisher", func() { ) BeforeEach(func() { notificationsCh = make(chan protos.EventResponse) - testPublisher = pubsub.NewSqsPublisher(notificationsCh, &awsLocal.EndpointUri) + testPublisher = pubsub.NewSqsPublisher(notificationsCh, &awsLocal.Address) Expect(testPublisher).ToNot(BeNil()) // Set to DEBUG when diagnosing test failures testPublisher.SetLogLevel(logging.NONE) @@ -46,8 +46,9 @@ var _ = Describe("SQS Publisher", func() { EventId: "feed-beef", Outcome: &protos.EventOutcome{ Code: protos.EventOutcome_InternalError, - Dest: "me", Details: "error details", + Config: "test-cfg", + Id: "abd-456", }, } done := make(chan interface{}) @@ -157,8 +158,9 @@ var _ = Describe("SQS Publisher", func() { EventId: evt.EventId, Outcome: &protos.EventOutcome{ Code: protos.EventOutcome_InternalError, - Dest: fmt.Sprintf("test-%d", i), Details: "more details about the error", + Config: "test-cfg", + Id: fmt.Sprintf("fsm-%d", i), }, } } @@ -178,7 +180,7 @@ var _ = Describe("SQS Publisher", func() { g.Expect(receivedEvt.EventId).To(Equal(fmt.Sprintf("event-%d", i))) g.Expect(receivedEvt.Outcome.Code).To(Equal(protos.EventOutcome_InternalError)) g.Expect(receivedEvt.Outcome.Details).To(Equal("more details about the error")) - g.Expect(receivedEvt.Outcome.Dest).To(ContainSubstring("test-")) + g.Expect(receivedEvt.Outcome.Id).To(ContainSubstring("fsm-")) }).Should(Succeed()) } }() diff --git a/pubsub/sqs_sub.go b/pubsub/sqs_sub.go index aa3f6df..1a34be0 100644 --- a/pubsub/sqs_sub.go +++ b/pubsub/sqs_sub.go @@ -139,7 +139,7 @@ func (s *SqsSubscriber) ProcessMessage(msg *sqs.Message, queueUrl *string) { return } - destId := request.Dest + destId := request.GetId() if destId == "" { errDetails := fmt.Sprintf("No Destination ID in %v", request.String()) s.logger.Error(errDetails) diff --git a/pubsub/sqs_sub_test.go b/pubsub/sqs_sub_test.go index 72b32a6..670cf2e 100644 --- a/pubsub/sqs_sub_test.go +++ b/pubsub/sqs_sub_test.go @@ -32,7 +32,7 @@ var _ = Describe("SQS Subscriber", func() { BeforeEach(func() { Expect(awsLocal).ToNot(BeNil()) eventsCh = make(chan protos.EventRequest) - testSubscriber = pubsub.NewSqsSubscriber(eventsCh, &awsLocal.EndpointUri) + testSubscriber = pubsub.NewSqsSubscriber(eventsCh, &awsLocal.Address) Expect(testSubscriber).ToNot(BeNil()) // Set to DEBUG when diagnosing failing tests testSubscriber.SetLogLevel(log.NONE) @@ -43,7 +43,7 @@ var _ = Describe("SQS Subscriber", func() { It("receives events", func() { msg := protos.EventRequest{ Event: api.NewEvent("test-event"), - Dest: "some-fsm", + Id: "some-fsm", } msg.Event.EventId = "feed-beef" msg.Event.Originator = "test-subscriber" diff --git a/server/event_handlers_test.go b/server/event_handlers_test.go index f0d6c11..f4b895b 100644 --- a/server/event_handlers_test.go +++ b/server/event_handlers_test.go @@ -119,7 +119,7 @@ var _ = Describe("Event Handlers", func() { id = uuid.NewString() outcome = &protos.EventOutcome{ Code: protos.EventOutcome_Ok, - Dest: "fake-sm", + Id: "fake-sm", Details: "something happened", } Expect(store.AddEventOutcome(id, cfgName, outcome, @@ -136,7 +136,7 @@ var _ = Describe("Event Handlers", func() { Expect(json.NewDecoder(writer.Body).Decode(&result)).ToNot(HaveOccurred()) Expect(result.StatusCode).To(Equal(outcome.Code.String())) Expect(result.Message).To(Equal(outcome.Details)) - Expect(result.Destination).To(Equal(outcome.Dest)) + Expect(result.Destination).To(Equal(outcome.Id)) }) It("with an invalid ID will return Not Found", func() { endpoint := strings.Join([]string{server.ApiPrefix, diff --git a/server/types.go b/server/types.go index 9992add..8c449c5 100644 --- a/server/types.go +++ b/server/types.go @@ -71,6 +71,6 @@ func MakeOutcomeResponse(outcome *protos.EventOutcome) *OutcomeResponse { return &OutcomeResponse{ StatusCode: outcome.Code.String(), Message: outcome.Details, - Destination: outcome.Dest, + Destination: outcome.Id, } } diff --git a/storage/memory_store.go b/storage/memory_store.go index 641c585..b89ca35 100644 --- a/storage/memory_store.go +++ b/storage/memory_store.go @@ -61,12 +61,6 @@ func (csm *InMemoryStore) put(key string, value proto.Message) error { return err } -func (csm *InMemoryStore) GetAllInState(cfg string, state string) []*protos.FiniteStateMachine { - // TODO [#33] Ability to query for all machines in a given state - csm.logger.Error("Not implemented") - return nil -} - func (csm *InMemoryStore) GetEvent(id string, cfg string) (*protos.Event, bool) { key := NewKeyForEvent(id, cfg) event := &protos.Event{} @@ -144,3 +138,26 @@ func (csm *InMemoryStore) GetTimeout() time.Duration { func (csm *InMemoryStore) Health() error { return nil } + +func (csm *InMemoryStore) GetAllInState(cfg string, state string) []string { + // TODO [#33] Ability to query for all machines in a given state + csm.logger.Error(NotImplementedError("GetAllInState").Error()) + return nil +} + +func (csm *InMemoryStore) GetAllConfigs() []string { + // TODO [#33] Ability to query for all machines in a given state + csm.logger.Error(NotImplementedError("GetAllConfigs").Error()) + return nil +} + +func (csm *InMemoryStore) GetAllVersions(name string) []string { + // TODO [#33] Ability to query for all machines in a given state + csm.logger.Error(NotImplementedError("GetAllVersions").Error()) + return nil +} +func (csm *InMemoryStore) UpdateState(cfgName string, id string, oldState string, newState string) error { + // TODO [#33] Ability to query for all machines in a given state + csm.logger.Error(NotImplementedError("GetAllVersions").Error()) + return nil +} diff --git a/storage/redis_sets_store.go b/storage/redis_sets_store.go new file mode 100644 index 0000000..98fde88 --- /dev/null +++ b/storage/redis_sets_store.go @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022 AlertAvert.com. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Author: Marco Massenzio (marco@alertavert.com) + */ + +package storage + +import ( + "context" + "fmt" +) + +func (csm *RedisStore) UpdateState(cfgName string, id string, oldState string, newState string) error { + var key string + var err error + if oldState != "" { + key = NewKeyForMachinesByState(cfgName, oldState) + err = csm.client.SRem(context.Background(), key, id).Err() + if err != nil { + return fmt.Errorf( + "cannot remove FSM [%s#%s] from state set `%s`: %s", + cfgName, id, oldState, err) + } + } + if newState != "" { + key = NewKeyForMachinesByState(cfgName, newState) + err = csm.client.SAdd(context.Background(), key, id).Err() + if err != nil { + return fmt.Errorf( + "cannot add FSM [%s#%s] to state set `%s`: %s", + cfgName, id, newState, err) + } + } + return nil +} + +func (csm *RedisStore) GetAllInState(cfg string, state string) []string { + // TODO: enable splitting results with a (cursor, count) + csm.logger.Debug("Looking up all FSMs [%s] in DB with state `%s`", cfg, state) + key := NewKeyForMachinesByState(cfg, state) + fsms, err := csm.client.SMembers(context.Background(), key).Result() + if err != nil { + csm.logger.Error("Could not retrieve FSMs for state `%s`: %s", state, err) + return nil + } + csm.logger.Debug("Returning %d items", len(fsms)) + return fsms +} + +func (csm *RedisStore) GetAllConfigs() []string { + // TODO: enable splitting results with a (cursor, count) + csm.logger.Debug("Looking up all configs in DB") + configs, err := csm.client.SMembers(context.Background(), ConfigsPrefix).Result() + if err != nil { + csm.logger.Error("Could not retrieve configurations: %s", err) + return nil + } + csm.logger.Debug("Returning %d items", len(configs)) + return configs +} + +func (csm *RedisStore) GetAllVersions(name string) []string { + csm.logger.Debug("Looking up all versions for Configurations `%s` in DB", name) + configs, err := csm.client.SMembers(context.Background(), NewKeyForConfig(name)).Result() + if err != nil { + csm.logger.Error("Could not retrieve configurations: %s", err) + return nil + } + csm.logger.Debug("Returning %d items", len(configs)) + return configs +} diff --git a/storage/redis_store.go b/storage/redis_store.go index dfd827e..ccaa75b 100644 --- a/storage/redis_store.go +++ b/storage/redis_store.go @@ -40,12 +40,6 @@ type RedisStore struct { MaxRetries int } -func (csm *RedisStore) GetAllInState(cfg string, state string) []*protos.FiniteStateMachine { - // TODO [#33] Ability to query for all machines in a given state - csm.logger.Error(NotImplementedError("GetAllInState").Error()) - return nil -} - func (csm *RedisStore) GetConfig(id string) (*protos.Configuration, bool) { key := NewKeyForConfig(id) var cfg protos.Configuration @@ -73,7 +67,7 @@ func (csm *RedisStore) GetStateMachine(id string, cfg string) (*protos.FiniteSta var stateMachine protos.FiniteStateMachine err := csm.get(key, &stateMachine) if err != nil { - csm.logger.Error("Error retrieving state machine `%s`: %s", key, err.Error()) + csm.logger.Error("cannot access store for state machine `%s`: %s", key, err.Error()) return nil, false } return &stateMachine, true @@ -87,6 +81,9 @@ func (csm *RedisStore) PutConfig(cfg *protos.Configuration) error { if csm.client.Exists(context.Background(), key).Val() == 1 { return AlreadyExistsError(key) } + // TODO: Find out whether the client allows to batch requests, instead of sending multiple server requests + csm.client.SAdd(context.Background(), ConfigsPrefix, cfg.Name) + csm.client.SAdd(context.Background(), NewKeyForConfig(cfg.Name), api.GetVersionId(cfg)) return csm.put(key, cfg, NeverExpire) } diff --git a/storage/redis_store_test.go b/storage/redis_store_test.go index ddd49c3..be13b5d 100644 --- a/storage/redis_store_test.go +++ b/storage/redis_store_test.go @@ -11,6 +11,7 @@ package storage_test import ( "context" + "fmt" . "github.com/JiaYongfei/respect/gomega" "github.com/go-redis/redis/v8" "github.com/golang/protobuf/proto" @@ -25,7 +26,7 @@ import ( var _ = Describe("RedisStore", func() { - Context("when configured locally", func() { + Context("for simple operations", func() { var store storage.StoreManager var rdb *redis.Client var cfg *protos.Configuration @@ -49,9 +50,11 @@ var _ = Describe("RedisStore", func() { Addr: container.Address, DB: storage.DefaultRedisDb, }) + }, 0.5) + AfterEach(func() { // Cleaning up the DB to prevent "dirty" store to impact test results rdb.FlushDB(context.Background()) - }) + }, 0.2) It("is healthy", func() { Expect(store.Health()).To(Succeed()) }) @@ -171,7 +174,8 @@ var _ = Describe("RedisStore", func() { cfg := "orders" response := &protos.EventOutcome{ Code: protos.EventOutcome_Ok, - Dest: "1234-feed-beef", + Config: "test", + Id: "1234-feed-beef", Details: "this was just a test", } Expect(store.AddEventOutcome(id, cfg, response, storage.NeverExpire)).ToNot(HaveOccurred()) @@ -188,8 +192,8 @@ var _ = Describe("RedisStore", func() { cfg := "orders" response := &protos.EventOutcome{ Code: protos.EventOutcome_Ok, - Dest: "1234-feed-beef", Details: "this was just a test", + Id: "1234-feed-beef", } key := storage.NewKeyForOutcome(id, cfg) val, _ := proto.Marshal(response) @@ -213,4 +217,128 @@ var _ = Describe("RedisStore", func() { storage.NeverExpire)).To(HaveOccurred()) }) }) + + When("querying for configurations", func() { + var store storage.StoreManager + var rdb *redis.Client + + BeforeEach(func() { + Expect(container).ToNot(BeNil()) + store = storage.NewRedisStoreWithDefaults(container.Address) + Expect(store).ToNot(BeNil()) + store.SetLogLevel(slf4go.NONE) + + // This is used to go "behind the back" of our StoreManager and mess with it for testing + // purposes. Do NOT do this in your code. + rdb = redis.NewClient(&redis.Options{ + Addr: container.Address, + DB: storage.DefaultRedisDb, + }) + }, 0.5) + AfterEach(func() { + // Cleaning up the DB to prevent "dirty" store to impact test results + rdb.FlushDB(context.Background()) + }, 0.2) + + It("can get all configuration names", func() { + for _, name := range []string{"orders", "devices", "users"} { + Expect(store.PutConfig(&protos.Configuration{Name: name, Version: "v3", StartingState: "start"})). + ToNot(HaveOccurred()) + } + configs := store.GetAllConfigs() + Expect(len(configs)).To(Equal(3)) + Expect(configs).To(ContainElements("orders", "devices", "users")) + }) + It("can get all versions of a configuration", func() { + for _, version := range []string{"v1alpha1", "v1beta", "v1"} { + Expect(store.PutConfig(&protos.Configuration{Name: "orders", Version: version, StartingState: "start"})). + ToNot(HaveOccurred()) + } + configs := store.GetAllVersions("orders") + Expect(len(configs)).To(Equal(3)) + Expect(configs).To(ContainElements("orders:v1alpha1", "orders:v1beta", "orders:v1")) + }) + It("returns an empty slice for a non-existent config", func() { + configs := store.GetAllVersions("fake") + Expect(len(configs)).To(Equal(0)) + }) + }) + When("querying for FSMs", func() { + var store storage.StoreManager + var rdb *redis.Client + + BeforeEach(func() { + Expect(container).ToNot(BeNil()) + store = storage.NewRedisStoreWithDefaults(container.Address) + Expect(store).ToNot(BeNil()) + store.SetLogLevel(slf4go.NONE) + + // This is used to go "behind the back" of our StoreManager and mess with it for testing + // purposes. Do NOT do this in your code. + rdb = redis.NewClient(&redis.Options{ + Addr: container.Address, + DB: storage.DefaultRedisDb, + }) + }, 0.5) + AfterEach(func() { + // Cleaning up the DB to prevent "dirty" store to impact test results + rdb.FlushDB(context.Background()) + }, 0.2) + It("finds them by state", func() { + for id := 1; id < 5; id++ { + fsm := &protos.FiniteStateMachine{ + ConfigId: "orders:v4", + State: "in_transit", + History: []*protos.Event{ + {Transition: &protos.Transition{Event: "confirmed"}, Originator: "bot"}, + {Transition: &protos.Transition{Event: "shipped"}, Originator: "bot"}, + }, + } + fsmId := fmt.Sprintf("fsm-%d", id) + Expect(store.PutStateMachine(fsmId, fsm)).ToNot(HaveOccurred()) + Expect(store.UpdateState("orders", fsmId, "", fsm.State)) + } + res := store.GetAllInState("orders", "in_transit") + Expect(len(res)).To(Equal(4)) + for id := 1; id < 5; id++ { + Expect(res).To(ContainElement(fmt.Sprintf("fsm-%d", id))) + } + }) + When("transitioning state", func() { + BeforeEach(func() { + for id := 1; id < 10; id++ { + fsm := &protos.FiniteStateMachine{ + ConfigId: "orders:v4", + State: "in_transit", + History: []*protos.Event{ + {Transition: &protos.Transition{Event: "confirmed"}, Originator: "bot"}, + {Transition: &protos.Transition{Event: "shipped"}, Originator: "bot"}, + }, + } + fsmId := fmt.Sprintf("fsm-%d", id) + Expect(store.PutStateMachine(fsmId, fsm)).ToNot(HaveOccurred()) + Expect(store.UpdateState("orders", fsmId, "", fsm.State)) + } + }) + It("finds them", func() { + for id := 3; id < 6; id++ { + fsmId := fmt.Sprintf("fsm-%d", id) + Expect(store.UpdateState("orders", fsmId, "in_transit", "shipped")) + } + res := store.GetAllInState("orders", "shipped") + Expect(len(res)).To(Equal(3)) + for id := 3; id < 6; id++ { + Expect(res).To(ContainElement(fmt.Sprintf("fsm-%d", id))) + } + res = store.GetAllInState("orders", "in_transit") + Expect(len(res)).To(Equal(6)) + }) + It("will remove with an empty newState", func() { + Expect(store.UpdateState("orders", "fsm-1", "in_transit", "")).To(Succeed()) + res := store.GetAllInState("orders", "in_transit") + Ω(res).ToNot(ContainElement("fsm-1")) + }) + }) + }) + }) diff --git a/storage/storage_suite_test.go b/storage/storage_suite_test.go index 3b58d97..fe5a929 100644 --- a/storage/storage_suite_test.go +++ b/storage/storage_suite_test.go @@ -11,13 +11,13 @@ package storage_test import ( "context" - "fmt" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" "testing" "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + internals "github.com/massenz/go-statemachine/internals/testing" ) func TestStorage(t *testing.T) { @@ -25,43 +25,10 @@ func TestStorage(t *testing.T) { RunSpecs(t, "Storage Suite") } -type RedisContainer struct { - testcontainers.Container - Address string -} - -func SetupRedis(ctx context.Context) (*RedisContainer, error) { - req := testcontainers.ContainerRequest{ - Image: "redis:6", - ExposedPorts: []string{"6379/tcp"}, - WaitingFor: wait.ForLog("* Ready to accept connections"), - } - container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - }) - if err != nil { - return nil, err - } - - mappedPort, err := container.MappedPort(ctx, "6379") - if err != nil { - return nil, err - } - - hostIP, err := container.Host(ctx) - if err != nil { - return nil, err - } - - address := fmt.Sprintf("%s:%s", hostIP, mappedPort.Port()) - return &RedisContainer{Container: container, Address: address}, nil -} - -var container *RedisContainer +var container *internals.Container var _ = BeforeSuite(func() { var err error - container, err = SetupRedis(context.Background()) + container, err = internals.NewRedisContainer(context.Background()) Expect(err).ToNot(HaveOccurred()) // Note the timeout here is in seconds (and it's not a time.Duration either) }, 5.0) diff --git a/storage/types.go b/storage/types.go index 048f050..258fa8a 100644 --- a/storage/types.go +++ b/storage/types.go @@ -32,12 +32,42 @@ var ( type ConfigurationStorageManager interface { GetConfig(versionId string) (*protos.Configuration, bool) PutConfig(cfg *protos.Configuration) error + + // GetAllConfigs returns all the `Configurations` that exist in the server, regardless of + // the version, and whether are used or not by an FSM. + GetAllConfigs() []string + + // GetAllVersions returns the full `name:version` ID of all the Configurations whose + // name matches `name`. + GetAllVersions(name string) []string } type FiniteStateMachineStorageManager interface { + // GetStateMachine will find the FSM with `id and that is configured via a `Configuration` whose + // `name` matches `cfg` (without the `version`). GetStateMachine(id string, cfg string) (*protos.FiniteStateMachine, bool) + + // PutStateMachine creates or updates the FSM whose `id` is given. + // No further action is taken: no check that the referenced `Configuration` exists, and the + // `state` SETs are not updated either: it is the caller's responsibility to call the + // `UpdateState` method (possibly with an empty `oldState`, in the case of creation). PutStateMachine(id string, fsm *protos.FiniteStateMachine) error - GetAllInState(cfg string, state string) []*protos.FiniteStateMachine + + // GetAllInState looks up all the FSMs that are currently in the given `state` and + // are configured with a `Configuration` whose name matches `cfg` (regardless of the + // configuration's version). + // + // It returns the IDs for the FSMs. + GetAllInState(cfg string, state string) []string + + // UpdateState will move the FSM's `id` from/to the respective Redis SETs. + // + // When creating or updating an FSM with `PutStateMachine`, the state SETs are not + // modified; it is the responsibility of the caller to manage the FSM state appropriately + // (or not, as the case may be). + // + // `oldState` may be empty in the case of a new FSM being created. + UpdateState(cfgName string, id string, oldState string, newState string) error } type EventStorageManager interface {