Skip to content

Commit

Permalink
kms: add Client.WriteDB API (#21)
Browse files Browse the repository at this point in the history
This commit adds a `WriteDB` for writing a database snapshot
to a KMS server. This is commonly used for restoring from a
snapshot.

It also adjusts the `ReadDB` API to take a `ReadDBRequest`.

Signed-off-by: Andreas Auernhammer <[email protected]>
  • Loading branch information
aead authored Aug 7, 2024
1 parent bc91058 commit 6af6613
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 6 deletions.
75 changes: 69 additions & 6 deletions kms/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io/fs"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -672,33 +673,50 @@ func (c *Client) RemoveNode(ctx context.Context, req *RemoveClusterNodeRequest)
}

// ReadDB returns a snapshot of current KMS server database.
// If req.Host is empty, one of the client's hosts is used.
// The returned ReadDBResponse must be closed by the caller.
//
// It requires SysAdmin privileges.
//
// The returned error is of type *HostError.
func (c *Client) ReadDB(ctx context.Context) (*ReadDBResponse, error) {
func (c *Client) ReadDB(ctx context.Context, req *ReadDBRequest) (*ReadDBResponse, error) {
const (
Method = http.MethodGet
Path = api.PathDB
StatusOK = http.StatusOK
)

url, host, err := c.lb.URL(Path)
var (
err error
reqURL string
host = req.Host
)
if host == "" {
reqURL, host, err = c.lb.URL(Path)
} else {
reqURL, err = url.JoinPath(httpsURL(host), Path)
}
if err != nil {
return nil, hostError(host, err)
}
r, err := http.NewRequestWithContext(ctx, Method, url, nil)

r, err := http.NewRequestWithContext(ctx, Method, reqURL, nil)
if err != nil {
return nil, hostError(host, err)
}
r.Header.Add(headers.Accept, headers.ContentTypeAppAny)
r.Header.Add(headers.Accept, headers.ContentEncodingGZIP)

resp, err := c.client.Do(r)
var resp *http.Response
if req.Host == "" {
resp, err = c.client.Do(r) // Without req.Host, use the client LB.
} else {
resp, err = c.direct.Do(r) // With an explicit req.Host, don't use client LB.
}
if err != nil {
return nil, hostError(host, err)
}

if resp.StatusCode != StatusOK {
defer resp.Body.Close()
return nil, hostError(host, readError(resp))
Expand All @@ -723,9 +741,54 @@ func (c *Client) ReadDB(ctx context.Context) (*ReadDBResponse, error) {
}, nil
}

// WriteDB writes a database snapshot to req.Host. The given
// host must not be empty.
//
// It requires SysAdmin privileges.
//
// The returned error is of type *HostError.
func (c *Client) WriteDB(ctx context.Context, req *WriteDBRequest) error {
const (
Method = http.MethodPut
Path = api.PathDB
StatusOK = http.StatusOK
)

host := req.Host
reqURL, err := url.JoinPath(httpsURL(host), Path)
if err != nil {
return hostError(host, err)
}

r, err := http.NewRequestWithContext(ctx, Method, reqURL, req.Body)
if err != nil {
return hostError(host, err)
}
r.Header.Add(headers.Accept, headers.ContentTypeBinary)

if r.ContentLength == 0 && r.Body != nil && r.Body != http.NoBody {
r.ContentLength = -1 // Indicate that the content length is unknown
}
if f, ok := req.Body.(fs.File); ok { // If the body is a file we can set a content length
if stat, err := f.Stat(); err == nil {
r.ContentLength = stat.Size()
}
}

resp, err := c.direct.Do(r)
if err != nil {
return hostError(host, err)
}
defer resp.Body.Close()

if resp.StatusCode != StatusOK {
return hostError(host, readError(resp))
}
return nil
}

// Logs returns a stream of server log records from req.Host.
// If req.Host is empty, the first host of the client's host
// list is used.
// If req.Host is empty, one of the client's hosts is used.
//
// The LogRequest specifies which log records are fetched.
// For example, only records with a certain log level or
Expand Down
18 changes: 18 additions & 0 deletions kms/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kms

import (
"errors"
"io"
"log/slog"
"net/netip"
"time"
Expand Down Expand Up @@ -145,6 +146,23 @@ func (r *ClusterStatusRequest) MarshalPB(*pb.ClusterStatusRequest) error { retur
// UnmarshalPB initializes the ClusterStatusRequest from its protobuf representation.
func (r *ClusterStatusRequest) UnmarshalPB(*pb.ClusterStatusRequest) error { return nil }

// ReadDBRequest contains options for fetching a KMS server database
// snapshot.
type ReadDBRequest struct {
// Host from which a database snapshot should be taken.
Host string
}

// WriteDBRequest contains options for restoring a KMS server's
// database from a snapshot.
type WriteDBRequest struct {
// Host to which the database snapshot is sent.
Host string

// The database snapshot.
Body io.Reader
}

// ProfileRequest contains options for customizing performance profiling.
type ProfileRequest struct {
// Host on which performance profiling should be enabled.
Expand Down

0 comments on commit 6af6613

Please sign in to comment.