Skip to content

Commit

Permalink
kgw jsonrpc client
Browse files Browse the repository at this point in the history
  • Loading branch information
Yaiba committed May 10, 2024
1 parent 151be7e commit 56f6519
Show file tree
Hide file tree
Showing 20 changed files with 345 additions and 281 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
run:
timeout: 10m
go: '1.21'

issues:
exclude-dirs:
Expand Down
27 changes: 14 additions & 13 deletions cmd/kwil-cli/cmds/common/authinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const (
kgwAuthTokenFileName = "auth.json"
)

// KGWAuthTokenFilePath returns the path to the file that stores the Gateway Auth cookies.
// KGWAuthTokenFilePath returns the path to the file that stores the Gateway Authn cookies.
func KGWAuthTokenFilePath() string {
return filepath.Join(config.DefaultConfigDir, kgwAuthTokenFileName)
}
Expand Down Expand Up @@ -79,6 +79,7 @@ func convertToHttpCookie(c cookie) *http.Cookie {
}
}

// PersistedCookies is a set of Gateway Authn cookies that can be saved to a file.
// getDomain returns the domain of the URL.
func getDomain(target string) (string, error) {
if target == "" {
Expand Down Expand Up @@ -108,7 +109,7 @@ func getCookieIdentifier(domain string, userIdentifier []byte) string {
// It uses a custom cookie type that is json serializable.
type PersistedCookies map[string]cookie

// LoadPersistedCookie loads a persisted cookie from the auth file.
// LoadPersistedCookie loads a persisted cookie from the authn file.
// It will look up the cookie for the given user identifier.
// If nothing is found, it returns nil, nil.
func LoadPersistedCookie(authFile string, domain string, userIdentifier []byte) (*http.Cookie, error) {
Expand All @@ -118,13 +119,13 @@ func LoadPersistedCookie(authFile string, domain string, userIdentifier []byte)

file, err := utils.CreateOrOpenFile(authFile)
if err != nil {
return nil, fmt.Errorf("open kgw auth file: %w", err)
return nil, fmt.Errorf("open kgw authn file: %w", err)
}

var aInfo PersistedCookies
err = json.NewDecoder(file).Decode(&aInfo)
if err != nil {
return nil, fmt.Errorf("unmarshal kgw auth file: %w", err)
return nil, fmt.Errorf("unmarshal kgw authn file: %w", err)
}

b64Identifier := getCookieIdentifier(domain, userIdentifier)
Expand All @@ -133,15 +134,15 @@ func LoadPersistedCookie(authFile string, domain string, userIdentifier []byte)
return convertToHttpCookie(cookie), nil
}

// SaveCookie saves the cookie to auth file.
// SaveCookie saves the cookie to authn file.
// It will overwrite the cookie if the address already exists.
func SaveCookie(authFile string, domain string, userIdentifier []byte, originCookie *http.Cookie) error {
b64Identifier := getCookieIdentifier(domain, userIdentifier)
cookie := convertToCookie(originCookie)

authInfoBytes, err := utils.ReadOrCreateFile(authFile)
if err != nil {
return fmt.Errorf("read kgw auth file: %w", err)
return fmt.Errorf("read kgw authn file: %w", err)
}

var aInfo PersistedCookies
Expand All @@ -150,19 +151,19 @@ func SaveCookie(authFile string, domain string, userIdentifier []byte, originCoo
} else {
err = json.Unmarshal(authInfoBytes, &aInfo)
if err != nil {
return fmt.Errorf("unmarshal kgw auth file: %w", err)
return fmt.Errorf("unmarshal kgw authn file: %w", err)
}
}
aInfo[b64Identifier] = cookie

jsonBytes, err := json.MarshalIndent(&aInfo, "", " ")
if err != nil {
return fmt.Errorf("marshal kgw auth info: %w", err)
return fmt.Errorf("marshal kgw authn info: %w", err)
}

err = os.WriteFile(authFile, jsonBytes, 0600)
if err != nil {
return fmt.Errorf("write kgw auth file: %w", err)
return fmt.Errorf("write kgw authn file: %w", err)
}
return nil
}
Expand All @@ -172,7 +173,7 @@ func SaveCookie(authFile string, domain string, userIdentifier []byte, originCoo
func DeleteCookie(authFile string, domain string, userIdentifier []byte) error {
authInfoBytes, err := utils.ReadOrCreateFile(authFile)
if err != nil {
return fmt.Errorf("read kgw auth file: %w", err)
return fmt.Errorf("read kgw authn file: %w", err)
}

var aInfo PersistedCookies
Expand All @@ -181,7 +182,7 @@ func DeleteCookie(authFile string, domain string, userIdentifier []byte) error {
} else {
err = json.Unmarshal(authInfoBytes, &aInfo)
if err != nil {
return fmt.Errorf("unmarshal kgw auth file: %w", err)
return fmt.Errorf("unmarshal kgw authn file: %w", err)
}
}

Expand All @@ -190,12 +191,12 @@ func DeleteCookie(authFile string, domain string, userIdentifier []byte) error {

jsonBytes, err := json.MarshalIndent(&aInfo, "", " ")
if err != nil {
return fmt.Errorf("marshal kgw auth info: %w", err)
return fmt.Errorf("marshal kgw authn info: %w", err)
}

err = utils.WriteFile(authFile, jsonBytes)
if err != nil {
return fmt.Errorf("write kgw auth file: %w", err)
return fmt.Errorf("write kgw authn file: %w", err)
}
return nil
}
Empty file added core/gatewayclient/README.md
Empty file.
70 changes: 42 additions & 28 deletions core/gatewayclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ import (
"github.com/kwilteam/kwil-db/core/client"
rpcClient "github.com/kwilteam/kwil-db/core/rpc/client"
"github.com/kwilteam/kwil-db/core/rpc/client/gateway"
httpGateway "github.com/kwilteam/kwil-db/core/rpc/client/gateway/http"
httpTx "github.com/kwilteam/kwil-db/core/rpc/client/user/http"
jsonrpcGateway "github.com/kwilteam/kwil-db/core/rpc/client/gateway/jsonrpc"
rpcclient "github.com/kwilteam/kwil-db/core/rpc/client/user/jsonrpc"
jsonrpc "github.com/kwilteam/kwil-db/core/rpc/json"
clientType "github.com/kwilteam/kwil-db/core/types/client"
gatewayTypes "github.com/kwilteam/kwil-db/core/types/gateway"
)

// GatewayClient is a client that is made to interact with a kwil gateway.
// It inherits the functionality of the main Kwil client, but also provides
// authentication cookies to the gateway.
// It automatically handles the authentication process with the gateway.
type GatewayClient struct {
client.Client
client.Client // tx client

target *url.URL

httpClient *http.Client
gatewayClient gateway.GatewayClient
conn *http.Client // the "connection"
gatewayClient gateway.Client

gatewaySigner GatewayAuthSignFunc // a hook for when the gateway authentication is needed

Expand Down Expand Up @@ -79,7 +79,7 @@ func NewClient(ctx context.Context, target string, opts *GatewayOptions) (*Gatew

persistJar := &customAuthCookieJar{jar: cookieJar}

httpClient := &http.Client{
httpConn := &http.Client{
Jar: persistJar,
}

Expand All @@ -88,25 +88,32 @@ func NewClient(ctx context.Context, target string, opts *GatewayOptions) (*Gatew
return nil, fmt.Errorf("parse target: %w", err)
}

// TODO: use jsonrpc, and maybe make an option in GatewayOptions to use the
// old endpoints rather than json-rpc?
txClient := httpTx.NewClient(parsedTarget, httpTx.WithHTTPClient(httpClient))
jsonrpcClientOpts := []rpcclient.Opts{}
if options != nil {
jsonrpcClientOpts = append(jsonrpcClientOpts,
rpcclient.WithLogger(options.Logger),
// so txClient and gatewayClient can share the connection
rpcclient.WithHTTPClient(httpConn),
)
}

txClient := rpcclient.NewClient(parsedTarget, jsonrpcClientOpts...)
coreClient, err := client.WrapClient(ctx, txClient, &options.Options)
if err != nil {
return nil, fmt.Errorf("wrap client: %w", err)
}

gatewayRPC, err := httpGateway.NewGatewayHttpClient(parsedTarget, httpGateway.WithHTTPClient(httpClient))
gatewayClient, err := jsonrpcGateway.NewClient(parsedTarget,
gateway.WithHTTPClient(httpConn))
if err != nil {
return nil, fmt.Errorf("create gateway rpc client: %w", err)
}

g := &GatewayClient{
Client: *coreClient,
httpClient: httpClient,
conn: httpConn,
gatewaySigner: options.AuthSignFunc,
gatewayClient: gatewayRPC,
gatewayClient: gatewayClient,
target: parsedTarget,
}

Expand Down Expand Up @@ -137,10 +144,18 @@ func (c *GatewayClient) Call(ctx context.Context, dbid string, action string, in
if err == nil {
return res, nil
}
if !errors.Is(err, rpcClient.ErrUnauthorized) {
return nil, err

var jsonRPCErr *jsonrpc.Error
if errors.As(err, &jsonRPCErr) {
if jsonRPCErr.Code != jsonrpc.ErrorKGWNotAuthorized {
return nil, err
}
}

//if !errors.Is(err, rpcClient.ErrUnauthorized) {
// return nil, err
//}

// we need to authenticate
err = c.authenticate(ctx)
if err != nil {
Expand All @@ -153,21 +168,20 @@ func (c *GatewayClient) Call(ctx context.Context, dbid string, action string, in

// authenticate authenticates the client with the gateway.
func (c *GatewayClient) authenticate(ctx context.Context) error {
authParam, err := c.gatewayClient.GetAuthParameter(ctx)
if errors.Is(err, rpcClient.ErrNotFound) {
return fmt.Errorf("failed to get auth parameter. are you sure you're talking to a gateway? err: %w", err)
} else if err != nil {
return fmt.Errorf("get auth parameter: %w", err)
}

authURI, err := url.JoinPath(c.target.String(), gateway.AuthEndpoint)
authParam, err := c.gatewayClient.GetAuthnParameter(ctx)
if err != nil {
return fmt.Errorf("join path: %w", err)
if errors.Is(err, rpcClient.ErrNotFound) {
return fmt.Errorf("failed to get auth parameter. are you sure you're talking to a gateway? err: %w", err)
}
return fmt.Errorf("get authn parameter: %w", err)
}

// remove trailing slash, avoid the confusing case like "http://example.com/" != "http://example.com"
// This is also done in the kgw, https://github.com/kwilteam/kgw/pull/42
targetDomain := strings.TrimSuffix(c.target.String(), "/")
// With switching to JSON rpc in KGW, the domain should not include the path.
targetDomain := c.target.Scheme + "://" + c.target.Host
authURI := targetDomain + gateway.AuthnEndpoint
//targetDomain := strings.TrimSuffix(c.target.String(), "/")
// backward compatibility if the Domain is not returned by the gateway
// Those fields are returned from kgw in https://github.com/kwilteam/kgw/pull/40
if authParam.Domain != "" && authParam.Domain != targetDomain {
Expand Down Expand Up @@ -196,13 +210,13 @@ func (c *GatewayClient) authenticate(ctx context.Context) error {
}

// send the auth request
err = c.gatewayClient.Auth(ctx, &gatewayTypes.GatewayAuth{
err = c.gatewayClient.Authn(ctx, &gateway.AuthnRequest{
Nonce: authParam.Nonce,
Sender: c.Signer.Identity(),
Signature: sig,
})
if err != nil {
return fmt.Errorf("gateway auth: %w", err)
return fmt.Errorf("gateway authn: %w", err)
}

return nil
Expand All @@ -229,7 +243,7 @@ func (c *GatewayClient) SetAuthCookie(cookie *http.Cookie) error {
return fmt.Errorf("cookie name %s not valid", cookie.Name)
}

c.httpClient.Jar.SetCookies(c.target, []*http.Cookie{cookie})
c.conn.Jar.SetCookies(c.target, []*http.Cookie{cookie})

c.authCookie = cookie

Expand Down
14 changes: 8 additions & 6 deletions core/gatewayclient/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@ import (
"fmt"

"github.com/kwilteam/kwil-db/core/crypto/auth"
types "github.com/kwilteam/kwil-db/core/types/gateway"
"github.com/kwilteam/kwil-db/core/rpc/client/gateway"
)

// kgw handles the authentication with KGW provider.
// KGW is a LB that als provides authentication for Kwil. It only supports HTTP.
// KGW is a LB that provides authentication for Kwil. It only supports HTTP.
// This is not part of core Kwil API, thus we implement it here.
//
// The authentication process is as follows:
// 1. Client starts an authentication session to KGW provider by sending a GET
// request to /auth endpoint, and KGW will return authn parameters.
// 1. Client starts an authentication session to KGW provider by sending `kgw.authn_param`
// request to /rpc/v1 endpoint, and KGW will return authn parameters.
// 2. Client composes a message using returned parameters and configuration,
// then presents the message to the user to sign.
// 3. Then user signs the message and passes the signature back to the client.
// 3. User signs the message and passes the signature back to the client.
// 4. Client identifies itself by sending a POST request to the KGW provider,
// and KGW will return a cookie if the signature is valid.
// 5. Following requests to the KGW provider should include the cookie for
// authentication required endpoints.
//
// Note in step 3, it's

const (
kgwAuthVersion = "1"
Expand All @@ -41,7 +43,7 @@ func defaultGatewayAuthSignFunc(message string, signer auth.Signer) (*auth.Signa
// composeGatewayAuthMessage composes the SIWE-like message to sign.
// param is the result of GET request for authentication.
// ALl the other parameters are expected from config.
func composeGatewayAuthMessage(param *types.GatewayAuthParameter, domain string, uri string,
func composeGatewayAuthMessage(param *gateway.AuthnParameterResponse, domain string, uri string,
version string, chainID string) string {
var msg bytes.Buffer
msg.WriteString(
Expand Down
3 changes: 1 addition & 2 deletions core/gatewayclient/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ type GatewayOptions struct {
// DefaultOptions returns the default options for the gateway client.
func DefaultOptions() *GatewayOptions {
return &GatewayOptions{
Options: *clientType.DefaultOptions(),

Options: *clientType.DefaultOptions(),
AuthSignFunc: defaultGatewayAuthSignFunc,
}
}
Expand Down
5 changes: 3 additions & 2 deletions core/rpc/client/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ var (
// ErrUnauthorized is returned when the client is not authenticated
// It is the equivalent of http status code 401
ErrUnauthorized = errors.New("unauthorized")
ErrNotFound = errors.New("not found") // resource not found
ErrMethodNotFound = errors.New("method not found")
ErrNotFound = errors.New("not found")
ErrInvalidRequest = errors.New("invalid request")
ErrNotAllowed = errors.New("not allowed")
)

// RPCError is a common error type used by any RPC client implementation to
Expand Down
Loading

0 comments on commit 56f6519

Please sign in to comment.