diff --git a/cmd/kwil-cli/cmds/common/authinfo.go b/cmd/kwil-cli/cmds/common/authinfo.go index 0e4f3938e..f5509f468 100644 --- a/cmd/kwil-cli/cmds/common/authinfo.go +++ b/cmd/kwil-cli/cmds/common/authinfo.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "os" "path/filepath" + "strings" "time" "github.com/kwilteam/kwil-db/cmd/kwil-cli/config" @@ -77,6 +79,30 @@ func convertToHttpCookie(c cookie) *http.Cookie { } } +// getDomain returns the domain of the URL. +func getDomain(target string) (string, error) { + if target == "" { + return "", fmt.Errorf("target is empty") + } + + if !(strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://")) { + return "", fmt.Errorf("target missing scheme") + } + + parsedTarget, err := url.Parse(target) + if err != nil { + return "", fmt.Errorf("parse target: %w", err) + } + + return parsedTarget.Scheme + "://" + parsedTarget.Host, nil +} + +// getCookieIdentifier returns a unique identifier for a cookie, base64 encoded. +func getCookieIdentifier(domain string, userIdentifier []byte) string { + return base64.StdEncoding.EncodeToString( + append([]byte(domain+"_"), userIdentifier...)) +} + // PersistedCookies is a set of Gateway Auth cookies that can be saved to a file. // It maps a base64 user identifier to a cookie, ensuring only one cookie per wallet. // It uses a custom cookie type that is json serializable. @@ -85,7 +111,7 @@ type PersistedCookies map[string]cookie // LoadPersistedCookie loads a persisted cookie from the auth file. // It will look up the cookie for the given user identifier. // If nothing is found, it returns nil, nil. -func LoadPersistedCookie(authFile string, userIdentifier []byte) (*http.Cookie, error) { +func LoadPersistedCookie(authFile string, domain string, userIdentifier []byte) (*http.Cookie, error) { if _, err := os.Stat(authFile); os.IsNotExist(err) { return nil, nil } @@ -101,7 +127,7 @@ func LoadPersistedCookie(authFile string, userIdentifier []byte) (*http.Cookie, return nil, fmt.Errorf("unmarshal kgw auth file: %w", err) } - b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier) + b64Identifier := getCookieIdentifier(domain, userIdentifier) cookie := aInfo[b64Identifier] return convertToHttpCookie(cookie), nil @@ -109,7 +135,8 @@ func LoadPersistedCookie(authFile string, userIdentifier []byte) (*http.Cookie, // SaveCookie saves the cookie to auth file. // It will overwrite the cookie if the address already exists. -func SaveCookie(authFile string, userIdentifier []byte, originCookie *http.Cookie) error { +func SaveCookie(authFile string, domain string, userIdentifier []byte, originCookie *http.Cookie) error { + b64Identifier := getCookieIdentifier(domain, userIdentifier) cookie := convertToCookie(originCookie) authInfoBytes, err := utils.ReadOrCreateFile(authFile) @@ -117,8 +144,6 @@ func SaveCookie(authFile string, userIdentifier []byte, originCookie *http.Cooki return fmt.Errorf("read kgw auth file: %w", err) } - b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier) - var aInfo PersistedCookies if len(authInfoBytes) == 0 { aInfo = make(PersistedCookies) @@ -144,14 +169,12 @@ func SaveCookie(authFile string, userIdentifier []byte, originCookie *http.Cooki // DeleteCookie will delete a cookie that exists for a given user identifier. // If no cookie exists for the user identifier, it will do nothing. -func DeleteCookie(authFile string, userIdentifier []byte) error { +func DeleteCookie(authFile string, domain string, userIdentifier []byte) error { authInfoBytes, err := utils.ReadOrCreateFile(authFile) if err != nil { return fmt.Errorf("read kgw auth file: %w", err) } - b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier) - var aInfo PersistedCookies if len(authInfoBytes) == 0 { aInfo = make(PersistedCookies) @@ -161,6 +184,8 @@ func DeleteCookie(authFile string, userIdentifier []byte) error { return fmt.Errorf("unmarshal kgw auth file: %w", err) } } + + b64Identifier := getCookieIdentifier(domain, userIdentifier) delete(aInfo, b64Identifier) jsonBytes, err := json.MarshalIndent(&aInfo, "", " ") diff --git a/cmd/kwil-cli/cmds/common/authinfo_test.go b/cmd/kwil-cli/cmds/common/authinfo_test.go index 57060cfcd..da1134d4c 100644 --- a/cmd/kwil-cli/cmds/common/authinfo_test.go +++ b/cmd/kwil-cli/cmds/common/authinfo_test.go @@ -9,6 +9,45 @@ import ( "github.com/stretchr/testify/assert" ) +func TestLoadKGWAuthInfo_without_domain(t *testing.T) { + // this test just to show what the old behavior was + + ckA := http.Cookie{ + Name: "AAA", + Value: "AAA", + Path: "AAA", + Domain: "AAA", + Expires: time.Date(2023, 10, 27, 15, 46, 58, 651387237, time.UTC), + } + + ckB := http.Cookie{ + Name: "BBB", + Value: "BBB", + Path: "BBB", + Domain: "BBB", + Expires: time.Date(2023, 10, 27, 15, 46, 58, 651387237, time.UTC), + } + + var err error + authFile := filepath.Join(t.TempDir(), "auth.json") + domain := "" + + // authn on site A + err = SaveCookie(authFile, domain, []byte("0x123"), &ckA) + assert.NoError(t, err) + + // authn on site B + err = SaveCookie(authFile, domain, []byte("0x123"), &ckB) + assert.NoError(t, err) + + got, err := LoadPersistedCookie(authFile, domain, []byte("0x123")) + assert.NoError(t, err) + + // ckA has been overwritten by ckB + assert.NotEqualValues(t, &ckA, got) + assert.EqualValues(t, &ckB, got) +} + func TestLoadKGWAuthInfo(t *testing.T) { ck := http.Cookie{ Name: "test", @@ -27,12 +66,72 @@ func TestLoadKGWAuthInfo(t *testing.T) { var err error authFile := filepath.Join(t.TempDir(), "auth.json") + domain := "https://kgw.kwil.com" - err = SaveCookie(authFile, []byte("0x123"), &ck) + err = SaveCookie(authFile, domain, []byte("0x123"), &ck) assert.NoError(t, err) - got, err := LoadPersistedCookie(authFile, []byte("0x123")) + got, err := LoadPersistedCookie(authFile, domain, []byte("0x123")) assert.NoError(t, err) assert.EqualValues(t, &ck, got) } + +func Test_getDomain(t *testing.T) { + type args struct { + target string + } + tests := []struct { + name string + args args + wantErr bool + wantDoamin string + }{ + // TODO: Add test cases. + { + name: "empty string", + args: args{ + target: "", + }, + wantErr: true, + wantDoamin: "", + }, + { + name: "localhost with port", + args: args{ + target: "http://localhost:8080/api", + }, + wantDoamin: "http://localhost:8080", + }, + { + name: "https localhost with port", + args: args{ + target: "https://localhost:8080/api/", + }, + wantDoamin: "https://localhost:8080", + }, + { + name: "http example.com", + args: args{ + target: "http://example.com/a/b/c", + }, + wantDoamin: "http://example.com", + }, + { + name: "https example.com", + args: args{ + target: "https://example.com/a/b/c", + }, + wantDoamin: "https://example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, err := getDomain(tt.args.target) + if tt.wantErr { + assert.Errorf(t, err, "getDomain(%v)", tt.args.target) + } + assert.Equalf(t, tt.wantDoamin, domain, "getDomain(%v)", tt.args.target) + }) + } +} diff --git a/cmd/kwil-cli/cmds/common/roundtripper.go b/cmd/kwil-cli/cmds/common/roundtripper.go index da164c6ec..2ac1c3b07 100644 --- a/cmd/kwil-cli/cmds/common/roundtripper.go +++ b/cmd/kwil-cli/cmds/common/roundtripper.go @@ -98,14 +98,19 @@ func DialClient(ctx context.Context, cmd *cobra.Command, flags uint8, fn RoundTr return fn(ctx, client, conf) } - cookie, err := LoadPersistedCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity()) + providerDomain, err := getDomain(conf.Provider) + if err != nil { + return err + } + + cookie, err := LoadPersistedCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity()) if err == nil && cookie != nil { // if setting fails, then don't do fail usage- failure likely means that the client has // switched providers, and the cookie is no longer valid. The gatewayclient will re-authenticate. // delete the cookie if it is invalid err = client.SetAuthCookie(cookie) if err != nil { - err2 := DeleteCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity()) + err2 := DeleteCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity()) if err2 != nil { return fmt.Errorf("failed to delete cookie: %w", err2) } @@ -123,7 +128,7 @@ func DialClient(ctx context.Context, cmd *cobra.Command, flags uint8, fn RoundTr return nil } - err = SaveCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity(), cookie) + err = SaveCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity(), cookie) if err != nil { return fmt.Errorf("save cookie: %w", err) }