From b64214020adc179738acac35155b71ffab0c317f Mon Sep 17 00:00:00 2001 From: Michael Adler Date: Tue, 5 Sep 2023 13:53:55 +0200 Subject: [PATCH] feat: support for external middleware plugins This commit introduces possibility to integrate external middleware plugins. These plugins have the possibility to modify incoming HTTP requests prior to routing. This includes modifying headers, routing and denying requests. Due to the lack of a proper plugin system in Go, external middleware plugins are implemented as standalone sub-processes. They communicate with wfx via protobuf message over stdin/stdout, hence it's crucial for the plugin to read from stdin promptly to avoid blocking wfx. An example plugin written in Go is included in the `example/plugin` subdirectory, taking special care that stdin does not block. Signed-off-by: Michael Adler --- .github/workflows/ci.yml | 13 +- .gitlab-ci.yml | 5 +- .golangci.yml | 3 + .goreleaser.yml | 1 + CHANGELOG.md | 1 + Makefile | 4 +- cmd/wfx/cmd/root/cmd_test.go | 25 ++ cmd/wfx/cmd/root/flags.go | 11 + cmd/wfx/cmd/root/northbound.go | 34 ++- cmd/wfx/cmd/root/plugins.go | 68 +++++ cmd/wfx/cmd/root/plugins_test.go | 176 ++++++++++++ cmd/wfx/cmd/root/southbound.go | 34 ++- docs/installation.md | 1 + docs/operations.md | 43 +++ example/plugin/.gitignore | 1 + example/plugin/go.mod | 9 + example/plugin/go.sum | 7 + example/plugin/main.go | 106 ++++++++ fbs/client/header.fbs | 13 + fbs/client/request.fbs | 27 ++ fbs/client/response.fbs | 25 ++ fbs/request.fbs | 20 ++ fbs/response.fbs | 23 ++ generated/plugin/Payload.go | 69 +++++ generated/plugin/PluginRequest.go | 130 +++++++++ generated/plugin/PluginResponse.go | 146 ++++++++++ generated/plugin/client/Action.go | 35 +++ generated/plugin/client/Envelope.go | 133 +++++++++ generated/plugin/client/Request.go | 221 +++++++++++++++ generated/plugin/client/Response.go | 185 +++++++++++++ generated/plugin/client/ResponseStatus.go | 32 +++ go.mod | 1 + go.sum | 2 + go.work | 1 + justfile | 9 +- middleware/logging/log.go | 33 ++- middleware/logging/log_test.go | 28 +- middleware/logging/reader.go | 41 --- middleware/logging/reader_test.go | 31 --- middleware/plugin/disabled.go | 26 ++ middleware/plugin/disabled_test.go | 24 ++ middleware/plugin/ioutil/io.go | 86 ++++++ middleware/plugin/ioutil/io_test.go | 95 +++++++ middleware/plugin/ioutil/main_test.go | 19 ++ middleware/plugin/main_test.go | 19 ++ middleware/plugin/middleware.go | 140 ++++++++++ middleware/plugin/middleware_test.go | 240 +++++++++++++++++ middleware/plugin/plugin.go | 311 ++++++++++++++++++++++ middleware/plugin/plugin_test.go | 198 ++++++++++++++ middleware/plugin/process_unix.go | 64 +++++ middleware/plugin/process_windows.go | 27 ++ shell.nix | 3 +- 52 files changed, 2887 insertions(+), 112 deletions(-) create mode 100644 cmd/wfx/cmd/root/plugins.go create mode 100644 cmd/wfx/cmd/root/plugins_test.go create mode 100644 example/plugin/.gitignore create mode 100644 example/plugin/go.mod create mode 100644 example/plugin/go.sum create mode 100644 example/plugin/main.go create mode 100644 fbs/client/header.fbs create mode 100644 fbs/client/request.fbs create mode 100644 fbs/client/response.fbs create mode 100644 fbs/request.fbs create mode 100644 fbs/response.fbs create mode 100644 generated/plugin/Payload.go create mode 100644 generated/plugin/PluginRequest.go create mode 100644 generated/plugin/PluginResponse.go create mode 100644 generated/plugin/client/Action.go create mode 100644 generated/plugin/client/Envelope.go create mode 100644 generated/plugin/client/Request.go create mode 100644 generated/plugin/client/Response.go create mode 100644 generated/plugin/client/ResponseStatus.go delete mode 100644 middleware/logging/reader.go delete mode 100644 middleware/logging/reader_test.go create mode 100644 middleware/plugin/disabled.go create mode 100644 middleware/plugin/disabled_test.go create mode 100644 middleware/plugin/ioutil/io.go create mode 100644 middleware/plugin/ioutil/io_test.go create mode 100644 middleware/plugin/ioutil/main_test.go create mode 100644 middleware/plugin/main_test.go create mode 100644 middleware/plugin/middleware.go create mode 100644 middleware/plugin/middleware_test.go create mode 100644 middleware/plugin/plugin.go create mode 100644 middleware/plugin/plugin_test.go create mode 100644 middleware/plugin/process_unix.go create mode 100644 middleware/plugin/process_windows.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d7957068..75a778bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: --health-retries 20 steps: - uses: actions/checkout@v4 - - run: go test -timeout 180s -race -coverprofile=coverage.out -covermode=atomic -tags testing,integration,postgres,sqlite ./... + - run: go test -timeout 180s -race -coverprofile=coverage.out -covermode=atomic -tags testing,integration,postgres,sqlite,plugin ./... env: PGHOST: postgres PGPORT: 5432 @@ -109,7 +109,7 @@ jobs: --health-retries 20 steps: - uses: actions/checkout@v4 - - run: go test -timeout 180s -race -coverprofile=coverage.out -covermode=atomic -tags testing,integration,mysql,sqlite ./... + - run: go test -timeout 180s -race -coverprofile=coverage.out -covermode=atomic -tags testing,integration,mysql,sqlite,plugin ./... env: MYSQL_DATABASE: wfx MYSQL_ROOT_PASSWORD: root @@ -196,22 +196,23 @@ jobs: - uses: dominikh/staticcheck-action@v1.3.0 with: install-go: false - build-tags: sqlite,testing + build-tags: sqlite,testing,plugin - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: version: latest - args: --build-tags=sqlite,testing + args: --build-tags=sqlite,testing,plugin skip-cache: true generate: name: Generate Code runs-on: ubuntu-latest container: - image: quay.io/goswagger/swagger + image: archlinux steps: + - name: Install packages + run: pacman -Syu --noconfirm python-yaml git just go flatbuffers go-swagger gofumpt - uses: actions/checkout@v4 - - run: apk add --no-cache py3-yaml git just bash go - name: Disable git security features run: git config --global safe.directory '*' - uses: brokeyourbike/go-mockery-action@v0 diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 73889046..4bfaf31e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -60,10 +60,9 @@ generate: stage: lint needs: [] image: - name: quay.io/goswagger/swagger - entrypoint: [""] # needed to get a shell + name: archlinux:latest before_script: - - apk add --no-cache py3-yaml git just bash go git-lfs + - pacman -Syu --noconfirm python-yaml git just go flatbuffers go-swagger gofumpt - git lfs install && git submodule update script: - just generate diff --git a/.golangci.yml b/.golangci.yml index b4417231..b8ba536b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -88,6 +88,7 @@ linters: - tparallel - usestdlibvars - wrapcheck + - zerologlint linters-settings: staticcheck: @@ -101,3 +102,5 @@ linters-settings: wrapcheck: ignorePackageGlobs: - github.com/siemens/wfx/internal/errutil + - google.golang.org/protobuf/* + - io diff --git a/.goreleaser.yml b/.goreleaser.yml index 5d4b4881..0c0ecfb2 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -18,6 +18,7 @@ builds: - sqlite - postgres - mysql + - plugin flags: - -trimpath - -mod=readonly diff --git a/CHANGELOG.md b/CHANGELOG.md index d5a1440a..fe4bf105 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add optional `description` field to workflows - Job event notifications via server-sent events (see #11) +- Plugin System for External Middlewares (see #43) ### Fixed diff --git a/Makefile b/Makefile index cc54fdc3..2a3ce3c1 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ MAKEFLAGS += --jobs=$(shell nproc) DESTDIR ?= prefix ?= /usr/local -GO_TAGS = sqlite,postgres,mysql +GO_TAGS = sqlite,postgres,mysql,plugin export CGO_ENABLED=1 @@ -31,7 +31,7 @@ default: .PHONY: test test: - go test -race -coverprofile=coverage.out -covermode=atomic -timeout 30s ./... "--tags=sqlite,testing" + go test -race -coverprofile=coverage.out -covermode=atomic -timeout 30s ./... "--tags=sqlite,testing,plugin" .PHONY: install install: diff --git a/cmd/wfx/cmd/root/cmd_test.go b/cmd/wfx/cmd/root/cmd_test.go index 2cc89724..17c6796e 100644 --- a/cmd/wfx/cmd/root/cmd_test.go +++ b/cmd/wfx/cmd/root/cmd_test.go @@ -25,6 +25,7 @@ import ( "github.com/knadh/koanf/v2" "github.com/rs/zerolog" "github.com/siemens/wfx/cmd/wfxctl/flags" + "github.com/siemens/wfx/persistence" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -225,3 +226,27 @@ func waitForLogLevel(t *testing.T, expected zerolog.Level) { } require.Equal(t, expected.String(), zerolog.GlobalLevel().String()) } + +func TestCreateNorthboundCollection_PluginsDir(t *testing.T) { + dir, _ := os.MkdirTemp("", "TestCreateNorthboundCollection_PluginsDir.*") + k.Write(func(k *koanf.Koanf) { + _ = k.Set(mgmtPluginsDirFlag, dir) + }) + dbMock := persistence.NewMockStorage(t) + sc, err := createNorthboundCollection([]string{"http"}, dbMock) + t.Cleanup(func() { sc.Shutdown(context.Background()) }) + assert.NoError(t, err) + assert.NotNil(t, sc) +} + +func TestCreateSouthboundCollection_PluginsDir(t *testing.T) { + dir, _ := os.MkdirTemp("", "TestCreateSouthboundCollection_PluginsDir.*") + k.Write(func(k *koanf.Koanf) { + _ = k.Set(clientPluginsDirFlag, dir) + }) + dbMock := persistence.NewMockStorage(t) + sc, err := createSouthboundCollection([]string{"http"}, dbMock) + t.Cleanup(func() { sc.Shutdown(context.Background()) }) + assert.NoError(t, err) + assert.NotNil(t, sc) +} diff --git a/cmd/wfx/cmd/root/flags.go b/cmd/wfx/cmd/root/flags.go index 81e25aa0..b4a2976a 100644 --- a/cmd/wfx/cmd/root/flags.go +++ b/cmd/wfx/cmd/root/flags.go @@ -60,6 +60,10 @@ const ( preferedStorage = "sqlite" defaultStorageOpts = "file:wfx.db?_fk=1&_journal=WAL" + + // Plugins + clientPluginsDirFlag = "client-plugins-dir" + mgmtPluginsDirFlag = "mgmt-plugins-dir" ) func init() { @@ -109,6 +113,13 @@ func init() { f.StringSlice(configFlag, config.DefaultConfigFiles(), "path to one or more .yaml config files") _ = Command.MarkPersistentFlagFilename(configFlag, "yml", "yaml") + // plugins + _ = Command.MarkPersistentFlagDirname(clientPluginsDirFlag) + f.String(clientPluginsDirFlag, "", "directory containing client plugins") + + _ = Command.MarkPersistentFlagDirname(mgmtPluginsDirFlag) + f.String(mgmtPluginsDirFlag, "", "directory containing management plugins") + { var defaultStorage string supportedStorages := persistence.Storages() diff --git a/cmd/wfx/cmd/root/northbound.go b/cmd/wfx/cmd/root/northbound.go index bf45bacd..70bf36da 100644 --- a/cmd/wfx/cmd/root/northbound.go +++ b/cmd/wfx/cmd/root/northbound.go @@ -27,12 +27,15 @@ import ( func createNorthboundCollection(schemes []string, storage persistence.Storage) (*serverCollection, error) { var settings server.HTTPSettings + var pluginsDir string k.Read(func(k *koanf.Koanf) { settings.Host = k.String(mgmtHostFlag) settings.TLSHost = k.String(mgmtTLSHostFlag) settings.Port = k.Int(mgmtPortFlag) settings.TLSPort = k.Int(mgmtTLSPortFlag) settings.UDSPath = k.String(mgmtUnixSocketFlag) + + pluginsDir = k.String(mgmtPluginsDirFlag) }) api := api.NewNorthboundAPI(storage) fsMW, err := fileserver.NewFileServerMiddleware(k) @@ -41,18 +44,27 @@ func createNorthboundCollection(schemes []string, storage persistence.Storage) ( } swaggerJSON, _ := restapi.SwaggerJSON.MarshalJSON() - mw := middleware.NewGlobalMiddleware(restapi.ConfigureAPI(api), - []middleware.IntermediateMW{ - // LIFO - logging.MW{}, - jq.MW{}, - fsMW, - swagger.NewSpecMiddleware(api.Context().BasePath(), swaggerJSON), - health.NewHealthMiddleware(storage), - version.MW{}, - middleware.PromoteWrapper(cors.AllowAll().Handler), - }) + // LIFO, i.e. middlewares are applied in reverse order + intermdiateMws := []middleware.IntermediateMW{ + jq.MW{}, + fsMW, + swagger.NewSpecMiddleware(api.Context().BasePath(), swaggerJSON), + health.NewHealthMiddleware(storage), + version.MW{}, + middleware.PromoteWrapper(cors.AllowAll().Handler), + } + + if pluginsDir != "" { + mws, err := createPluginMiddlewares(pluginsDir) + if err != nil { + return nil, fault.Wrap(err) + } + intermdiateMws = append(intermdiateMws, mws...) + } + intermdiateMws = append(intermdiateMws, logging.MW{}) + + mw := middleware.NewGlobalMiddleware(restapi.ConfigureAPI(api), intermdiateMws) servers, err := createServers(schemes, mw, settings) if err != nil { return nil, fault.Wrap(err) diff --git a/cmd/wfx/cmd/root/plugins.go b/cmd/wfx/cmd/root/plugins.go new file mode 100644 index 00000000..5b14b162 --- /dev/null +++ b/cmd/wfx/cmd/root/plugins.go @@ -0,0 +1,68 @@ +package root + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "os" + "path" + "path/filepath" + "sort" + + "github.com/Southclaws/fault" + "github.com/rs/zerolog/log" + "github.com/siemens/wfx/middleware" + "github.com/siemens/wfx/middleware/plugin" +) + +func createPluginMiddlewares(pluginsDir string) ([]middleware.IntermediateMW, error) { + pluginMws, err := loadPlugins(pluginsDir) + if err != nil { + return nil, fault.Wrap(err) + } + result := make([]middleware.IntermediateMW, 0, len(pluginMws)) + for _, p := range pluginMws { + mw, err := plugin.NewMiddleware(p) + if err != nil { + return nil, fault.Wrap(err) + } + result = append(result, mw) + } + return result, nil +} + +func loadPlugins(dir string) ([]plugin.Plugin, error) { + log.Debug().Msg("Loading plugins") + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fault.Wrap(err) + } + + result := make([]plugin.Plugin, 0, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + dest, err := filepath.EvalSymlinks(path.Join(dir, entry.Name())) + if err != nil { + return nil, fault.Wrap(err) + } + info, err := os.Stat(dest) + if err != nil { + return nil, fault.Wrap(err) + } + // check if file is executable + if (info.Mode() & 0o111) != 0 { + result = append(result, plugin.NewFBPlugin(dest)) + } else { + log.Warn().Str("dest", dest).Msg("Ignoring non-executable file") + } + } + } + sort.Slice(result, func(i int, j int) bool { return result[i].Name() < result[j].Name() }) + log.Debug().Int("count", len(result)).Msg("Loaded plugins") + return result, nil +} diff --git a/cmd/wfx/cmd/root/plugins_test.go b/cmd/wfx/cmd/root/plugins_test.go new file mode 100644 index 00000000..1f735a4b --- /dev/null +++ b/cmd/wfx/cmd/root/plugins_test.go @@ -0,0 +1,176 @@ +package root + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "io" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadPluginsEmpty(t *testing.T) { + t.Parallel() + + dir, _ := os.MkdirTemp("", "TestLoadPluginsEmpty") + t.Cleanup(func() { + _ = os.Remove(dir) + }) + plugins, err := loadPlugins(dir) + require.NoError(t, err) + assert.Empty(t, plugins) +} + +func TestLoadPlugins(t *testing.T) { + t.Parallel() + + dir, _ := os.MkdirTemp("", "TestLoadPlugins") + t.Cleanup(func() { + _ = os.RemoveAll(dir) + }) + + f, _ := os.CreateTemp(dir, "plugin") + _ = f.Close() + _ = os.Chmod(f.Name(), os.FileMode(0o700)) + + plugins, err := loadPlugins(dir) + require.NoError(t, err) + assert.Len(t, plugins, 1) + assert.Equal(t, f.Name(), plugins[0].Name()) +} + +func TestLoadPluginsIgnoreNonExecutable(t *testing.T) { + t.Parallel() + + dir, _ := os.MkdirTemp("", "TestLoadPluginsIgnoreNonExecutable") + t.Cleanup(func() { + _ = os.RemoveAll(dir) + }) + + f, _ := os.CreateTemp(dir, "plugin") + _ = f.Close() + + plugins, err := loadPlugins(dir) + require.NoError(t, err) + assert.Len(t, plugins, 0) +} + +func TestLoadPluginsSymlink(t *testing.T) { + t.Parallel() + + baseDir, _ := os.MkdirTemp("", "TestLoadPluginsSymlink") + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + + first, _ := os.MkdirTemp(baseDir, "first") + second, _ := os.MkdirTemp(baseDir, "second") + + f, _ := os.CreateTemp(first, "plugin") + _ = f.Close() + _ = os.Chmod(f.Name(), os.FileMode(0o700)) + + // create symlink + dest := path.Join(second, "example") + _ = os.Symlink(f.Name(), dest) + + plugins, err := loadPlugins(second) + require.NoError(t, err) + assert.Len(t, plugins, 1) + assert.Equal(t, f.Name(), plugins[0].Name()) +} + +func TestLoadPluginsSymlinkIgnoreNonExecutable(t *testing.T) { + t.Parallel() + + baseDir, _ := os.MkdirTemp("", "TestLoadPluginsSymlinkIgnoreNonExecutable") + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + + first, _ := os.MkdirTemp(baseDir, "first") + second, _ := os.MkdirTemp(baseDir, "second") + + f, _ := os.CreateTemp(first, "plugin") + _ = f.Close() + + // create symlink + dest := path.Join(second, "example") + _ = os.Symlink(f.Name(), dest) + + plugins, err := loadPlugins(second) + require.NoError(t, err) + assert.Len(t, plugins, 0) +} + +func TestCreatePluginMiddlewares_InvalidDir(t *testing.T) { + mws, err := createPluginMiddlewares("") + assert.Nil(t, mws) + assert.NotNil(t, err) +} + +func TestCreatePluginMiddlewares_EmptydDir(t *testing.T) { + baseDir, _ := os.MkdirTemp("", "TestCreatePluginMiddlewares_EmptydDir") + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + mws, err := createPluginMiddlewares(baseDir) + assert.Empty(t, mws) + assert.NoError(t, err) +} + +func TestCreatePluginMiddlewares_PluginFailure(t *testing.T) { + baseDir, _ := os.MkdirTemp("", "TestCreatePluginMiddlewares_PluginFailure") + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + + f, err := os.CreateTemp(baseDir, "plugin*.sh") + require.NoError(t, err) + t.Cleanup(func() { + _ = os.Remove(f.Name()) + }) + _, _ = io.WriteString(f, "no shebang") + fname := f.Name() + _ = f.Close() + _ = os.Chmod(fname, os.FileMode(0o700)) + + mws, err := createPluginMiddlewares(baseDir) + assert.Nil(t, mws) + assert.Error(t, err) +} + +func TestCreatePluginMiddlewares(t *testing.T) { + baseDir, _ := os.MkdirTemp("", "TestCreatePluginMiddlewares") + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + + f, err := os.CreateTemp(baseDir, "plugin*.sh") + require.NoError(t, err) + t.Cleanup(func() { + _ = os.Remove(f.Name()) + }) + _, _ = io.WriteString(f, `#!/bin/sh +while true; do + sleep 1 +done +`) + fname := f.Name() + _ = f.Close() + _ = os.Chmod(fname, os.FileMode(0o700)) + + mws, err := createPluginMiddlewares(baseDir) + assert.Len(t, mws, 1) + assert.NoError(t, err) + mws[0].Shutdown() +} diff --git a/cmd/wfx/cmd/root/southbound.go b/cmd/wfx/cmd/root/southbound.go index adaa96d4..58f7ad35 100644 --- a/cmd/wfx/cmd/root/southbound.go +++ b/cmd/wfx/cmd/root/southbound.go @@ -27,12 +27,15 @@ import ( func createSouthboundCollection(schemes []string, storage persistence.Storage) (*serverCollection, error) { var settings server.HTTPSettings + var pluginsDir string k.Read(func(k *koanf.Koanf) { settings.Host = k.String(clientHostFlag) settings.TLSHost = k.String(clientTLSHostFlag) settings.Port = k.Int(clientPortFlag) settings.TLSPort = k.Int(clientTLSPortFlag) settings.UDSPath = k.String(clientUnixSocket) + + pluginsDir = k.String(clientPluginsDirFlag) }) api := api.NewSouthboundAPI(storage) @@ -42,18 +45,27 @@ func createSouthboundCollection(schemes []string, storage persistence.Storage) ( } swaggerJSON, _ := restapi.SwaggerJSON.MarshalJSON() - mw := middleware.NewGlobalMiddleware(restapi.ConfigureAPI(api), - []middleware.IntermediateMW{ - // LIFO - logging.MW{}, - jq.MW{}, - fsMW, - swagger.NewSpecMiddleware(api.Context().BasePath(), swaggerJSON), - health.NewHealthMiddleware(storage), - version.MW{}, - middleware.PromoteWrapper(cors.AllowAll().Handler), - }) + // LIFO, i.e. middlewares are applied in reverse order + intermdiateMws := []middleware.IntermediateMW{ + jq.MW{}, + fsMW, + swagger.NewSpecMiddleware(api.Context().BasePath(), swaggerJSON), + health.NewHealthMiddleware(storage), + version.MW{}, + middleware.PromoteWrapper(cors.AllowAll().Handler), + } + + if pluginsDir != "" { + mws, err := createPluginMiddlewares(pluginsDir) + if err != nil { + return nil, fault.Wrap(err) + } + intermdiateMws = append(intermdiateMws, mws...) + } + intermdiateMws = append(intermdiateMws, logging.MW{}) + + mw := middleware.NewGlobalMiddleware(restapi.ConfigureAPI(api), intermdiateMws) servers, err := createServers(schemes, mw, settings) if err != nil { return nil, fault.Wrap(err) diff --git a/docs/installation.md b/docs/installation.md index f6deb122..e86d8f57 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -30,6 +30,7 @@ The following persistent storage selection build tags are available: | `libsqlite3` | Dynamically link against `libsqlite3` | | `postgres` | Enable built-in [PostgreSQL](https://www.postgresql.org) support | | `mysql` | Enable built-in [MySQL](https://www.mysql.com/) support | +| `plugin` | Enable support for [external plugins](operations.md#Plugins) | By default, all built-in persistent storage options are enabled. diff --git a/docs/operations.md b/docs/operations.md index 27d6faa0..131a79af 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -254,6 +254,49 @@ access. └──────────────────┘ ``` +## Plugins + +wfx provides a flexible (out-of-tree) plugin mechanism to extend its HTTP request processing capabilities. +A plugin operates as a subprocess initiated and supervised by wfx, thereby allowing plugins to be written in _any_ +programming language. + +### Enabling Plugins + +To enable plugins, wfx must be built with the `plugin` tag (enabled by default) and started with the +`--client-plugins-dir`/`--mgmt-plugins-dir` flag. +_Any executable_ in the respective directory is considered to be a plugin (symlinks to executables work as well). +Plugins are sorted and executed in lexicographic order. + +### Developing Plugins + +Communication between wfx and a plugin is achieved by exchanging [flatbuffer](https://flatbuffers.dev/) messages via +stdin/stdout. The flatbuffer specification is available in the [fbs](../fbs) directory. + +For every incoming HTTP request, wfx generates a unique string called `cookie`. This string, along with the complete +HTTP request (including headers and body), is written to the plugin's stdin. The plugin then sends its response, paired +with the same `cookie`, back to wfx by writing to its stdout. This `cookie` mechanism ensures that wfx can accurately +associate responses with their corresponding requests. + +**Note**: It is crucial for the plugin to read data from its stdin descriptor promptly to prevent blocking writes by +wfx. The `cookie` mechanism facilitates asynchronous processing. + +Based on the plugin's response, wfx can: + +- Modify the incoming HTTP request before it undergoes further processing by wfx in the usual manner. +- Send a preemptive HTTP response back to the client, such as a "permission denied" message. +- Leave the request unchanged. + +### Use Cases + +Plugins are typically used for: + +- Enforcing authentication and authorization for API endpoints. +- Handling URL rewriting and redirection tasks. + +### Example + +An [example plugin](../example/plugin) written in Go demonstrates denying access to the `/api/wfx/v1/workflows` endpoint. + ## Telemetry No telemetry or user data is collected or processed by wfx. diff --git a/example/plugin/.gitignore b/example/plugin/.gitignore new file mode 100644 index 00000000..6d635cb1 --- /dev/null +++ b/example/plugin/.gitignore @@ -0,0 +1 @@ +/plugin diff --git a/example/plugin/go.mod b/example/plugin/go.mod new file mode 100644 index 00000000..23095adf --- /dev/null +++ b/example/plugin/go.mod @@ -0,0 +1,9 @@ +module github.com/siemens/wfx/example/plugin + +replace github.com/siemens/wfx => ../.. + +go 1.19 + +require github.com/siemens/wfx v0.0.0-00010101000000-000000000000 + +require github.com/google/flatbuffers v23.5.26+incompatible // indirect diff --git a/example/plugin/go.sum b/example/plugin/go.sum new file mode 100644 index 00000000..b5d1ba0d --- /dev/null +++ b/example/plugin/go.sum @@ -0,0 +1,7 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= +github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/example/plugin/main.go b/example/plugin/main.go new file mode 100644 index 00000000..32ad247d --- /dev/null +++ b/example/plugin/main.go @@ -0,0 +1,106 @@ +package main + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "log" + "os" + "strings" + "time" + + "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/siemens/wfx/middleware/plugin/ioutil" +) + +const queueSize = 64 + +type QueueEntry struct { + start time.Time + request *plugin.PluginRequestT +} + +func main() { + // NOTE: The working directory for this process is inherited from the wfx parent process + file, err := os.OpenFile("plugin.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + log.Fatal(err) + } + defer file.Close() + log.SetOutput(file) + log.Println("[INFO] Plugin starting...") + + queue := make(chan QueueEntry, queueSize) + + // reader goroutine: it's only purpose is to read from stdin as soon as + // something comes in, to prevent the other side from blocking + go func() { + for { + req, err := ioutil.ReadRequest(os.Stdin) + if err != nil { + log.Println("[ERROR] Failed to receive message", err) + continue + } + + entry := newQueueEntry(req) + // this is important: we rather discard the message instead of + // blocking, since we want to read the next message from stdin + select { + case queue <- entry: + log.Println("[DEBUG] Message enqueued") + // success + default: + log.Println("[ERROR] Queue full. Message discarded.") + } + } + }() + + // here we decide what to do with the request and send the response + for entry := range queue { + req := entry.request + + destination := req.Request.Destination + body := "\n" + if len(req.Request.Content) > 0 { + body = string(req.Request.Content) + } + log.Printf("[DEBUG] Processing request: cookie=%d, destination=%s, body=%s", req.Cookie, destination, body) + + // prepare response, it's important to use the cookie from the request + resp := plugin.PluginResponseT{Cookie: req.Cookie} + + // this just an example; prevent access to /workflows + if strings.Contains(destination, "/api/wfx/v1/workflows") { + log.Println("[DEBUG] Denying request") + + resp.Payload = &plugin.PayloadT{ + Type: plugin.Payloadgenerated_plugin_client_Response, + Value: &client.ResponseT{ + Status: client.ResponseStatusDeny, + Content: []byte("You are not allowed to access the workflows resource.\n"), + }, + } + } else { + log.Println("[DEBUG] Allowing request") + } + + if err := ioutil.WriteResponse(os.Stdout, &resp); err != nil { + log.Println("[ERROR] Failed to send message", err) + } + delta := time.Since(entry.start) + log.Printf("[INFO] Processed request in %0.02f us\n", float64(delta.Nanoseconds())/1_000.) + } +} + +func newQueueEntry(request *plugin.PluginRequestT) QueueEntry { + return QueueEntry{ + start: time.Now(), + request: request, + } +} diff --git a/fbs/client/header.fbs b/fbs/client/header.fbs new file mode 100644 index 00000000..b22951b4 --- /dev/null +++ b/fbs/client/header.fbs @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2023 Siemens AG +// +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Michael Adler + +namespace generated.plugin.client; + +// Envelope consists of a name and a (possibly empty) list of values. +table Envelope { + name: string; + values: [string]; +} diff --git a/fbs/client/request.fbs b/fbs/client/request.fbs new file mode 100644 index 00000000..5351aae8 --- /dev/null +++ b/fbs/client/request.fbs @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2023 Siemens AG +// +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Michael Adler + +include "header.fbs"; + +namespace generated.plugin.client; + +enum Action: byte { + Read = 0, + Create = 1, + Update = 2, + Delete = 3, +} + +table Request { + // Unique identifier for the request used to correlate requests and responses. + cookie: ulong; + + destination: string; + // The action requested by the client. This field is read-only. + action: Action; + envelope: [Envelope]; + content: [ubyte]; +} diff --git a/fbs/client/response.fbs b/fbs/client/response.fbs new file mode 100644 index 00000000..7f9ed39c --- /dev/null +++ b/fbs/client/response.fbs @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2023 Siemens AG +// +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Michael Adler + +include "header.fbs"; + +namespace generated.plugin.client; + +enum ResponseStatus: byte { + // The request was denied and may not proceed. + Deny = 0, + // THe request may proceed. + Accept = 1, + // The request was modifified and may proceed. + Modified = 2, +} + +// Response contains the response details for a client. +table Response { + status: ResponseStatus; + envelope: [Envelope]; + content: [ubyte]; +} diff --git a/fbs/request.fbs b/fbs/request.fbs new file mode 100644 index 00000000..7d78d137 --- /dev/null +++ b/fbs/request.fbs @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2023 Siemens AG +// +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Michael Adler + +include "client/request.fbs"; + +namespace generated.plugin; + +// PluginRequest sent to the plugin. +table PluginRequest { + // Version identifier for the schema. + version: ulong; + // Cookie is a unique value for the specific request. + cookie: ulong; + request: client.Request; +} + +root_type PluginRequest; diff --git a/fbs/response.fbs b/fbs/response.fbs new file mode 100644 index 00000000..a5f83045 --- /dev/null +++ b/fbs/response.fbs @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2023 Siemens AG +// +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Michael Adler + +include "client/request.fbs"; +include "client/response.fbs"; + +namespace generated.plugin; + +union Payload { generated.plugin.client.Request, generated.plugin.client.Response } + +// PluginResponse is the response of a plugin. +table PluginResponse { + // Version identifier for the schema. + version: ulong; + // Cookie used to correlate the response with the original request. + cookie: ulong; + payload: Payload; +} + +root_type PluginResponse; diff --git a/generated/plugin/Payload.go b/generated/plugin/Payload.go new file mode 100644 index 00000000..2960b682 --- /dev/null +++ b/generated/plugin/Payload.go @@ -0,0 +1,69 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package plugin + +import ( + flatbuffers "github.com/google/flatbuffers/go" + "strconv" + + generated__plugin__client "github.com/siemens/wfx/generated/plugin/client" +) + +type Payload byte + +const ( + PayloadNONE Payload = 0 + Payloadgenerated_plugin_client_Request Payload = 1 + Payloadgenerated_plugin_client_Response Payload = 2 +) + +var EnumNamesPayload = map[Payload]string{ + PayloadNONE: "NONE", + Payloadgenerated_plugin_client_Request: "generated_plugin_client_Request", + Payloadgenerated_plugin_client_Response: "generated_plugin_client_Response", +} + +var EnumValuesPayload = map[string]Payload{ + "NONE": PayloadNONE, + "generated_plugin_client_Request": Payloadgenerated_plugin_client_Request, + "generated_plugin_client_Response": Payloadgenerated_plugin_client_Response, +} + +func (v Payload) String() string { + if s, ok := EnumNamesPayload[v]; ok { + return s + } + return "Payload(" + strconv.FormatInt(int64(v), 10) + ")" +} + +type PayloadT struct { + Type Payload + Value interface{} +} + +func (t *PayloadT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + switch t.Type { + case Payloadgenerated_plugin_client_Request: + return t.Value.(*generated__plugin__client.RequestT).Pack(builder) + case Payloadgenerated_plugin_client_Response: + return t.Value.(*generated__plugin__client.ResponseT).Pack(builder) + } + return 0 +} + +func (rcv Payload) UnPack(table flatbuffers.Table) *PayloadT { + switch rcv { + case Payloadgenerated_plugin_client_Request: + var x generated__plugin__client.Request + x.Init(table.Bytes, table.Pos) + return &PayloadT{Type: Payloadgenerated_plugin_client_Request, Value: x.UnPack()} + case Payloadgenerated_plugin_client_Response: + var x generated__plugin__client.Response + x.Init(table.Bytes, table.Pos) + return &PayloadT{Type: Payloadgenerated_plugin_client_Response, Value: x.UnPack()} + } + return nil +} diff --git a/generated/plugin/PluginRequest.go b/generated/plugin/PluginRequest.go new file mode 100644 index 00000000..92fd4335 --- /dev/null +++ b/generated/plugin/PluginRequest.go @@ -0,0 +1,130 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package plugin + +import ( + flatbuffers "github.com/google/flatbuffers/go" + + generated__plugin__client "github.com/siemens/wfx/generated/plugin/client" +) + +type PluginRequestT struct { + Version uint64 `json:"version"` + Cookie uint64 `json:"cookie"` + Request *generated__plugin__client.RequestT `json:"request"` +} + +func (t *PluginRequestT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + requestOffset := t.Request.Pack(builder) + PluginRequestStart(builder) + PluginRequestAddVersion(builder, t.Version) + PluginRequestAddCookie(builder, t.Cookie) + PluginRequestAddRequest(builder, requestOffset) + return PluginRequestEnd(builder) +} + +func (rcv *PluginRequest) UnPackTo(t *PluginRequestT) { + t.Version = rcv.Version() + t.Cookie = rcv.Cookie() + t.Request = rcv.Request(nil).UnPack() +} + +func (rcv *PluginRequest) UnPack() *PluginRequestT { + if rcv == nil { + return nil + } + t := &PluginRequestT{} + rcv.UnPackTo(t) + return t +} + +type PluginRequest struct { + _tab flatbuffers.Table +} + +func GetRootAsPluginRequest(buf []byte, offset flatbuffers.UOffsetT) *PluginRequest { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &PluginRequest{} + x.Init(buf, n+offset) + return x +} + +func FinishPluginRequestBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsPluginRequest(buf []byte, offset flatbuffers.UOffsetT) *PluginRequest { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &PluginRequest{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedPluginRequestBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *PluginRequest) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *PluginRequest) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *PluginRequest) Version() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *PluginRequest) MutateVersion(n uint64) bool { + return rcv._tab.MutateUint64Slot(4, n) +} + +func (rcv *PluginRequest) Cookie() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *PluginRequest) MutateCookie(n uint64) bool { + return rcv._tab.MutateUint64Slot(6, n) +} + +func (rcv *PluginRequest) Request(obj *generated__plugin__client.Request) *generated__plugin__client.Request { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + x := rcv._tab.Indirect(o + rcv._tab.Pos) + if obj == nil { + obj = new(generated__plugin__client.Request) + } + obj.Init(rcv._tab.Bytes, x) + return obj + } + return nil +} + +func PluginRequestStart(builder *flatbuffers.Builder) { + builder.StartObject(3) +} +func PluginRequestAddVersion(builder *flatbuffers.Builder, version uint64) { + builder.PrependUint64Slot(0, version, 0) +} +func PluginRequestAddCookie(builder *flatbuffers.Builder, cookie uint64) { + builder.PrependUint64Slot(1, cookie, 0) +} +func PluginRequestAddRequest(builder *flatbuffers.Builder, request flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(request), 0) +} +func PluginRequestEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/generated/plugin/PluginResponse.go b/generated/plugin/PluginResponse.go new file mode 100644 index 00000000..8ced1a45 --- /dev/null +++ b/generated/plugin/PluginResponse.go @@ -0,0 +1,146 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package plugin + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type PluginResponseT struct { + Version uint64 `json:"version"` + Cookie uint64 `json:"cookie"` + Payload *PayloadT `json:"payload"` +} + +func (t *PluginResponseT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + payloadOffset := t.Payload.Pack(builder) + + PluginResponseStart(builder) + PluginResponseAddVersion(builder, t.Version) + PluginResponseAddCookie(builder, t.Cookie) + if t.Payload != nil { + PluginResponseAddPayloadType(builder, t.Payload.Type) + } + PluginResponseAddPayload(builder, payloadOffset) + return PluginResponseEnd(builder) +} + +func (rcv *PluginResponse) UnPackTo(t *PluginResponseT) { + t.Version = rcv.Version() + t.Cookie = rcv.Cookie() + payloadTable := flatbuffers.Table{} + if rcv.Payload(&payloadTable) { + t.Payload = rcv.PayloadType().UnPack(payloadTable) + } +} + +func (rcv *PluginResponse) UnPack() *PluginResponseT { + if rcv == nil { + return nil + } + t := &PluginResponseT{} + rcv.UnPackTo(t) + return t +} + +type PluginResponse struct { + _tab flatbuffers.Table +} + +func GetRootAsPluginResponse(buf []byte, offset flatbuffers.UOffsetT) *PluginResponse { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &PluginResponse{} + x.Init(buf, n+offset) + return x +} + +func FinishPluginResponseBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsPluginResponse(buf []byte, offset flatbuffers.UOffsetT) *PluginResponse { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &PluginResponse{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedPluginResponseBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *PluginResponse) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *PluginResponse) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *PluginResponse) Version() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *PluginResponse) MutateVersion(n uint64) bool { + return rcv._tab.MutateUint64Slot(4, n) +} + +func (rcv *PluginResponse) Cookie() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *PluginResponse) MutateCookie(n uint64) bool { + return rcv._tab.MutateUint64Slot(6, n) +} + +func (rcv *PluginResponse) PayloadType() Payload { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return Payload(rcv._tab.GetByte(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *PluginResponse) MutatePayloadType(n Payload) bool { + return rcv._tab.MutateByteSlot(8, byte(n)) +} + +func (rcv *PluginResponse) Payload(obj *flatbuffers.Table) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + rcv._tab.Union(obj, o) + return true + } + return false +} + +func PluginResponseStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func PluginResponseAddVersion(builder *flatbuffers.Builder, version uint64) { + builder.PrependUint64Slot(0, version, 0) +} +func PluginResponseAddCookie(builder *flatbuffers.Builder, cookie uint64) { + builder.PrependUint64Slot(1, cookie, 0) +} +func PluginResponseAddPayloadType(builder *flatbuffers.Builder, payloadType Payload) { + builder.PrependByteSlot(2, byte(payloadType), 0) +} +func PluginResponseAddPayload(builder *flatbuffers.Builder, payload flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(payload), 0) +} +func PluginResponseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/generated/plugin/client/Action.go b/generated/plugin/client/Action.go new file mode 100644 index 00000000..7fdd5909 --- /dev/null +++ b/generated/plugin/client/Action.go @@ -0,0 +1,35 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package client + +import "strconv" + +type Action int8 + +const ( + ActionRead Action = 0 + ActionCreate Action = 1 + ActionUpdate Action = 2 + ActionDelete Action = 3 +) + +var EnumNamesAction = map[Action]string{ + ActionRead: "Read", + ActionCreate: "Create", + ActionUpdate: "Update", + ActionDelete: "Delete", +} + +var EnumValuesAction = map[string]Action{ + "Read": ActionRead, + "Create": ActionCreate, + "Update": ActionUpdate, + "Delete": ActionDelete, +} + +func (v Action) String() string { + if s, ok := EnumNamesAction[v]; ok { + return s + } + return "Action(" + strconv.FormatInt(int64(v), 10) + ")" +} diff --git a/generated/plugin/client/Envelope.go b/generated/plugin/client/Envelope.go new file mode 100644 index 00000000..f9a403fc --- /dev/null +++ b/generated/plugin/client/Envelope.go @@ -0,0 +1,133 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package client + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type EnvelopeT struct { + Name string `json:"name"` + Values []string `json:"values"` +} + +func (t *EnvelopeT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + nameOffset := flatbuffers.UOffsetT(0) + if t.Name != "" { + nameOffset = builder.CreateString(t.Name) + } + valuesOffset := flatbuffers.UOffsetT(0) + if t.Values != nil { + valuesLength := len(t.Values) + valuesOffsets := make([]flatbuffers.UOffsetT, valuesLength) + for j := 0; j < valuesLength; j++ { + valuesOffsets[j] = builder.CreateString(t.Values[j]) + } + EnvelopeStartValuesVector(builder, valuesLength) + for j := valuesLength - 1; j >= 0; j-- { + builder.PrependUOffsetT(valuesOffsets[j]) + } + valuesOffset = builder.EndVector(valuesLength) + } + EnvelopeStart(builder) + EnvelopeAddName(builder, nameOffset) + EnvelopeAddValues(builder, valuesOffset) + return EnvelopeEnd(builder) +} + +func (rcv *Envelope) UnPackTo(t *EnvelopeT) { + t.Name = string(rcv.Name()) + valuesLength := rcv.ValuesLength() + t.Values = make([]string, valuesLength) + for j := 0; j < valuesLength; j++ { + t.Values[j] = string(rcv.Values(j)) + } +} + +func (rcv *Envelope) UnPack() *EnvelopeT { + if rcv == nil { + return nil + } + t := &EnvelopeT{} + rcv.UnPackTo(t) + return t +} + +type Envelope struct { + _tab flatbuffers.Table +} + +func GetRootAsEnvelope(buf []byte, offset flatbuffers.UOffsetT) *Envelope { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Envelope{} + x.Init(buf, n+offset) + return x +} + +func FinishEnvelopeBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsEnvelope(buf []byte, offset flatbuffers.UOffsetT) *Envelope { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Envelope{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedEnvelopeBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *Envelope) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Envelope) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Envelope) Name() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Envelope) Values(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j*4)) + } + return nil +} + +func (rcv *Envelope) ValuesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func EnvelopeStart(builder *flatbuffers.Builder) { + builder.StartObject(2) +} +func EnvelopeAddName(builder *flatbuffers.Builder, name flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(name), 0) +} +func EnvelopeAddValues(builder *flatbuffers.Builder, values flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(values), 0) +} +func EnvelopeStartValuesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func EnvelopeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/generated/plugin/client/Request.go b/generated/plugin/client/Request.go new file mode 100644 index 00000000..a533d32f --- /dev/null +++ b/generated/plugin/client/Request.go @@ -0,0 +1,221 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package client + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type RequestT struct { + Cookie uint64 `json:"cookie"` + Destination string `json:"destination"` + Action Action `json:"action"` + Envelope []*EnvelopeT `json:"envelope"` + Content []byte `json:"content"` +} + +func (t *RequestT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + destinationOffset := flatbuffers.UOffsetT(0) + if t.Destination != "" { + destinationOffset = builder.CreateString(t.Destination) + } + envelopeOffset := flatbuffers.UOffsetT(0) + if t.Envelope != nil { + envelopeLength := len(t.Envelope) + envelopeOffsets := make([]flatbuffers.UOffsetT, envelopeLength) + for j := 0; j < envelopeLength; j++ { + envelopeOffsets[j] = t.Envelope[j].Pack(builder) + } + RequestStartEnvelopeVector(builder, envelopeLength) + for j := envelopeLength - 1; j >= 0; j-- { + builder.PrependUOffsetT(envelopeOffsets[j]) + } + envelopeOffset = builder.EndVector(envelopeLength) + } + contentOffset := flatbuffers.UOffsetT(0) + if t.Content != nil { + contentOffset = builder.CreateByteString(t.Content) + } + RequestStart(builder) + RequestAddCookie(builder, t.Cookie) + RequestAddDestination(builder, destinationOffset) + RequestAddAction(builder, t.Action) + RequestAddEnvelope(builder, envelopeOffset) + RequestAddContent(builder, contentOffset) + return RequestEnd(builder) +} + +func (rcv *Request) UnPackTo(t *RequestT) { + t.Cookie = rcv.Cookie() + t.Destination = string(rcv.Destination()) + t.Action = rcv.Action() + envelopeLength := rcv.EnvelopeLength() + t.Envelope = make([]*EnvelopeT, envelopeLength) + for j := 0; j < envelopeLength; j++ { + x := Envelope{} + rcv.Envelope(&x, j) + t.Envelope[j] = x.UnPack() + } + t.Content = rcv.ContentBytes() +} + +func (rcv *Request) UnPack() *RequestT { + if rcv == nil { + return nil + } + t := &RequestT{} + rcv.UnPackTo(t) + return t +} + +type Request struct { + _tab flatbuffers.Table +} + +func GetRootAsRequest(buf []byte, offset flatbuffers.UOffsetT) *Request { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Request{} + x.Init(buf, n+offset) + return x +} + +func FinishRequestBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsRequest(buf []byte, offset flatbuffers.UOffsetT) *Request { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Request{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedRequestBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *Request) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Request) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Request) Cookie() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Request) MutateCookie(n uint64) bool { + return rcv._tab.MutateUint64Slot(4, n) +} + +func (rcv *Request) Destination() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Request) Action() Action { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return Action(rcv._tab.GetInt8(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *Request) MutateAction(n Action) bool { + return rcv._tab.MutateInt8Slot(8, int8(n)) +} + +func (rcv *Request) Envelope(obj *Envelope, j int) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + x := rcv._tab.Vector(o) + x += flatbuffers.UOffsetT(j) * 4 + x = rcv._tab.Indirect(x) + obj.Init(rcv._tab.Bytes, x) + return true + } + return false +} + +func (rcv *Request) EnvelopeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Request) Content(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Request) ContentLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Request) ContentBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Request) MutateContent(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func RequestStart(builder *flatbuffers.Builder) { + builder.StartObject(5) +} +func RequestAddCookie(builder *flatbuffers.Builder, cookie uint64) { + builder.PrependUint64Slot(0, cookie, 0) +} +func RequestAddDestination(builder *flatbuffers.Builder, destination flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(destination), 0) +} +func RequestAddAction(builder *flatbuffers.Builder, action Action) { + builder.PrependInt8Slot(2, int8(action), 0) +} +func RequestAddEnvelope(builder *flatbuffers.Builder, envelope flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(envelope), 0) +} +func RequestStartEnvelopeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func RequestAddContent(builder *flatbuffers.Builder, content flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(content), 0) +} +func RequestStartContentVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func RequestEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/generated/plugin/client/Response.go b/generated/plugin/client/Response.go new file mode 100644 index 00000000..d31defc9 --- /dev/null +++ b/generated/plugin/client/Response.go @@ -0,0 +1,185 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package client + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type ResponseT struct { + Status ResponseStatus `json:"status"` + Envelope []*EnvelopeT `json:"envelope"` + Content []byte `json:"content"` +} + +func (t *ResponseT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + envelopeOffset := flatbuffers.UOffsetT(0) + if t.Envelope != nil { + envelopeLength := len(t.Envelope) + envelopeOffsets := make([]flatbuffers.UOffsetT, envelopeLength) + for j := 0; j < envelopeLength; j++ { + envelopeOffsets[j] = t.Envelope[j].Pack(builder) + } + ResponseStartEnvelopeVector(builder, envelopeLength) + for j := envelopeLength - 1; j >= 0; j-- { + builder.PrependUOffsetT(envelopeOffsets[j]) + } + envelopeOffset = builder.EndVector(envelopeLength) + } + contentOffset := flatbuffers.UOffsetT(0) + if t.Content != nil { + contentOffset = builder.CreateByteString(t.Content) + } + ResponseStart(builder) + ResponseAddStatus(builder, t.Status) + ResponseAddEnvelope(builder, envelopeOffset) + ResponseAddContent(builder, contentOffset) + return ResponseEnd(builder) +} + +func (rcv *Response) UnPackTo(t *ResponseT) { + t.Status = rcv.Status() + envelopeLength := rcv.EnvelopeLength() + t.Envelope = make([]*EnvelopeT, envelopeLength) + for j := 0; j < envelopeLength; j++ { + x := Envelope{} + rcv.Envelope(&x, j) + t.Envelope[j] = x.UnPack() + } + t.Content = rcv.ContentBytes() +} + +func (rcv *Response) UnPack() *ResponseT { + if rcv == nil { + return nil + } + t := &ResponseT{} + rcv.UnPackTo(t) + return t +} + +type Response struct { + _tab flatbuffers.Table +} + +func GetRootAsResponse(buf []byte, offset flatbuffers.UOffsetT) *Response { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Response{} + x.Init(buf, n+offset) + return x +} + +func FinishResponseBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.Finish(offset) +} + +func GetSizePrefixedRootAsResponse(buf []byte, offset flatbuffers.UOffsetT) *Response { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Response{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func FinishSizePrefixedResponseBuffer(builder *flatbuffers.Builder, offset flatbuffers.UOffsetT) { + builder.FinishSizePrefixed(offset) +} + +func (rcv *Response) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Response) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Response) Status() ResponseStatus { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return ResponseStatus(rcv._tab.GetInt8(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *Response) MutateStatus(n ResponseStatus) bool { + return rcv._tab.MutateInt8Slot(4, int8(n)) +} + +func (rcv *Response) Envelope(obj *Envelope, j int) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + x := rcv._tab.Vector(o) + x += flatbuffers.UOffsetT(j) * 4 + x = rcv._tab.Indirect(x) + obj.Init(rcv._tab.Bytes, x) + return true + } + return false +} + +func (rcv *Response) EnvelopeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Response) Content(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Response) ContentLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Response) ContentBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Response) MutateContent(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func ResponseStart(builder *flatbuffers.Builder) { + builder.StartObject(3) +} +func ResponseAddStatus(builder *flatbuffers.Builder, status ResponseStatus) { + builder.PrependInt8Slot(0, int8(status), 0) +} +func ResponseAddEnvelope(builder *flatbuffers.Builder, envelope flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(envelope), 0) +} +func ResponseStartEnvelopeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func ResponseAddContent(builder *flatbuffers.Builder, content flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(content), 0) +} +func ResponseStartContentVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func ResponseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/generated/plugin/client/ResponseStatus.go b/generated/plugin/client/ResponseStatus.go new file mode 100644 index 00000000..40468223 --- /dev/null +++ b/generated/plugin/client/ResponseStatus.go @@ -0,0 +1,32 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package client + +import "strconv" + +type ResponseStatus int8 + +const ( + ResponseStatusDeny ResponseStatus = 0 + ResponseStatusAccept ResponseStatus = 1 + ResponseStatusModified ResponseStatus = 2 +) + +var EnumNamesResponseStatus = map[ResponseStatus]string{ + ResponseStatusDeny: "Deny", + ResponseStatusAccept: "Accept", + ResponseStatusModified: "Modified", +} + +var EnumValuesResponseStatus = map[string]ResponseStatus{ + "Deny": ResponseStatusDeny, + "Accept": ResponseStatusAccept, + "Modified": ResponseStatusModified, +} + +func (v ResponseStatus) String() string { + if s, ok := EnumNamesResponseStatus[v]; ok { + return s + } + return "ResponseStatus(" + strconv.FormatInt(int64(v), 10) + ")" +} diff --git a/go.mod b/go.mod index fc115b3b..97c9c99a 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-openapi/validate v0.22.3 github.com/go-sql-driver/mysql v1.7.1 github.com/golang-migrate/migrate/v4 v4.16.2 + github.com/google/flatbuffers v23.5.26+incompatible github.com/gookit/color v1.5.4 github.com/itchyny/gojq v0.12.13 github.com/jackc/pgx/v5 v5.5.0 diff --git a/go.sum b/go.sum index d414e78d..d7cb4e6f 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA= github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= +github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= diff --git a/go.work b/go.work index 34633556..77bc2aea 100644 --- a/go.work +++ b/go.work @@ -4,4 +4,5 @@ use ( . ./contrib/remote-access/client ./contrib/config-deployment/client + ./example/plugin ) diff --git a/justfile b/justfile index f9be6c72..9740d42c 100644 --- a/justfile +++ b/justfile @@ -116,8 +116,15 @@ _generate-ent: _generate-mockery: mockery --all +_generate-flatbuffers: + #!/usr/bin/env bash + set -euo pipefail + rm -rf generated/plugin + find fbs -name "*.fbs" | xargs flatc -g --gen-object-api --go-module-name github.com/siemens/wfx + gofumpt -l -w generated/plugin + # Generate code -generate: _generate-swagger _generate-ent _generate-mockery +generate: _generate-swagger _generate-ent _generate-mockery _generate-flatbuffers # Start PostgreSQL container postgres-start VERSION="15": diff --git a/middleware/logging/log.go b/middleware/logging/log.go index e4d6df69..5c76039e 100644 --- a/middleware/logging/log.go +++ b/middleware/logging/log.go @@ -9,10 +9,13 @@ package logging */ import ( + "bytes" "context" + "io" "net/http" "time" + "github.com/Southclaws/fault" "github.com/google/uuid" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -30,11 +33,15 @@ func (mw MW) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() reqID := uuid.New().String() + var path string + if r.URL != nil { + path = r.URL.Path + } contextLogger := log.With(). Str("reqID", reqID). Str("remoteAddr", r.RemoteAddr). Str("method", r.Method). - Str("path", r.URL.Path). + Str("path", path). Str("host", r.Host). Bool("tls", r.TLS != nil). Logger() @@ -45,13 +52,15 @@ func (mw MW) Wrap(next http.Handler) http.Handler { r = r.WithContext(ctx) if contextLogger.GetLevel() <= zerolog.TraceLevel { + request, err := PeekBody(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } myResponseWriter := newMyResponseWriter(w) - myRequestReader := newMyRequestReader(r) next.ServeHTTP(myResponseWriter, r) - request := myRequestReader.requestBody.Bytes() - contextLogger.Trace(). Bytes("request", request). Msg("Request") @@ -77,3 +86,19 @@ func LoggerFromCtx(ctx context.Context) zerolog.Logger { } return log.Logger } + +func PeekBody(r *http.Request) ([]byte, error) { + // consume request body + var request []byte + if r.Body != nil { + var err error + request, err = io.ReadAll(r.Body) + if err != nil { + return nil, fault.Wrap(err) + } + _ = r.Body.Close() + // restore original body for other middlewares + r.Body = io.NopCloser(bytes.NewBuffer(request)) + } + return request, nil +} diff --git a/middleware/logging/log_test.go b/middleware/logging/log_test.go index 263e8203..cc7e8958 100644 --- a/middleware/logging/log_test.go +++ b/middleware/logging/log_test.go @@ -10,6 +10,7 @@ package logging import ( "context" + "errors" "fmt" "io" "net/http" @@ -22,7 +23,8 @@ import ( ) func TestLog(t *testing.T) { - handler := MW{}.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mw := MW{} + handler := mw.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, client") })) @@ -37,6 +39,7 @@ func TestLog(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Hello, client\n", string(greeting)) + mw.Shutdown() } func TestLogDebug(t *testing.T) { @@ -68,3 +71,26 @@ func TestLoggerFomCtx_Default(t *testing.T) { actual := LoggerFromCtx(context.Background()) assert.Equal(t, log.Logger, actual) } + +type FaultyReadCloser struct{} + +func (r FaultyReadCloser) Read([]byte) (n int, err error) { + return 0, errors.New("failed to read") +} + +func (r FaultyReadCloser) Close() error { + return nil +} + +func TestPeekBody_ReadFailure(t *testing.T) { + var body FaultyReadCloser + r := &http.Request{Body: body} + _, err := PeekBody(r) + assert.NotNil(t, err) + + handler := MW{}.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, r) +} diff --git a/middleware/logging/reader.go b/middleware/logging/reader.go deleted file mode 100644 index 8fd0cd25..00000000 --- a/middleware/logging/reader.go +++ /dev/null @@ -1,41 +0,0 @@ -package logging - -/* - * SPDX-FileCopyrightText: 2023 Siemens AG - * - * SPDX-License-Identifier: Apache-2.0 - * - * Author: Michael Adler - */ - -import ( - "bytes" - "io" - "net/http" - - "github.com/Southclaws/fault" -) - -type requestReader struct { - requestBody *bytes.Buffer - originalReader io.ReadCloser - teeReader *io.Reader -} - -func newMyRequestReader(r *http.Request) requestReader { - var buf bytes.Buffer - tee := io.TeeReader(r.Body, &buf) - myReader := requestReader{requestBody: &buf, originalReader: r.Body, teeReader: &tee} - r.Body = myReader - return myReader -} - -// Read reads up to len(p) bytes into p. -func (r requestReader) Read(p []byte) (int, error) { - n, err := (*r.teeReader).Read(p) - return n, fault.Wrap(err) -} - -func (r requestReader) Close() error { - return fault.Wrap(r.originalReader.Close()) -} diff --git a/middleware/logging/reader_test.go b/middleware/logging/reader_test.go deleted file mode 100644 index 45cbbbc4..00000000 --- a/middleware/logging/reader_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package logging - -/* - * SPDX-FileCopyrightText: 2023 Siemens AG - * - * SPDX-License-Identifier: Apache-2.0 - * - * Author: Michael Adler - */ - -import ( - "bytes" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestReader(t *testing.T) { - body := new(bytes.Buffer) - body.WriteString("Hello world") - req := httptest.NewRequest(http.MethodGet, "http://localhost", body) - - out := make([]byte, 1024) - reader := newMyRequestReader(req) - defer reader.Close() - _, err := reader.Read(out) - assert.NoError(t, err) - assert.Equal(t, "Hello world", reader.requestBody.String()) -} diff --git a/middleware/plugin/disabled.go b/middleware/plugin/disabled.go new file mode 100644 index 00000000..bbc78612 --- /dev/null +++ b/middleware/plugin/disabled.go @@ -0,0 +1,26 @@ +//go:build !plugin + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "errors" + "net/http" + "time" +) + +type DummyMW struct{} + +func (mw *DummyMW) Wrap(next http.Handler) http.Handler { return next } +func (mw *DummyMW) Shutdown() {} + +func NewMiddleware(Plugin, time.Duration, time.Duration) (*DummyMW, error) { + return nil, errors.New("this version of wfx was compiled without support for plugins") +} diff --git a/middleware/plugin/disabled_test.go b/middleware/plugin/disabled_test.go new file mode 100644 index 00000000..5495e949 --- /dev/null +++ b/middleware/plugin/disabled_test.go @@ -0,0 +1,24 @@ +//go:build !plugin + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewMiddleware(t *testing.T) { + mw, err := NewMiddleware(nil, time.Second, time.Second) + assert.Nil(t, mw) + assert.ErrorContains(t, err, "this version of wfx was compiled without support for plugins") +} diff --git a/middleware/plugin/ioutil/io.go b/middleware/plugin/ioutil/io.go new file mode 100644 index 00000000..3ee79f48 --- /dev/null +++ b/middleware/plugin/ioutil/io.go @@ -0,0 +1,86 @@ +package ioutil + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "encoding/binary" + "errors" + "io" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/siemens/wfx/generated/plugin" +) + +// buffer size large enough for typical requests +const initialSize = 1 << 14 + +type Packer interface { + Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT +} + +// ReadRequest reads a request from the provided io.Reader. +func ReadRequest(r io.Reader) (*plugin.PluginRequestT, error) { + buf, err := readBytes(r) + if err != nil { + return nil, err + } + return plugin.GetRootAsPluginRequest(buf, 0).UnPack(), nil +} + +// ReadResponse reads a response from the provided io.Reader. +func ReadResponse(r io.Reader) (*plugin.PluginResponseT, error) { + buf, err := readBytes(r) + if err != nil { + return nil, err + } + return plugin.GetRootAsPluginResponse(buf, 0).UnPack(), nil +} + +// WriteRequest writes the given request to an io.Writer. +func WriteRequest(w io.Writer, req *plugin.PluginRequestT) error { + return writeHelper(w, req) +} + +// WriteResponse writes the given request to an io.Writer. +func WriteResponse(w io.Writer, resp *plugin.PluginResponseT) error { + return writeHelper(w, resp) +} + +func readPrefix(r io.Reader) (uint32, error) { + // see https://github.com/dvidelabs/flatcc/blob/master/doc/binary-format.md + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + return binary.LittleEndian.Uint32(buf), nil +} + +func readBytes(r io.Reader) ([]byte, error) { + size, err := readPrefix(r) + if err != nil { + return nil, err + } + buf := make([]byte, size) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return buf, nil +} + +func writeHelper(w io.Writer, packer Packer) error { + builder := flatbuffers.NewBuilder(initialSize) + end := packer.Pack(builder) + builder.FinishSizePrefixed(end) + buf := builder.FinishedBytes() + n, err := w.Write(buf) + if n != len(buf) { + return errors.New("incomplete write") + } + return err +} diff --git a/middleware/plugin/ioutil/io_test.go b/middleware/plugin/ioutil/io_test.go new file mode 100644 index 00000000..93cdd424 --- /dev/null +++ b/middleware/plugin/ioutil/io_test.go @@ -0,0 +1,95 @@ +package ioutil + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "bytes" + "errors" + "testing" + "testing/iotest" + + "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteAndReadRequest(t *testing.T) { + t.Parallel() + + expected := plugin.PluginRequestT{ + Cookie: 1, + Request: &client.RequestT{ + Action: client.ActionRead, + Destination: "http://localhost/foo/bar/offset=0&limit=42", + Envelope: []*client.EnvelopeT{ + {Name: "Foo", Values: []string{"Bar", "Baz"}}, + }, + }, + } + + buf := new(bytes.Buffer) + err := WriteRequest(buf, &expected) + require.NoError(t, err) + + actual, err := ReadRequest(buf) + + require.NoError(t, err) + assert.EqualValues(t, expected, *actual) +} + +func TestWriteAndReadResponse(t *testing.T) { + t.Parallel() + + expected := plugin.PluginResponseT{ + Cookie: 1, + Payload: &plugin.PayloadT{ + Type: plugin.Payloadgenerated_plugin_client_Request, + Value: &client.RequestT{ + Action: client.ActionRead, + Destination: "http://localhost/foo/bar/offset=0&limit=42", + Envelope: []*client.EnvelopeT{ + {Name: "Foo", Values: []string{"Bar", "Baz"}}, + }, + }, + }, + } + + buf := new(bytes.Buffer) + err := WriteResponse(buf, &expected) + require.NoError(t, err) + + actual, err := ReadResponse(buf) + + require.NoError(t, err) + assert.EqualValues(t, expected, *actual) +} + +func TestFaultyReader(t *testing.T) { + t.Parallel() + + myErr := errors.New("this is a fake error") + r := iotest.ErrReader(myErr) + + t.Run("ReadRequest", func(t *testing.T) { + t.Parallel() + + req, err := ReadRequest(r) + assert.ErrorIs(t, err, myErr) + assert.Nil(t, req) + }) + + t.Run("ReadResponse", func(t *testing.T) { + t.Parallel() + + req, err := ReadResponse(r) + assert.ErrorIs(t, err, myErr) + assert.Nil(t, req) + }) +} diff --git a/middleware/plugin/ioutil/main_test.go b/middleware/plugin/ioutil/main_test.go new file mode 100644 index 00000000..73268b0d --- /dev/null +++ b/middleware/plugin/ioutil/main_test.go @@ -0,0 +1,19 @@ +package ioutil + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/middleware/plugin/main_test.go b/middleware/plugin/main_test.go new file mode 100644 index 00000000..88497e65 --- /dev/null +++ b/middleware/plugin/main_test.go @@ -0,0 +1,19 @@ +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/middleware/plugin/middleware.go b/middleware/plugin/middleware.go new file mode 100644 index 00000000..d661a9c9 --- /dev/null +++ b/middleware/plugin/middleware.go @@ -0,0 +1,140 @@ +//go:build plugin + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "net/http" + "net/url" + + "github.com/Southclaws/fault" + "github.com/rs/zerolog/log" + genPlugin "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/siemens/wfx/middleware/logging" +) + +type MW struct { + plugin Plugin +} + +func NewMiddleware(plugin Plugin) (*MW, error) { + log.Debug().Str("plugin", plugin.Name()).Msg("Creating new plugin middleware") + + if err := plugin.Start(); err != nil { + _ = plugin.Wait() + log.Err(err).Msg("Failed to start plugin") + return nil, fault.Wrap(err) + } + mw := MW{plugin: plugin} + + go func() { + for { + p := mw.plugin + + err := p.Wait() + log.Warn().Err(err).Msg("Plugin stopped") + if p.IsStopped() { + log.Info().Msg("Plugin was stopped, not restarting it") + break + } + + if err := p.Start(); err != nil { + log.Err(err).Msg("Failed to start plugin") + break + } + } + }() + + return &mw, nil +} + +func (mw *MW) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log := logging.LoggerFromCtx(r.Context()).With().Str("plugin", mw.plugin.Name()).Logger() + + log.Debug().Msg("Sending request to plugin") + pending, err := mw.plugin.Enqueue(r) + if err != nil { + msg := "Failed to send plugin request" + log.Err(err).Msg(msg) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(msg)) + return + } + + resp, err := pending.Await() + if err != nil { + log.Err(err).Msg("Failed to receive plugin response") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + return + } + + if resp.Payload != nil { + switch resp.Payload.Type { + case genPlugin.Payloadgenerated_plugin_client_Response: + log.Info().Msg("Sending client response provided by plugin") + val := resp.Payload.Value.(*client.ResponseT) + for _, h := range val.Envelope { + for _, value := range h.Values { + w.Header().Add(h.Name, value) + } + } + switch val.Status { + case client.ResponseStatusAccept: + w.WriteHeader(http.StatusOK) + case client.ResponseStatusModified: + w.WriteHeader(http.StatusOK) + case client.ResponseStatusDeny: + w.WriteHeader(http.StatusForbidden) + } + _, _ = w.Write(val.Content) + return + case genPlugin.Payloadgenerated_plugin_client_Request: + log.Info().Msg("Request was modified by plugin") + // override http.Request with the response + val := resp.Payload.Value.(*client.RequestT) + + if parsedURL, err := url.Parse(val.Destination); err != nil { + log.Err(err).Str("destination", val.Destination).Msg("Failed to parse destination") + } else { + r.URL = parsedURL + } + + // delete existing headers + for k := range r.Header { + delete(r.Header, k) + } + if len(val.Envelope) > 0 { + if r.Header == nil { + r.Header = make(http.Header) + } + for _, h := range val.Envelope { + for _, value := range h.Values { + r.Header.Add(h.Name, value) + } + } + } + default: + log.Warn().Int("type", int(resp.Payload.Type)).Msg("Unhandled payload type") + } + } + log.Debug().Msg("Request may continue") + next.ServeHTTP(w, r) + }) +} + +func (mw *MW) Shutdown() { + if err := mw.plugin.Stop(); err != nil { + log.Err(err).Str("path", mw.plugin.Name()).Msg("Failed to stop plugin") + } + _ = mw.plugin.Wait() +} diff --git a/middleware/plugin/middleware_test.go b/middleware/plugin/middleware_test.go new file mode 100644 index 00000000..956da8c9 --- /dev/null +++ b/middleware/plugin/middleware_test.go @@ -0,0 +1,240 @@ +//go:build plugin + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "errors" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/stretchr/testify/assert" +) + +type StartFailPlugin struct{} + +func (p StartFailPlugin) Name() string { return "StartFailPlugin" } +func (p StartFailPlugin) Start() error { return errors.New("failed to start plugin") } +func (p StartFailPlugin) Stop() error { return nil } +func (p StartFailPlugin) IsStopped() bool { return true } +func (p StartFailPlugin) Wait() error { return nil } +func (p StartFailPlugin) Enqueue(*http.Request) (*PendingResponse, error) { + return nil, errors.New("not supported") +} + +func TestNewMiddleware_StartFails(t *testing.T) { + p := StartFailPlugin{} + mw, err := NewMiddleware(p) + assert.Error(t, err) + assert.Nil(t, mw) +} + +type TestPlugin struct { + mutex sync.RWMutex + counter uint64 + started bool + responses []PendingResponse + // error behavior + failEnqueue bool + failAwait bool +} + +func NewTestPlugin() *TestPlugin { + return &TestPlugin{ + responses: make([]PendingResponse, 0), + } +} + +func (p *TestPlugin) Name() string { return "TestPlugin" } + +func (p *TestPlugin) Start() error { + p.mutex.Lock() + defer p.mutex.Unlock() + p.started = true + return nil +} + +func (p *TestPlugin) Stop() error { + p.mutex.Lock() + defer p.mutex.Unlock() + p.started = false + return nil +} + +func (p *TestPlugin) IsStopped() bool { + p.mutex.RLock() + defer p.mutex.RUnlock() + return p.started +} +func (p *TestPlugin) Wait() error { return nil } + +func (p *TestPlugin) Enqueue(*http.Request) (*PendingResponse, error) { + if p.failEnqueue { + return nil, errors.New("enqueuing failed on purpose") + } + cookie := p.counter + p.counter++ + pending := PendingResponse{ + chResp: make(chan *plugin.PluginResponseT), + entry: &QueueEntry{ + request: &plugin.PluginRequestT{Cookie: cookie}, + }, + timeoutFn: func() {}, + } + if p.failAwait { + close(pending.chResp) + } + p.mutex.Lock() + p.responses = append(p.responses, pending) + p.mutex.Unlock() + return &pending, nil +} + +func TestNewMiddleware_ModifyRequest(t *testing.T) { + p := NewTestPlugin() + go func() { + for { + p.mutex.RLock() + n := len(p.responses) + p.mutex.RUnlock() + if n > 0 { + break + } + time.Sleep(time.Millisecond * 10) + } + p.mutex.Lock() + pending := p.responses[0] + p.responses = p.responses[1:] + p.mutex.Unlock() + + t.Log("Sending response") + pending.chResp <- &plugin.PluginResponseT{ + Cookie: pending.entry.request.Cookie, + Payload: &plugin.PayloadT{ + Type: plugin.Payloadgenerated_plugin_client_Request, + Value: &client.RequestT{ + Action: client.ActionRead, + Envelope: []*client.EnvelopeT{ + {Name: "User-Agent", Values: []string{"gotest"}}, + }, + Destination: "localhost/foo/bar", + }, + }, + } + }() + + mw, err := NewMiddleware(p) + assert.Nil(t, err) + + handler := mw.Wrap(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + + recorder := httptest.NewRecorder() + httpReq := &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + handler.ServeHTTP(recorder, httpReq) + mw.Shutdown() +} + +func TestNewMiddleware_SendResponse(t *testing.T) { + p := NewTestPlugin() + go func() { + for { + p.mutex.RLock() + n := len(p.responses) + p.mutex.RUnlock() + if n > 0 { + break + } + time.Sleep(time.Millisecond * 10) + } + pending := p.responses[0] + t.Log("Sending response") + pending.chResp <- &plugin.PluginResponseT{ + Cookie: pending.entry.request.Cookie, + Payload: &plugin.PayloadT{ + Type: plugin.Payloadgenerated_plugin_client_Response, + Value: &client.ResponseT{ + Status: client.ResponseStatusAccept, + Envelope: []*client.EnvelopeT{ + {Name: "User-Agent", Values: []string{"gotest"}}, + }, + Content: []byte{}, + }, + }, + } + }() + + mw, err := NewMiddleware(p) + assert.Nil(t, err) + + handler := mw.Wrap(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + + recorder := httptest.NewRecorder() + httpReq := &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + handler.ServeHTTP(recorder, httpReq) + mw.Shutdown() +} + +func TestNewMiddleware_FailEnqueue(t *testing.T) { + p := NewTestPlugin() + p.failEnqueue = true + mw, err := NewMiddleware(p) + assert.Nil(t, err) + + handler := mw.Wrap(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + + recorder := httptest.NewRecorder() + httpReq := &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + handler.ServeHTTP(recorder, httpReq) + mw.Shutdown() +} + +func TestNewMiddleware_FailAwait(t *testing.T) { + p := NewTestPlugin() + p.failAwait = true + mw, err := NewMiddleware(p) + assert.Nil(t, err) + + handler := mw.Wrap(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + + recorder := httptest.NewRecorder() + httpReq := &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + handler.ServeHTTP(recorder, httpReq) + mw.Shutdown() +} diff --git a/middleware/plugin/plugin.go b/middleware/plugin/plugin.go new file mode 100644 index 00000000..fb23d163 --- /dev/null +++ b/middleware/plugin/plugin.go @@ -0,0 +1,311 @@ +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "bufio" + "errors" + "io" + "net/http" + "os/exec" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/Southclaws/fault" + "github.com/rs/zerolog/log" + genPlugin "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/siemens/wfx/middleware/logging" + "github.com/siemens/wfx/middleware/plugin/ioutil" +) + +type Plugin interface { + Name() string + Start() error + Stop() error + IsStopped() bool + Wait() error + Enqueue(r *http.Request) (*PendingResponse, error) +} + +// FBPlugin is a plugin which communicates using FlatBuffer messages. +type FBPlugin struct { + path string + + queue chan *QueueEntry + queueClosed atomic.Bool + + responses map[uint64]PendingResponse + responsesMutex sync.Mutex + + cookieCounter atomic.Uint64 + cmd atomic.Pointer[exec.Cmd] + finished atomic.Bool + stopped atomic.Bool + + errSender chan error + errReceiver chan error +} + +type QueueEntry struct { + created time.Time + received time.Time + request *genPlugin.PluginRequestT +} + +func newQueueEntry(request *genPlugin.PluginRequestT) *QueueEntry { + return &QueueEntry{ + created: time.Now(), + request: request, + } +} + +// PendingResponse represents a response that is pending from the plugin. +// Use the Await() method to block and wait for the actual response. +type PendingResponse struct { + chResp chan *genPlugin.PluginResponseT + entry *QueueEntry + timeoutFn func() +} + +// NewFBPlugin creates a new plugin instance. In order to start the plugin, call +// the Start() function. +func NewFBPlugin(path string) *FBPlugin { + return &FBPlugin{path: path} +} + +// Start starts the plugin but does not wait for it to complete. +// +// After a successful call to Start the Wait method must be called in +// order to release associated system resources. +func (p *FBPlugin) Start() error { + if p.finished.Load() || p.stopped.Load() { + return errors.New("cannot start a finished plugin") + } + + log.Info().Str("path", p.path).Msg("Starting plugin") + cmd := exec.Command(p.path) + if cmd.Path == "" { + return errors.New("Plugin not found") + } + + // this ensures that a process group is created (needed to kill all child processes) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + p.queue = make(chan *QueueEntry) + p.responses = make(map[uint64]PendingResponse) + // we need a buffer because we only read from one of the channels (and then stop the processing pipeline) + p.errSender = make(chan error, 1) + p.errReceiver = make(chan error, 1) + + stdin, err := cmd.StdinPipe() + if err != nil { + return fault.Wrap(err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return fault.Wrap(err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return fault.Wrap(err) + } + + go p.sender(stdin) + go p.receiver(stdout) + go p.forwardLogs(stderr) + + go func() { + // if we receive some error, terminate the process; this results in the + // process being restarted + select { + case err := <-p.errSender: + log.Err(err).Msg("Received error from sender") + case err := <-p.errReceiver: + log.Err(err).Msg("Received error from receiver") + } + // do not call Stop() since we want the process to be restarted + _ = p.terminateProcess() + }() + + if err := cmd.Start(); err != nil { + return fault.Wrap(err) + } + log.Debug().Str("path", cmd.Path).Msg("Plugin started") + p.cmd.Store(cmd) + return nil +} + +func (p *FBPlugin) IsStopped() bool { + return p.stopped.Load() +} + +// Wait for the plugin to finish. +func (p *FBPlugin) Wait() error { + var err error + if cmd := p.cmd.Load(); cmd != nil { + err = cmd.Wait() + } + p.finished.Store(true) + + if !p.queueClosed.Swap(true) { + // ensure we only close the queue once + close(p.queue) + } + return fault.Wrap(err) +} + +func (p *FBPlugin) Name() string { + return p.path +} + +// Enqueue adds a new request to be sent to the plugin. +func (p *FBPlugin) Enqueue(r *http.Request) (*PendingResponse, error) { + cmd := p.cmd.Load() + if cmd == nil || cmd.Process == nil { + return nil, errors.New("plugin not started") + } + if p.finished.Load() || p.stopped.Load() { + return nil, errors.New("plugin finished, cannot send any more requests") + } + + cookie := p.cookieCounter.Add(1) + req, err := convertRequest(r, cookie) + if err != nil { + return nil, fault.Wrap(err) + } + + entry := newQueueEntry(req) + pr := PendingResponse{ + chResp: make(chan *genPlugin.PluginResponseT), + entry: entry, + timeoutFn: func() { p.responsesMutex.Lock(); delete(p.responses, cookie); p.responsesMutex.Unlock() }, + } + + p.responsesMutex.Lock() + p.responses[cookie] = pr + p.responsesMutex.Unlock() + + p.queue <- entry + log.Debug().Uint64("cookie", req.Cookie).Msg("Enqueued request") + return &pr, nil +} + +// Stop stops the plugin. +func (p *FBPlugin) Stop() error { + stopped := p.stopped.Swap(true) + if stopped || p.finished.Load() { + return nil + } + log.Info().Str("path", p.path).Msg("Stopping plugin") + if err := p.terminateProcess(); err != nil { + return fault.Wrap(err) + } + return nil +} + +func (p *FBPlugin) sender(w io.Writer) { + for entry := range p.queue { + if err := ioutil.WriteRequest(w, entry.request); err != nil { + p.errSender <- err + close(p.errSender) + return + } + log.Debug().Uint64("cookie", entry.request.Cookie).Msg("Request sent to plugin") + } +} + +func (p *FBPlugin) receiver(r io.Reader) { + for !p.finished.Load() { + resp, err := ioutil.ReadResponse(r) + if err != nil { + p.errReceiver <- err + close(p.errReceiver) + return + } + now := time.Now() + + cookie := resp.Cookie + p.responsesMutex.Lock() + pr, ok := p.responses[cookie] + delete(p.responses, cookie) + p.responsesMutex.Unlock() + if !ok { + log.Warn().Uint64("id", cookie).Msg("Response channel not found. Discarding plugin response.") + continue + } + pr.entry.received = now + pr.chResp <- resp + close(pr.chResp) // there can only be one response + } +} + +func (p *FBPlugin) forwardLogs(r io.Reader) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + log.Debug().Str("path", p.path).Str("line", line).Msg("Received log message") + } + if err := scanner.Err(); err != nil { + log.Warn().Err(err).Msg("Failed to read stderr from plugin") + } +} + +func convertRequest(r *http.Request, cookie uint64) (*genPlugin.PluginRequestT, error) { + envelope := make([]*client.EnvelopeT, 0, len(r.Header)) + for name, values := range r.Header { + header := client.EnvelopeT{Name: name, Values: values} + envelope = append(envelope, &header) + } + + req := genPlugin.PluginRequestT{ + Cookie: cookie, + Request: &client.RequestT{ + Action: httpMethodToAction(r.Method), + Destination: r.URL.String(), + Envelope: envelope, + }, + } + + body, _ := logging.PeekBody(r) + if body != nil { + req.Request.Content = body + } + return &req, nil +} + +// Await blocks and waits for the plugin response until a specified timeout is reached. +func (pr PendingResponse) Await() (*genPlugin.PluginResponseT, error) { + resp, ok := <-pr.chResp + duration := pr.entry.received.Sub(pr.entry.created) + if ok { + log.Debug(). + Uint64("cookie", pr.entry.request.Cookie). + Dur("duration", duration). + Msg("Received response from plugin") + return resp, nil + } + return nil, errors.New("response channel was closed prematurely") +} + +func httpMethodToAction(method string) client.Action { + switch method { + case http.MethodPost: + return client.ActionCreate + case http.MethodPut: + return client.ActionUpdate + case http.MethodPatch: + return client.ActionUpdate + case http.MethodDelete: + return client.ActionDelete + default: + return client.ActionRead + } +} diff --git a/middleware/plugin/plugin_test.go b/middleware/plugin/plugin_test.go new file mode 100644 index 00000000..ca002044 --- /dev/null +++ b/middleware/plugin/plugin_test.go @@ -0,0 +1,198 @@ +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "bytes" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "sync" + "testing" + + "github.com/siemens/wfx/generated/plugin" + "github.com/siemens/wfx/generated/plugin/client" + "github.com/siemens/wfx/middleware/plugin/ioutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPlugin(t *testing.T) { + p := NewFBPlugin("true") + assert.NotNil(t, p) +} + +func TestNewPluginEmpty(t *testing.T) { + p := NewFBPlugin("") + assert.NotNil(t, p) +} + +func TestStopWithoutStart(t *testing.T) { + p := NewFBPlugin("true") + err := p.Stop() + assert.NoError(t, err) +} + +func TestStop(t *testing.T) { + p := NewFBPlugin("cat") + + err := p.Start() + require.NoError(t, err) + + var g sync.WaitGroup + g.Add(1) + go func() { + defer g.Done() + err := p.Stop() + assert.NoError(t, err) + }() + + err = p.Wait() + assert.ErrorContains(t, err, "signal: terminated") + g.Wait() +} + +func TestEnqueueWithoutStarting(t *testing.T) { + p := NewFBPlugin("true") + + req := http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + resp, err := p.Enqueue(&req) + assert.ErrorContains(t, err, "plugin not started") + assert.Nil(t, resp) +} + +func TestEnqueue(t *testing.T) { + p := NewFBPlugin("cat") + err := p.Start() + require.NoError(t, err) + + headers := make(map[string][]string) + headers["Content-Type"] = []string{"application/json"} + + req := http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + Header: headers, + } + resp, err := p.Enqueue(&req) + require.NoError(t, err) + assert.Equal(t, uint64(1), resp.entry.request.Cookie) + + err = p.Stop() + assert.NoError(t, err) + _ = p.Wait() +} + +func TestAwait(t *testing.T) { + pluginResp := plugin.PluginResponseT{ + Cookie: 1, + Payload: &plugin.PayloadT{ + Type: plugin.Payloadgenerated_plugin_client_Request, + Value: &client.RequestT{ + Action: client.ActionRead, + Destination: "http://localhost/foo/bar", + }, + }, + } + httpReq := http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "localhost", + Path: "/foo", + }, + } + fname := createTrivialPlugin(t, &httpReq, &pluginResp) + + p := NewFBPlugin(fname) + err := p.Start() + require.NoError(t, err) + + resp, err := p.Enqueue(&httpReq) + require.NoError(t, err) + + actual, err := resp.Await() + require.NoError(t, err) + assert.Equal(t, pluginResp.Cookie, actual.Cookie) + + go func() { + err := p.Stop() + assert.NoError(t, err) + }() + + p.errSender <- errors.New("stop, test is over") + _ = p.Wait() +} + +func TestIsStopped(t *testing.T) { + p := NewFBPlugin("true") + assert.False(t, p.IsStopped()) + _ = p.Stop() + assert.True(t, p.IsStopped()) +} + +func TestName(t *testing.T) { + p := NewFBPlugin("true") + assert.Equal(t, "true", p.Name()) +} + +func createTrivialPlugin(t *testing.T, httpReq *http.Request, resp *plugin.PluginResponseT) string { + var blobFile string + { + f, err := os.CreateTemp("", "flatbuf.*.bin") + require.NoError(t, err) + t.Cleanup(func() { + _ = os.Remove(f.Name()) + }) + blobFile = f.Name() + + err = ioutil.WriteResponse(f, resp) + require.NoError(t, err) + + _ = f.Close() + } + + req, _ := convertRequest(httpReq, 1) + rawRequest := new(bytes.Buffer) + // figure out exact size of req which will be written to stdin + _ = ioutil.WriteRequest(rawRequest, req) + + scriptContent := fmt.Sprintf(`#!/bin/sh +while true; do + # read request + dd bs=1 count=%d 1>&2 + # write response + cat "%s" +done +`, rawRequest.Len(), blobFile) + + f, err := os.CreateTemp("", "protobuf.*.sh") + require.NoError(t, err) + t.Cleanup(func() { + _ = os.Remove(f.Name()) + }) + fname := f.Name() + _, err = f.WriteString(scriptContent) + require.NoError(t, err) + _ = f.Close() + + err = os.Chmod(fname, os.FileMode(0o700)) + require.NoError(t, err) + return fname +} diff --git a/middleware/plugin/process_unix.go b/middleware/plugin/process_unix.go new file mode 100644 index 00000000..cbae6a12 --- /dev/null +++ b/middleware/plugin/process_unix.go @@ -0,0 +1,64 @@ +//go:build !windows + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "syscall" + "time" + + "github.com/Southclaws/fault" + "github.com/rs/zerolog/log" +) + +const gracefulTimeout = 15 * time.Second + +func (p *FBPlugin) terminateProcess() error { + cmd := p.cmd.Load() + if cmd == nil || cmd.Process == nil { + return nil + } + pid := cmd.Process.Pid + + if err := syscall.Kill(-pid, 0); err != nil { + // process already gone, do nothing + return nil + } + + // signal is sent to *every* process in the process group + log.Debug().Int("pid", pid).Msg("Sending SIGTERM") + if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil { + return fault.Wrap(err) + } + + done := make(chan bool) + go func() { + for i := 0; i < 100; i++ { + if p.finished.Load() { + done <- true + break + } + time.Sleep(10 * time.Millisecond) + } + }() + + select { + case <-done: + log.Debug().Msg("Process terminated gracefully") + case <-time.After(gracefulTimeout): + // check if process is still alive + if err := syscall.Kill(-pid, 0); err == nil { + // process is still alive + log.Warn().Int("pid", pid).Msg("Process is still alive, sending SIGKILL") + _ = syscall.Kill(-pid, syscall.SIGKILL) + } + } + return nil +} diff --git a/middleware/plugin/process_windows.go b/middleware/plugin/process_windows.go new file mode 100644 index 00000000..e3fb46ac --- /dev/null +++ b/middleware/plugin/process_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package plugin + +/* + * SPDX-FileCopyrightText: 2023 Siemens AG + * + * SPDX-License-Identifier: Apache-2.0 + * + * Author: Michael Adler + */ + +import ( + "os" + + "github.com/Southclaws/fault" +) + +func (p *Plugin) terminateProcess() error { + pid := p.cmd.Process.Pid + proc, err := os.FindProcess(pid) + if err != nil { + return fault.Wrap(err) + } + // note: this does not kill child processes + return fault.Wrap(proc.Kill()) +} diff --git a/shell.nix b/shell.nix index daf07152..99909515 100644 --- a/shell.nix +++ b/shell.nix @@ -33,10 +33,11 @@ mkShell { just git go + flatbuffers ]; shellHook = '' - export GOFLAGS="-tags=sqlite,mysql,postgres,testing,integration" + export GOFLAGS="-tags=sqlite,mysql,postgres,testing,integration,plugin" export LUA_PATH="$(pwd)/hugo/filters/?.lua;;" export PATH="$(pwd):$PATH"