diff --git a/backend/flow_api/handler.go b/backend/flow_api/handler.go index a77497a2c..6a8ab6390 100644 --- a/backend/flow_api/handler.go +++ b/backend/flow_api/handler.go @@ -124,7 +124,8 @@ func (h *FlowPilotHandler) executeFlow(c echo.Context, flow flowpilot.Flow) erro flowResult, err = flow.Execute(models.NewFlowDB(tx), flowpilot.WithQueryParamKey(queryParamKey), flowpilot.WithQueryParamValue(c.QueryParam(queryParamKey)), - flowpilot.WithInputData(inputData)) + flowpilot.WithInputData(inputData), + flowpilot.UseCompression(!h.Cfg.Debug)) return err } diff --git a/backend/flowpilot/context.go b/backend/flowpilot/context.go index 2d111eb81..05dcff2fc 100644 --- a/backend/flowpilot/context.go +++ b/backend/flowpilot/context.go @@ -121,6 +121,8 @@ func createAndInitializeFlow(db FlowDB, flow defaultFlow) (FlowResult, error) { return nil, fmt.Errorf("failed to initialize a new stash: %w", err) } + s.useCompression(flow.useCompression) + p := newPayload() csrfToken, err := generateRandomString(32) @@ -189,6 +191,8 @@ func executeFlowAction(db FlowDB, flow defaultFlow) (FlowResult, error) { return nil, fmt.Errorf("failed to parse stash from flow: %w", err) } + s.useCompression(flow.useCompression) + // Initialize JSONManagers for payload and flash data. p := newPayload() diff --git a/backend/flowpilot/flow.go b/backend/flowpilot/flow.go index faee41b4c..e934a40d6 100644 --- a/backend/flowpilot/flow.go +++ b/backend/flowpilot/flow.go @@ -36,6 +36,13 @@ func WithInputData(inputData InputData) func(*defaultFlow) { } } +// UseCompression causes the flow data to be compressed before stored to the db. +func UseCompression(b bool) func(*defaultFlow) { + return func(f *defaultFlow) { + f.useCompression = b + } +} + // StateName represents the name of a state in a flow. type StateName string @@ -191,6 +198,7 @@ type defaultFlow struct { queryParam queryParam // TODO contextValues contextValues // Values to be used within the flow context. inputData InputData + useCompression bool queryParamKey string queryParamValue string diff --git a/backend/flowpilot/stash.go b/backend/flowpilot/stash.go index e83cd00fd..2fd8e66d5 100644 --- a/backend/flowpilot/stash.go +++ b/backend/flowpilot/stash.go @@ -1,11 +1,15 @@ package flowpilot import ( + "bytes" + "compress/gzip" + "encoding/base64" "errors" "fmt" "github.com/teamhanko/hanko/backend/flowpilot/jsonmanager" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "io" ) const ( @@ -27,6 +31,7 @@ type stash interface { getPreviousStateName() StateName addScheduledStateNames(...StateName) getNextStateName() StateName + useCompression(bool) jsonmanager.JSONManager } @@ -35,6 +40,7 @@ type defaultStash struct { jm jsonmanager.JSONManager data jsonmanager.JSONManager scheduledStateNames []StateName + compressionEnabled bool } // newStashFromJSONManager creates a new instance of stash with a given JSONManager. @@ -44,6 +50,7 @@ func newStashFromJSONManager(jm jsonmanager.JSONManager) stash { jm: jm, data: data, scheduledStateNames: make([]StateName, 0), + compressionEnabled: false, } } @@ -51,6 +58,10 @@ func newStashFromJSONManager(jm jsonmanager.JSONManager) stash { func newStash(nextStates ...StateName) (stash, error) { jm := jsonmanager.NewJSONManager() + if len(nextStates) == 0 { + return nil, errors.New("can't create a new stash without a state name") + } + if err := jm.Set(stashKeyState, nextStates[0]); err != nil { return nil, err } @@ -68,6 +79,14 @@ func newStash(nextStates ...StateName) (stash, error) { // newStashFromString creates a new instance of Stash with the given JSON data. func newStashFromString(data string) (stash, error) { + var err error + + if len(data) > 0 && !startsWithCurlyBrace(data) { + if data, err = decodeData(data); err != nil { + return nil, fmt.Errorf("faiiled to decode stash data: %w", err) + } + } + jm, err := jsonmanager.NewJSONManagerFromString(data) return newStashFromJSONManager(jm), err } @@ -80,6 +99,53 @@ func reverseStateNames(slice []StateName) []StateName { return reversed } +func startsWithCurlyBrace(s string) bool { + // Check if the string is not empty + if len(s) == 0 { + return false + } + // Check if the first character is '{' + return s[0] == '{' +} + +func encodeData(jsonData string) (string, error) { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write([]byte(jsonData)); err != nil { + return "", err + } + + if err := gw.Close(); err != nil { + return "", err + } + + gzippedData := buf.Bytes() + base64GzippedData := base64.StdEncoding.EncodeToString(gzippedData) + return base64GzippedData, nil +} + +func decodeData(base64GzippedData string) (string, error) { + gzippedData, err := base64.StdEncoding.DecodeString(base64GzippedData) + if err != nil { + return "", err + } + + buf := bytes.NewBuffer(gzippedData) + gr, err := gzip.NewReader(buf) + if err != nil { + return "", err + } + + defer gr.Close() + + decompressedData, err := io.ReadAll(gr) + if err != nil { + return "", err + } + + return string(decompressedData), nil +} + // Get retrieves the value at the specified path in the JSON data. func (h *defaultStash) Get(path string) gjson.Result { return h.data.Get(path) @@ -97,6 +163,10 @@ func (h *defaultStash) Delete(path string) error { // String returns the JSON data as a string. func (h *defaultStash) String() string { + if h.compressionEnabled { + s, _ := encodeData(h.jm.String()) + return s + } return h.jm.String() } @@ -235,3 +305,7 @@ func (h *defaultStash) isRevertible() bool { lastHistItemIndex := h.jm.Get(fmt.Sprintf("%s.#", stashKeyHistory)).Int() - 1 return h.jm.Get(fmt.Sprintf("%s.%d.%s", stashKeyHistory, lastHistItemIndex, stashKeyRevertible)).Bool() } + +func (h *defaultStash) useCompression(b bool) { + h.compressionEnabled = b +}