Skip to content

Commit

Permalink
compute_ctl: Add endpoint that allows extensions to be installed (#9344)
Browse files Browse the repository at this point in the history
Adds endpoint to install extensions:

**POST** `/extensions`
```
{"extension":"pg_sessions_jwt","database":"neondb","version":"1.0.0"}
```

Will be used by `local-proxy`.
Example, for the JWT authentication to work the database needs to have
the pg_session_jwt extension and also to enable JWT to work in RLS
policies.

---------

Co-authored-by: Conrad Ludgate <[email protected]>
  • Loading branch information
devjv and conradludgate authored Oct 18, 2024
1 parent 15fecff commit 3532ae7
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 7 deletions.
52 changes: 51 additions & 1 deletion compute_tools/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use utils::lsn::Lsn;

use compute_api::privilege::Privilege;
use compute_api::responses::{ComputeMetrics, ComputeStatus};
use compute_api::spec::{ComputeFeature, ComputeMode, ComputeSpec};
use compute_api::spec::{ComputeFeature, ComputeMode, ComputeSpec, ExtVersion};
use utils::measured_stream::MeasuredReader;

use nix::sys::signal::{kill, Signal};
Expand Down Expand Up @@ -1416,6 +1416,56 @@ LIMIT 100",
Ok(())
}

pub async fn install_extension(
&self,
ext_name: &PgIdent,
db_name: &PgIdent,
ext_version: ExtVersion,
) -> Result<ExtVersion> {
use tokio_postgres::config::Config;
use tokio_postgres::NoTls;

let mut conf = Config::from_str(self.connstr.as_str()).unwrap();
conf.dbname(db_name);

let (db_client, conn) = conf
.connect(NoTls)
.await
.context("Failed to connect to the database")?;
tokio::spawn(conn);

let version_query = "SELECT extversion FROM pg_extension WHERE extname = $1";
let version: Option<ExtVersion> = db_client
.query_opt(version_query, &[&ext_name])
.await
.with_context(|| format!("Failed to execute query: {}", version_query))?
.map(|row| row.get(0));

// sanitize the inputs as postgres idents.
let ext_name: String = ext_name.pg_quote();
let quoted_version: String = ext_version.pg_quote();

if let Some(installed_version) = version {
if installed_version == ext_version {
return Ok(installed_version);
}
let query = format!("ALTER EXTENSION {ext_name} UPDATE TO {quoted_version}");
db_client
.simple_query(&query)
.await
.with_context(|| format!("Failed to execute query: {}", query))?;
} else {
let query =
format!("CREATE EXTENSION IF NOT EXISTS {ext_name} WITH VERSION {quoted_version}");
db_client
.simple_query(&query)
.await
.with_context(|| format!("Failed to execute query: {}", query))?;
}

Ok(ext_version)
}

#[tokio::main]
pub async fn prepare_preload_libraries(
&self,
Expand Down
37 changes: 35 additions & 2 deletions compute_tools/src/http/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use crate::catalog::SchemaDumpError;
use crate::catalog::{get_database_schema, get_dbs_and_roles};
use crate::compute::forward_termination_signal;
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
use compute_api::requests::{ConfigurationRequest, SetRoleGrantsRequest};
use compute_api::requests::{ConfigurationRequest, ExtensionInstallRequest, SetRoleGrantsRequest};
use compute_api::responses::{
ComputeStatus, ComputeStatusResponse, GenericAPIError, SetRoleGrantsResponse,
ComputeStatus, ComputeStatusResponse, ExtensionInstallResult, GenericAPIError,
SetRoleGrantsResponse,
};

use anyhow::Result;
Expand Down Expand Up @@ -100,6 +101,38 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}

(&Method::POST, "/extensions") => {
info!("serving /extensions POST request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for extensions request: {:?}",
status
);
error!(msg);
return render_json_error(&msg, StatusCode::PRECONDITION_FAILED);
}

let request = hyper::body::to_bytes(req.into_body()).await.unwrap();
let request = serde_json::from_slice::<ExtensionInstallRequest>(&request).unwrap();
let res = compute
.install_extension(&request.extension, &request.database, request.version)
.await;
match res {
Ok(version) => render_json(Body::from(
serde_json::to_string(&ExtensionInstallResult {
extension: request.extension,
version,
})
.unwrap(),
)),
Err(e) => {
error!("install_extension failed: {}", e);
render_json_error(&e.to_string(), StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Expand Down
69 changes: 68 additions & 1 deletion compute_tools/src/http/openapi_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,41 @@ paths:
description: Error text or 'true' if check passed.
example: "true"

/extensions:
post:
tags:
- Extensions
summary: Install extension if possible.
description: ""
operationId: installExtension
requestBody:
description: Extension name and database to install it to.
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExtensionInstallRequest"
responses:
200:
description: Result from extension installation
content:
application/json:
schema:
$ref: "#/components/schemas/ExtensionInstallResult"
412:
description: |
Compute is in the wrong state for processing the request.
content:
application/json:
schema:
$ref: "#/components/schemas/GenericError"
500:
description: Error during extension installation.
content:
application/json:
schema:
$ref: "#/components/schemas/GenericError"

/configure:
post:
tags:
Expand Down Expand Up @@ -404,7 +439,7 @@ components:
moment, when spec was received.
example: "2022-10-12T07:20:50.52Z"
status:
$ref: '#/components/schemas/ComputeStatus'
$ref: "#/components/schemas/ComputeStatus"
last_active:
type: string
description: |
Expand Down Expand Up @@ -444,6 +479,38 @@ components:
- configuration
example: running

ExtensionInstallRequest:
type: object
required:
- extension
- database
- version
properties:
extension:
type: string
description: Extension name.
example: "pg_session_jwt"
version:
type: string
description: Version of the extension.
example: "1.0.0"
database:
type: string
description: Database name.
example: "neondb"

ExtensionInstallResult:
type: object
properties:
extension:
description: Name of the extension.
type: string
example: "pg_session_jwt"
version:
description: Version of the extension.
type: string
example: "1.0.0"

InstalledExtensions:
type: object
properties:
Expand Down
10 changes: 8 additions & 2 deletions libs/compute_api/src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use crate::{
privilege::Privilege,
spec::{ComputeSpec, PgIdent},
spec::{ComputeSpec, ExtVersion, PgIdent},
};
use serde::Deserialize;

Expand All @@ -16,6 +15,13 @@ pub struct ConfigurationRequest {
pub spec: ComputeSpec,
}

#[derive(Deserialize, Debug)]
pub struct ExtensionInstallRequest {
pub extension: PgIdent,
pub database: PgIdent,
pub version: ExtVersion,
}

#[derive(Deserialize, Debug)]
pub struct SetRoleGrantsRequest {
pub database: PgIdent,
Expand Down
7 changes: 6 additions & 1 deletion libs/compute_api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize, Serializer};

use crate::{
privilege::Privilege,
spec::{ComputeSpec, Database, PgIdent, Role},
spec::{ComputeSpec, Database, ExtVersion, PgIdent, Role},
};

#[derive(Serialize, Debug, Deserialize)]
Expand Down Expand Up @@ -172,6 +172,11 @@ pub struct InstalledExtensions {
pub extensions: Vec<InstalledExtension>,
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct ExtensionInstallResult {
pub extension: PgIdent,
pub version: ExtVersion,
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct SetRoleGrantsResponse {
pub database: PgIdent,
Expand Down
3 changes: 3 additions & 0 deletions libs/compute_api/src/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use remote_storage::RemotePath;
/// intended to be used for DB / role names.
pub type PgIdent = String;

/// String type alias representing Postgres extension version
pub type ExtVersion = String;

/// Cluster spec or configuration represented as an optional number of
/// delta operations + final cluster state description.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
Expand Down
10 changes: 10 additions & 0 deletions test_runner/fixtures/endpoint/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def installed_extensions(self):
res.raise_for_status()
return res.json()

def extensions(self, extension: str, version: str, database: str):
body = {
"extension": extension,
"version": version,
"database": database,
}
res = self.post(f"http://localhost:{self.port}/extensions", json=body)
res.raise_for_status()
return res.json()

def set_role_grants(self, database: str, role: str, schema: str, privileges: list[str]):
res = self.post(
f"http://localhost:{self.port}/grants",
Expand Down
50 changes: 50 additions & 0 deletions test_runner/regress/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from logging import info

from fixtures.neon_fixtures import NeonEnv


def test_extensions(neon_simple_env: NeonEnv):
"""basic test for the extensions endpoint testing installing extensions"""

env = neon_simple_env

env.create_branch("test_extensions")

endpoint = env.endpoints.create_start("test_extensions")
extension = "neon_test_utils"
database = "test_extensions"

endpoint.safe_psql("CREATE DATABASE test_extensions")

with endpoint.connect(dbname=database) as pg_conn:
with pg_conn.cursor() as cur:
cur.execute(
"SELECT default_version FROM pg_available_extensions WHERE name = 'neon_test_utils'"
)
res = cur.fetchone()
assert res is not None
version = res[0]

with pg_conn.cursor() as cur:
cur.execute(
"SELECT extname, extversion FROM pg_extension WHERE extname = 'neon_test_utils'",
)
res = cur.fetchone()
assert not res, "The 'neon_test_utils' extension is installed"

client = endpoint.http_client()
install_res = client.extensions(extension, version, database)

info("Extension install result: %s", res)
assert install_res["extension"] == extension and install_res["version"] == version

with endpoint.connect(dbname=database) as pg_conn:
with pg_conn.cursor() as cur:
cur.execute(
"SELECT extname, extversion FROM pg_extension WHERE extname = 'neon_test_utils'",
)
res = cur.fetchone()
assert res is not None
(db_extension_name, db_extension_version) = res

assert db_extension_name == extension and db_extension_version == version

1 comment on commit 3532ae7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5229 tests run: 5012 passed, 3 failed, 214 skipped (full report)


Failures on Postgres 16

  • test_compaction_downloads_on_demand_with_image_creation: release-x86-64

Failures on Postgres 14

# Run all failed tests locally:
scripts/pytest -vv -n $(nproc) -k "test_close_on_connections_exit[release-pg14] or test_sql_over_http_serverless_driver[release-pg14] or test_compaction_downloads_on_demand_with_image_creation[release-pg16]"
Flaky tests (1)

Postgres 17

Test coverage report is not available

The comment gets automatically updated with the latest test results
3532ae7 at 2024-10-18T13:21:07.818Z :recycle:

Please sign in to comment.