diff --git a/cmd/router/main.go b/cmd/router/main.go index ee752d03f..3d80b020f 100644 --- a/cmd/router/main.go +++ b/cmd/router/main.go @@ -366,6 +366,15 @@ var runCmd = &cobra.Command{ wg.Done() }(wg) + wg.Add(1) + go func(wg *sync.WaitGroup) { + err := app.ServceUnixSocket(ctx) + if err != nil { + spqrlog.Zero.Error().Err(err).Msg("") + } + wg.Done() + }(wg) + wg.Wait() return nil diff --git a/coordinator/app/app.go b/coordinator/app/app.go index 622a93a72..f56ce3e42 100644 --- a/coordinator/app/app.go +++ b/coordinator/app/app.go @@ -2,10 +2,14 @@ package app import ( "context" + "fmt" "net" + "os" + "path" "sync" "github.com/pg-sharding/spqr/pkg/spqrlog" + "github.com/pg-sharding/spqr/router/port" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -55,6 +59,12 @@ func (app *App) Run(withPsql bool) error { } }(wg) } + wg.Add(1) + go func(wg *sync.WaitGroup) { + if err := app.ServeUnixSocket(wg); err != nil { + spqrlog.Zero.Error().Err(err).Msg("") + } + }(wg) wg.Wait() @@ -104,7 +114,7 @@ func (app *App) ServeCoordinator(wg *sync.WaitGroup) error { go func() { defer app.sem.Release(1) - err := app.coordinator.ProcClient(context.TODO(), conn) + err := app.coordinator.ProcClient(context.TODO(), conn, port.DefaultRouterPortType) if err != nil { spqrlog.Zero.Error().Err(err).Msg("failed to serve client") } @@ -150,3 +160,39 @@ func (app *App) ServeGrpcApi(wg *sync.WaitGroup) error { return serv.Serve(listener) } + +func (app *App) ServeUnixSocket(wg *sync.WaitGroup) error { + defer wg.Done() + + if err := os.MkdirAll(config.UnixSocketDirectory, 0777); err != nil { + return err + } + socketPath := path.Join(config.UnixSocketDirectory, fmt.Sprintf(".s.PGSQL.%s", config.CoordinatorConfig().CoordinatorPort)) + lAddr := &net.UnixAddr{Name: socketPath, Net: "unix"} + listener, err := net.ListenUnix("unix", lAddr) + if err != nil { + return err + } + + for { + conn, err := listener.Accept() + if err != nil { + spqrlog.Zero.Error().Err(err).Msg("") + continue + } + + if err := app.sem.Acquire(context.Background(), 1); err != nil { + spqrlog.Zero.Error().Err(err).Msg("") + continue + } + + go func() { + defer app.sem.Release(1) + + err := app.coordinator.ProcClient(context.TODO(), conn, port.UnixSocketPortType) + if err != nil { + spqrlog.Zero.Error().Err(err).Msg("failed to serve client") + } + }() + } +} diff --git a/coordinator/provider/coordinator.go b/coordinator/provider/coordinator.go index ea8e945b7..9a3b39caf 100644 --- a/coordinator/provider/coordinator.go +++ b/coordinator/provider/coordinator.go @@ -1051,10 +1051,14 @@ func (qc *qdbCoordinator) RemoveTaskGroup(ctx context.Context) error { } // TODO : unit tests -func (qc *qdbCoordinator) PrepareClient(nconn net.Conn) (CoordinatorClient, error) { - cl := psqlclient.NewPsqlClient(nconn, port.DefaultRouterPortType, "") +func (qc *qdbCoordinator) PrepareClient(nconn net.Conn, pt port.RouterPortType) (CoordinatorClient, error) { + cl := psqlclient.NewPsqlClient(nconn, pt, "") - if err := cl.Init(qc.tlsconfig); err != nil { + tlsconfig := qc.tlsconfig + if pt == port.UnixSocketPortType { + tlsconfig = nil + } + if err := cl.Init(tlsconfig); err != nil { return nil, err } @@ -1094,8 +1098,8 @@ func (qc *qdbCoordinator) PrepareClient(nconn net.Conn) (CoordinatorClient, erro } // TODO : unit tests -func (qc *qdbCoordinator) ProcClient(ctx context.Context, nconn net.Conn) error { - cl, err := qc.PrepareClient(nconn) +func (qc *qdbCoordinator) ProcClient(ctx context.Context, nconn net.Conn, pt port.RouterPortType) error { + cl, err := qc.PrepareClient(nconn, pt) if err != nil { spqrlog.Zero.Error().Err(err).Msg("") return err diff --git a/docker/coordinator/Dockerfile b/docker/coordinator/Dockerfile index 05c4f90c0..21bb5a2f8 100644 --- a/docker/coordinator/Dockerfile +++ b/docker/coordinator/Dockerfile @@ -1,3 +1,5 @@ FROM spqr-base-image +RUN apt-get update && apt-get install -y postgresql-client + ENTRYPOINT /spqr/spqr-coordinator -c ${COORDINATOR_CONFIG=/spqr/docker/coordinator/cfg.yaml} \ No newline at end of file diff --git a/pkg/clientinteractor/interactor.go b/pkg/clientinteractor/interactor.go index 39c89b48e..cc5864165 100644 --- a/pkg/clientinteractor/interactor.go +++ b/pkg/clientinteractor/interactor.go @@ -18,6 +18,7 @@ import ( "github.com/pg-sharding/spqr/pkg/pool" "github.com/pg-sharding/spqr/pkg/shard" "github.com/pg-sharding/spqr/pkg/txstatus" + "github.com/pg-sharding/spqr/router/port" "github.com/pg-sharding/spqr/router/statistics" spqrparser "github.com/pg-sharding/spqr/yacc/console" @@ -31,7 +32,7 @@ import ( ) type Interactor interface { - ProcClient(ctx context.Context, nconn net.Conn) error + ProcClient(ctx context.Context, nconn net.Conn, pt port.RouterPortType) error } type PSQLInteractor struct { diff --git a/pkg/config/auth.go b/pkg/config/auth.go index 032fa62ff..ba10951a3 100644 --- a/pkg/config/auth.go +++ b/pkg/config/auth.go @@ -1,5 +1,9 @@ package config +const ( + UnixSocketDirectory = "/var/run/postgresql" +) + type AuthMethod string const ( diff --git a/router/app/app.go b/router/app/app.go index 4a66d5140..6a89ecf8e 100644 --- a/router/app/app.go +++ b/router/app/app.go @@ -2,10 +2,14 @@ package app import ( "context" + "fmt" "net" + "os" + "path" "sync" reuse "github.com/libp2p/go-reuseport" + "github.com/pg-sharding/spqr/pkg/config" "github.com/pg-sharding/spqr/pkg/spqrlog" rgrpc "github.com/pg-sharding/spqr/router/grpc" "github.com/pg-sharding/spqr/router/instance" @@ -108,3 +112,27 @@ func (app *App) ServeGrpcApi(ctx context.Context) error { server.GracefulStop() return nil } + +func (app *App) ServceUnixSocket(ctx context.Context) error { + if err := os.MkdirAll(config.UnixSocketDirectory, 0777); err != nil { + return err + } + socketPath := path.Join(config.UnixSocketDirectory, fmt.Sprintf(".s.PGSQL.%s", app.spqr.Config().RouterPort)) + lAddr := &net.UnixAddr{Name: socketPath, Net: "unix"} + listener, err := net.ListenUnix("unix", lAddr) + if err != nil { + return err + } + defer func(listener net.Listener) { + _ = listener.Close() + }(listener) + + spqrlog.Zero.Info(). + Msg("SPQR Router is ready by unix socket") + go func() { + _ = app.spqr.Run(ctx, listener, port.UnixSocketPortType) + }() + + <-ctx.Done() + return nil +} diff --git a/router/port/port.go b/router/port/port.go index 8d26c212d..21b0c66f2 100644 --- a/router/port/port.go +++ b/router/port/port.go @@ -8,4 +8,6 @@ const ( RORouterPortType = RouterPortType(1) ADMRouterPortType = RouterPortType(2) + + UnixSocketPortType = RouterPortType(3) ) diff --git a/router/rulerouter/rulerouter.go b/router/rulerouter/rulerouter.go index fe3774408..71523bbaa 100644 --- a/router/rulerouter/rulerouter.go +++ b/router/rulerouter/rulerouter.go @@ -3,10 +3,11 @@ package rulerouter import ( "crypto/tls" "fmt" - "github.com/pg-sharding/spqr/pkg/models/spqrerror" "net" "sync" + "github.com/pg-sharding/spqr/pkg/models/spqrerror" + "github.com/jackc/pgx/v5/pgproto3" "github.com/pg-sharding/spqr/pkg/auth" "github.com/pg-sharding/spqr/pkg/client" @@ -171,7 +172,11 @@ func NewRouter(tlsconfig *tls.Config, rcfg *config.Router, notifier *notifier.No func (r *RuleRouterImpl) PreRoute(conn net.Conn, pt port.RouterPortType) (rclient.RouterClient, error) { cl := rclient.NewPsqlClient(conn, pt, r.Config().Qr.DefaultRouteBehaviour) - if err := cl.Init(r.tlsconfig); err != nil { + tlsConfig := r.tlsconfig + if pt == port.UnixSocketPortType { + tlsConfig = nil + } + if err := cl.Init(tlsConfig); err != nil { return cl, err }