From 442feb3c94e39118af4d31e49beaaa18ea2f8ee0 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 23 Sep 2024 19:32:59 -0400 Subject: [PATCH] fix: handle credential refresh and record not found errors Signed-off-by: Grant Linville --- pkg/sqlite/sqlite.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pkg/sqlite/sqlite.go b/pkg/sqlite/sqlite.go index 1131846..526cd16 100644 --- a/pkg/sqlite/sqlite.go +++ b/pkg/sqlite/sqlite.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "errors" "fmt" "log" "os" @@ -99,6 +100,20 @@ func (s Sqlite) Add(creds *credentials.Credentials) error { return fmt.Errorf("failed to encrypt credential: %w", err) } + // First, we need to check if a credential with this serverURL already exists. + // If it does, delete it first. + // This would normally happen during a credential refresh. + var existing GptscriptCredential + if err := s.db.Where("server_url = ?", cred.ServerURL).First(&existing).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to get existing credential: %w", err) + } + } else { + if err := s.db.Delete(&existing).Error; err != nil { + return fmt.Errorf("failed to delete existing credential: %w", err) + } + } + if err := s.db.Create(&cred).Error; err != nil { return fmt.Errorf("failed to create credential: %w", err) } @@ -124,6 +139,9 @@ func (s Sqlite) Get(serverURL string) (string, string, error) { err error ) if err = s.db.Where("server_url = ?", serverURL).First(&cred).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", "", nil + } return "", "", fmt.Errorf("failed to get credential: %w", err) } @@ -141,6 +159,9 @@ func (s Sqlite) List() (map[string]string, error) { err error ) if err = s.db.Find(&creds).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } return nil, fmt.Errorf("failed to list credentials: %w", err) }