From 6af6613f4b13601e636c35e7e19cdfba45f6de77 Mon Sep 17 00:00:00 2001 From: Andreas Auernhammer Date: Wed, 7 Aug 2024 13:46:26 +0200 Subject: [PATCH] kms: add `Client.WriteDB` API (#21) This commit adds a `WriteDB` for writing a database snapshot to a KMS server. This is commonly used for restoring from a snapshot. It also adjusts the `ReadDB` API to take a `ReadDBRequest`. Signed-off-by: Andreas Auernhammer --- kms/client.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++---- kms/request.go | 18 ++++++++++++ 2 files changed, 87 insertions(+), 6 deletions(-) diff --git a/kms/client.go b/kms/client.go index 255a143..7964a59 100644 --- a/kms/client.go +++ b/kms/client.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "errors" "fmt" + "io/fs" "net" "net/http" "net/url" @@ -672,33 +673,50 @@ func (c *Client) RemoveNode(ctx context.Context, req *RemoveClusterNodeRequest) } // ReadDB returns a snapshot of current KMS server database. +// If req.Host is empty, one of the client's hosts is used. // The returned ReadDBResponse must be closed by the caller. // // It requires SysAdmin privileges. // // The returned error is of type *HostError. -func (c *Client) ReadDB(ctx context.Context) (*ReadDBResponse, error) { +func (c *Client) ReadDB(ctx context.Context, req *ReadDBRequest) (*ReadDBResponse, error) { const ( Method = http.MethodGet Path = api.PathDB StatusOK = http.StatusOK ) - url, host, err := c.lb.URL(Path) + var ( + err error + reqURL string + host = req.Host + ) + if host == "" { + reqURL, host, err = c.lb.URL(Path) + } else { + reqURL, err = url.JoinPath(httpsURL(host), Path) + } if err != nil { return nil, hostError(host, err) } - r, err := http.NewRequestWithContext(ctx, Method, url, nil) + + r, err := http.NewRequestWithContext(ctx, Method, reqURL, nil) if err != nil { return nil, hostError(host, err) } r.Header.Add(headers.Accept, headers.ContentTypeAppAny) r.Header.Add(headers.Accept, headers.ContentEncodingGZIP) - resp, err := c.client.Do(r) + var resp *http.Response + if req.Host == "" { + resp, err = c.client.Do(r) // Without req.Host, use the client LB. + } else { + resp, err = c.direct.Do(r) // With an explicit req.Host, don't use client LB. + } if err != nil { return nil, hostError(host, err) } + if resp.StatusCode != StatusOK { defer resp.Body.Close() return nil, hostError(host, readError(resp)) @@ -723,9 +741,54 @@ func (c *Client) ReadDB(ctx context.Context) (*ReadDBResponse, error) { }, nil } +// WriteDB writes a database snapshot to req.Host. The given +// host must not be empty. +// +// It requires SysAdmin privileges. +// +// The returned error is of type *HostError. +func (c *Client) WriteDB(ctx context.Context, req *WriteDBRequest) error { + const ( + Method = http.MethodPut + Path = api.PathDB + StatusOK = http.StatusOK + ) + + host := req.Host + reqURL, err := url.JoinPath(httpsURL(host), Path) + if err != nil { + return hostError(host, err) + } + + r, err := http.NewRequestWithContext(ctx, Method, reqURL, req.Body) + if err != nil { + return hostError(host, err) + } + r.Header.Add(headers.Accept, headers.ContentTypeBinary) + + if r.ContentLength == 0 && r.Body != nil && r.Body != http.NoBody { + r.ContentLength = -1 // Indicate that the content length is unknown + } + if f, ok := req.Body.(fs.File); ok { // If the body is a file we can set a content length + if stat, err := f.Stat(); err == nil { + r.ContentLength = stat.Size() + } + } + + resp, err := c.direct.Do(r) + if err != nil { + return hostError(host, err) + } + defer resp.Body.Close() + + if resp.StatusCode != StatusOK { + return hostError(host, readError(resp)) + } + return nil +} + // Logs returns a stream of server log records from req.Host. -// If req.Host is empty, the first host of the client's host -// list is used. +// If req.Host is empty, one of the client's hosts is used. // // The LogRequest specifies which log records are fetched. // For example, only records with a certain log level or diff --git a/kms/request.go b/kms/request.go index 71c2f41..edd62ce 100644 --- a/kms/request.go +++ b/kms/request.go @@ -6,6 +6,7 @@ package kms import ( "errors" + "io" "log/slog" "net/netip" "time" @@ -145,6 +146,23 @@ func (r *ClusterStatusRequest) MarshalPB(*pb.ClusterStatusRequest) error { retur // UnmarshalPB initializes the ClusterStatusRequest from its protobuf representation. func (r *ClusterStatusRequest) UnmarshalPB(*pb.ClusterStatusRequest) error { return nil } +// ReadDBRequest contains options for fetching a KMS server database +// snapshot. +type ReadDBRequest struct { + // Host from which a database snapshot should be taken. + Host string +} + +// WriteDBRequest contains options for restoring a KMS server's +// database from a snapshot. +type WriteDBRequest struct { + // Host to which the database snapshot is sent. + Host string + + // The database snapshot. + Body io.Reader +} + // ProfileRequest contains options for customizing performance profiling. type ProfileRequest struct { // Host on which performance profiling should be enabled.