diff --git a/src/scitokens_cache.cpp b/src/scitokens_cache.cpp index 944f990..12536c1 100644 --- a/src/scitokens_cache.cpp +++ b/src/scitokens_cache.cpp @@ -95,38 +95,53 @@ std::string get_cache_file() { return keycache_file; } -void remove_issuer_entry(sqlite3 *db, const std::string &issuer, - bool new_transaction) { - - if (new_transaction) - sqlite3_exec(db, "BEGIN", 0, 0, 0); +// Remove a given issuer from the database. Starts a new transaction +// if `new_transaction` is true. +// If a failure occurs, then this function returns nonzero and closes +// the database handle. +int remove_issuer_entry(sqlite3 *db, const std::string &issuer, + bool new_transaction) { + + int rc; + if (new_transaction) { + if ((rc = sqlite3_exec(db, "BEGIN", 0, 0, 0)) != SQLITE_OK) { + sqlite3_close(db); + return -1; + } + } sqlite3_stmt *stmt; - int rc = sqlite3_prepare_v2(db, "DELETE FROM keycache WHERE issuer = ?", -1, - &stmt, NULL); + rc = sqlite3_prepare_v2(db, "DELETE FROM keycache WHERE issuer = ?", -1, + &stmt, NULL); if (rc != SQLITE_OK) { sqlite3_close(db); - return; + return -1; } if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) { sqlite3_finalize(stmt); sqlite3_close(db); - return; + return -1; } rc = sqlite3_step(stmt); if (rc != SQLITE_DONE) { sqlite3_finalize(stmt); sqlite3_close(db); - return; + return -1; } sqlite3_finalize(stmt); - if (new_transaction) - sqlite3_exec(db, "COMMIT", 0, 0, 0); + if (new_transaction) { + if ((rc = sqlite3_exec(db, "COMMIT", 0, 0, 0)) != SQLITE_OK) { + sqlite3_close(db); + return -1; + } + } + + return 0; } } // namespace @@ -170,27 +185,35 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer, picojson::value json_obj; auto err = picojson::parse(json_obj, metadata); if (!err.empty() || !json_obj.is()) { - remove_issuer_entry(db, issuer, true); + if (remove_issuer_entry(db, issuer, true) != 0) { + return false; + } sqlite3_close(db); return false; } auto top_obj = json_obj.get(); auto iter = top_obj.find("jwks"); if (iter == top_obj.end() || !iter->second.is()) { - remove_issuer_entry(db, issuer, true); + if (remove_issuer_entry(db, issuer, true) != 0) { + return false; + } sqlite3_close(db); return false; } auto keys_local = iter->second; iter = top_obj.find("expires"); if (iter == top_obj.end() || !iter->second.is()) { - remove_issuer_entry(db, issuer, true); + if (remove_issuer_entry(db, issuer, true) != 0) { + return false; + } sqlite3_close(db); return false; } auto expiry = iter->second.get(); if (now > expiry) { - remove_issuer_entry(db, issuer, true); + if (remove_issuer_entry(db, issuer, true) != 0) { + return false; + } sqlite3_close(db); return false; } @@ -238,9 +261,14 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer, return false; } - sqlite3_exec(db, "BEGIN", 0, 0, 0); + if ((rc = sqlite3_exec(db, "BEGIN", 0, 0, 0)) != SQLITE_OK) { + sqlite3_close(db); + return false; + } - remove_issuer_entry(db, issuer, false); + if (remove_issuer_entry(db, issuer, false) != 0) { + return false; + } sqlite3_stmt *stmt; rc = sqlite3_prepare_v2(db, "INSERT INTO keycache VALUES (?, ?)", -1, &stmt, @@ -270,10 +298,13 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer, sqlite3_close(db); return false; } + sqlite3_finalize(stmt); - sqlite3_exec(db, "COMMIT", 0, 0, 0); + if (sqlite3_exec(db, "COMMIT", 0, 0, 0) != SQLITE_OK) { + sqlite3_close(db); + return false; + } - sqlite3_finalize(stmt); sqlite3_close(db); return true; }