From d4720e4f522ecb169f89b561b0223623661b70a6 Mon Sep 17 00:00:00 2001 From: Stuart Douglas Date: Tue, 26 Nov 2024 09:07:32 +1100 Subject: [PATCH] fix: fix pg proxy issues, and remove hard coded DSNs (#3501) --- backend/controller/controller.go | 12 +-- .../provisioner_integration_test.go | 5 +- .../db/echodb/20241103205514_postgres.sql | 4 + backend/provisioner/testdata/go/echo/echo.go | 11 +-- ftl-project.toml | 9 --- internal/pgproxy/pgproxy.go | 76 +++++++++++-------- internal/pgproxy/pgproxy_test.go | 8 +- .../ftl/deployment/DatasourceProcessor.java | 13 +++- .../xyz/block/ftl/runtime/FTLController.java | 12 +++ .../xyz/block/ftl/runtime/FTLRecorder.java | 5 ++ 10 files changed, 90 insertions(+), 65 deletions(-) create mode 100644 backend/provisioner/testdata/go/echo/db/echodb/20241103205514_postgres.sql diff --git a/backend/controller/controller.go b/backend/controller/controller.go index b38c1e3a8d..c1e95d1a3e 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -1063,18 +1063,8 @@ func (s *Service) CreateDeployment(ctx context.Context, req *connect.Request[ftl return nil, fmt.Errorf("invalid module schema: %w", err) } - for _, d := range module.Decls { - if db, ok := d.(*schema.Database); ok && db.Runtime != nil { - key := dsnSecretKey(module.Name, db.Name) - - if err := s.sm.Set(ctx, configuration.NewRef(module.Name, key), db.Runtime.DSN); err != nil { - return nil, fmt.Errorf("could not set database secret %s: %w", key, err) - } - logger.Infof("Database declaration: %s -> %s type %s", db.Name, db.Runtime.DSN, db.Type) - } - } - dkey, err := s.dal.CreateDeployment(ctx, ms.Runtime.Language, module, artefacts) + if err != nil { logger.Errorf(err, "Could not create deployment") return nil, fmt.Errorf("could not create deployment: %w", err) diff --git a/backend/provisioner/provisioner_integration_test.go b/backend/provisioner/provisioner_integration_test.go index 6d965de3b9..e893a5f09d 100644 --- a/backend/provisioner/provisioner_integration_test.go +++ b/backend/provisioner/provisioner_integration_test.go @@ -6,11 +6,12 @@ import ( "fmt" "testing" - in "github.com/TBD54566975/ftl/internal/integration" "github.com/alecthomas/assert/v2" + + in "github.com/TBD54566975/ftl/internal/integration" ) -func TestDeploymentThrougDevProvisionerCreatePostgresDB(t *testing.T) { +func TestDeploymentThroughDevProvisionerCreatePostgresDB(t *testing.T) { in.Run(t, in.WithFTLConfig("./ftl-project.toml"), in.CopyModule("echo"), diff --git a/backend/provisioner/testdata/go/echo/db/echodb/20241103205514_postgres.sql b/backend/provisioner/testdata/go/echo/db/echodb/20241103205514_postgres.sql new file mode 100644 index 0000000000..37bc9c1ddd --- /dev/null +++ b/backend/provisioner/testdata/go/echo/db/echodb/20241103205514_postgres.sql @@ -0,0 +1,4 @@ +-- migrate:up +CREATE TABLE messages( message TEXT ); +-- migrate:down +DROP TABLE messages; \ No newline at end of file diff --git a/backend/provisioner/testdata/go/echo/echo.go b/backend/provisioner/testdata/go/echo/echo.go index c4a98d24ef..b9c5d1ad32 100644 --- a/backend/provisioner/testdata/go/echo/echo.go +++ b/backend/provisioner/testdata/go/echo/echo.go @@ -19,19 +19,12 @@ func (EchoDBConfig) Name() string { return "echodb" } // //ftl:verb export func Echo(ctx context.Context, req string, db ftl.DatabaseHandle[EchoDBConfig]) (string, error) { - _, err := db.Get(ctx).Exec(`CREATE TABLE IF NOT EXISTS messages( - message TEXT - );`) + _, err := db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req) if err != nil { return "", err } - _, err = db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req) - if err != nil { - return "", err - } - - rows, err := db.Get(ctx).Query(`SELECT message FROM messages;`) + rows, err := db.Get(ctx).Query(`SELECT DISTINCT message FROM messages;`) if err != nil { return "", err } diff --git a/ftl-project.toml b/ftl-project.toml index f4967c7d0e..fe3a4b6fdf 100644 --- a/ftl-project.toml +++ b/ftl-project.toml @@ -9,18 +9,9 @@ disable-ide-integration = true key = "inline://InZhbHVlIg" [modules] - [modules.database] - [modules.database.secrets] - FTL_DSN_DATABASE_TESTDB = "inline://InBvc3RncmVzOi8vMTI3LjAuMC4xOjE1NDMyL2RhdGFiYXNlX3Rlc3RkYj9zc2xtb2RlPWRpc2FibGVcdTAwMjZ1c2VyPXBvc3RncmVzXHUwMDI2cGFzc3dvcmQ9c2VjcmV0Ig" [modules.echo] [modules.echo.configuration] default = "inline://ImFub255bW91cyI" - [modules.mysql] - [modules.mysql.secrets] - FTL_DSN_MYSQL_TESTDB = "inline://InJvb3Q6c2VjcmV0QHRjcCgxMjcuMC4wLjE6MTMzMDYpL215c3FsX3Rlc3RkYj9hbGxvd05hdGl2ZVBhc3N3b3Jkcz1UcnVlIg" - [modules.test] - [modules.test.configuration] - [modules.test.secrets] [commands] startup = ["echo 'FTL startup command ⚡️'"] diff --git a/internal/pgproxy/pgproxy.go b/internal/pgproxy/pgproxy.go index 37ce587cc2..6b4dbbec70 100644 --- a/internal/pgproxy/pgproxy.go +++ b/internal/pgproxy/pgproxy.go @@ -77,9 +77,11 @@ func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error { // It will block until the connection is closed. func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) { defer conn.Close() + ctx, cancel := context.WithCancel(ctx) + defer cancel() logger := log.FromContext(ctx) - logger.Infof("new connection established: %s", conn.RemoteAddr()) + logger.Debugf("new connection established: %s", conn.RemoteAddr()) backend, startup, err := connectBackend(ctx, conn) if err != nil { @@ -90,30 +92,33 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr logger.Infof("client disconnected without startup message: %s", conn.RemoteAddr()) return } - logger.Debugf("startup message: %+v", startup) - logger.Debugf("backend connected: %s", conn.RemoteAddr()) + logger.Tracef("startup message: %+v", startup) + logger.Tracef("backend connected: %s", conn.RemoteAddr()) - frontend, err := connectFrontend(ctx, connectionFn, startup) + hijacked, err := connectFrontend(ctx, connectionFn, startup) if err != nil { // try again, in case there was a credential rotation - logger.Warnf("failed to connect frontend: %s, trying again", err) + logger.Debugf("failed to connect frontend: %s, trying again", err) - frontend, err = connectFrontend(ctx, connectionFn, startup) + hijacked, err = connectFrontend(ctx, connectionFn, startup) if err != nil { handleBackendError(ctx, backend, err) return } } + backend.Send(&pgproto3.AuthenticationOk{}) logger.Debugf("frontend connected") + for key, value := range hijacked.ParameterStatuses { + backend.Send(&pgproto3.ParameterStatus{Name: key, Value: value}) + } - backend.Send(&pgproto3.AuthenticationOk{}) - backend.Send(&pgproto3.ReadyForQuery{}) + backend.Send(&pgproto3.ReadyForQuery{TxStatus: 'I'}) if err := backend.Flush(); err != nil { logger.Errorf(err, "failed to flush backend authentication ok") return } - if err := proxy(ctx, backend, frontend); err != nil { + if err := proxy(ctx, backend, hijacked.Frontend); err != nil { logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err) return } @@ -171,7 +176,7 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp } } -func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) { +func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgconn.HijackedConn, error) { dsn, err := connectionFn(ctx, startup.Parameters) if err != nil { return nil, fmt.Errorf("failed to construct dsn: %w", err) @@ -181,38 +186,61 @@ func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup * if err != nil { return nil, fmt.Errorf("failed to connect to backend: %w", err) } - frontend := pgproto3.NewFrontend(conn.Conn(), conn.Conn()) - - return frontend, nil + hijacked, err := conn.Hijack() + if err != nil { + return nil, fmt.Errorf("failed to hijack backend: %w", err) + } + return hijacked, nil } func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error { logger := log.FromContext(ctx) - frontendMessages := make(chan pgproto3.BackendMessage) - backendMessages := make(chan pgproto3.FrontendMessage) errors := make(chan error, 2) go func() { for { msg, err := backend.Receive() + select { + case <-ctx.Done(): + return + default: + } if err != nil { errors <- fmt.Errorf("failed to receive backend message: %w", err) return } logger.Tracef("backend message: %T", msg) - backendMessages <- msg + frontend.Send(msg) + err = frontend.Flush() + if err != nil { + errors <- fmt.Errorf("failed to receive backend message: %w", err) + return + } + if _, ok := msg.(*pgproto3.Terminate); ok { + return + } } }() go func() { for { msg, err := frontend.Receive() + select { + case <-ctx.Done(): + return + default: + } if err != nil { errors <- fmt.Errorf("failed to receive frontend message: %w", err) return } logger.Tracef("frontend message: %T", msg) - frontendMessages <- msg + backend.Send(msg) + err = backend.Flush() + if err != nil { + errors <- fmt.Errorf("failed to receive backend message: %w", err) + return + } } }() @@ -220,20 +248,6 @@ func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Fr select { case <-ctx.Done(): return fmt.Errorf("context done: %w", ctx.Err()) - case msg := <-backendMessages: - frontend.Send(msg) - if err := frontend.Flush(); err != nil { - return fmt.Errorf("failed to flush frontend message: %w", err) - } - - if _, ok := msg.(*pgproto3.Terminate); ok { - return nil - } - case msg := <-frontendMessages: - backend.Send(msg) - if err := backend.Flush(); err != nil { - return fmt.Errorf("failed to flush backend message: %w", err) - } case err := <-errors: return err } diff --git a/internal/pgproxy/pgproxy_test.go b/internal/pgproxy/pgproxy_test.go index 2ef6633e5a..5d73f0d7f3 100644 --- a/internal/pgproxy/pgproxy_test.go +++ b/internal/pgproxy/pgproxy_test.go @@ -5,11 +5,12 @@ import ( "net" "testing" + "github.com/alecthomas/assert/v2" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/TBD54566975/ftl/internal/dev" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/pgproxy" - "github.com/alecthomas/assert/v2" - "github.com/jackc/pgx/v5/pgproto3" ) func TestPgProxy(t *testing.T) { @@ -48,6 +49,9 @@ func TestPgProxy(t *testing.T) { assert.NoError(t, frontend.Flush()) assertResponseType[*pgproto3.AuthenticationOk](t, frontend) + for range 13 { + assertResponseType[*pgproto3.ParameterStatus](t, frontend) + } assertResponseType[*pgproto3.ReadyForQuery](t, frontend) }) diff --git a/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java b/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java index a198a3d697..a629504d5e 100644 --- a/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java +++ b/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java @@ -9,10 +9,14 @@ import io.quarkus.agroal.spi.JdbcDataSourceBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.GeneratedResourceBuildItem; import io.quarkus.deployment.builditem.SystemPropertyBuildItem; import xyz.block.ftl.runtime.FTLDatasourceCredentials; +import xyz.block.ftl.runtime.FTLRecorder; import xyz.block.ftl.runtime.config.FTLConfigSource; +import xyz.block.ftl.v1.ModuleContextResponse; import xyz.block.ftl.v1.schema.Database; import xyz.block.ftl.v1.schema.Decl; @@ -21,10 +25,12 @@ public class DatasourceProcessor { private static final Logger log = Logger.getLogger(DatasourceProcessor.class); @BuildStep + @Record(ExecutionTime.STATIC_INIT) public SchemaContributorBuildItem registerDatasources( List datasources, BuildProducer systemPropProducer, - BuildProducer generatedResourceBuildItemBuildProducer) { + BuildProducer generatedResourceBuildItemBuildProducer, + FTLRecorder recorder) { log.infof("Processing %d datasource annotations into decls", datasources.size()); List decls = new ArrayList<>(); List namedDatasources = new ArrayList<>(); @@ -37,6 +43,11 @@ public SchemaContributorBuildItem registerDatasources( // FTL and quarkus use slightly different names dbKind = "postgres"; } + if (dbKind.equals("mysql")) { + recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.MYSQL); + } else { + recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.POSTGRES); + } //default name is which is not a valid name String sanitisedName = ds.getName().replace("<", "").replace(">", ""); //we use a dynamic credentials provider diff --git a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java index f51ca8d98e..ccc429e35c 100644 --- a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java +++ b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java @@ -5,7 +5,9 @@ import java.time.Duration; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; @@ -36,6 +38,8 @@ public class FTLController implements LeaseClient { private static volatile FTLController controller; + private final Map databases = new ConcurrentHashMap<>(); + /** * TODO: look at how init should work, this is terrible and will break dev mode */ @@ -71,6 +75,10 @@ public static FTLController instance() { verbService = VerbServiceGrpc.newStub(channel); } + public void registerDatabase(String name, ModuleContextResponse.DBType type) { + databases.put(name, type); + } + public byte[] getSecret(String secretName) { var context = getModuleContext(); if (context.containsSecrets(secretName)) { @@ -88,6 +96,10 @@ public byte[] getConfig(String secretName) { } public Datasource getDatasource(String name) { + if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) { + var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS"); + return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl"); + } List databasesList = getModuleContext().getDatabasesList(); for (var i : databasesList) { if (i.getName().equals(name)) { diff --git a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java index 1ec9cf0ca1..f3c497777d 100644 --- a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java +++ b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java @@ -18,6 +18,7 @@ import xyz.block.ftl.runtime.http.FTLHttpHandler; import xyz.block.ftl.runtime.http.HTTPVerbInvoker; import xyz.block.ftl.v1.CallRequest; +import xyz.block.ftl.v1.ModuleContextResponse; @Recorder public class FTLRecorder { @@ -171,4 +172,8 @@ public void run() { } }); } + + public void registerDatabase(String dbKind, ModuleContextResponse.DBType name) { + FTLController.instance().registerDatabase(dbKind, name); + } }