diff --git a/charms/jimm-k8s/config.yaml b/charms/jimm-k8s/config.yaml index 20d8241a5..bbf92dde2 100644 --- a/charms/jimm-k8s/config.yaml +++ b/charms/jimm-k8s/config.yaml @@ -85,7 +85,17 @@ options: type: string description: | Duration for the JWT expiry (defaults to 5 minutes). + This is the JWT JIMM sends to a Juju controller to authenticate + model related commands. Increase this if long running websocket + connections are failing due to authentication errors. default: 5m + session-expiry-duration: + type: string + default: 6h + description: | + Expiry duration for JIMM session tokens. These tokens are used + by clients and their expiry determines how frequently a user + must login. macaroon-expiry-duration: type: string default: 24h diff --git a/charms/jimm-k8s/lib/charms/hydra/v0/oauth.py b/charms/jimm-k8s/lib/charms/hydra/v0/oauth.py new file mode 100644 index 000000000..6d8ed1ef9 --- /dev/null +++ b/charms/jimm-k8s/lib/charms/hydra/v0/oauth.py @@ -0,0 +1,767 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""# Oauth Library. + +This library is designed to enable applications to register OAuth2/OIDC +clients with an OIDC Provider through the `oauth` interface. + +## Getting started + +To get started using this library you just need to fetch the library using `charmcraft`. **Note +that you also need to add `jsonschema` to your charm's `requirements.txt`.** + +```shell +cd some-charm +charmcraft fetch-lib charms.hydra.v0.oauth +EOF +``` + +Then, to initialize the library: +```python +# ... +from charms.hydra.v0.oauth import ClientConfig, OAuthRequirer + +OAUTH = "oauth" +OAUTH_SCOPES = "openid email" +OAUTH_GRANT_TYPES = ["authorization_code"] + +class SomeCharm(CharmBase): + def __init__(self, *args): + # ... + self.oauth = OAuthRequirer(self, client_config, relation_name=OAUTH) + + self.framework.observe(self.oauth.on.oauth_info_changed, self._configure_application) + # ... + + def _on_ingress_ready(self, event): + self.external_url = "https://example.com" + self._set_client_config() + + def _set_client_config(self): + client_config = ClientConfig( + urljoin(self.external_url, "/oauth/callback"), + OAUTH_SCOPES, + OAUTH_GRANT_TYPES, + ) + self.oauth.update_client_config(client_config) +``` +""" + +import inspect +import json +import logging +import re +from dataclasses import asdict, dataclass, field +from typing import Dict, List, Mapping, Optional + +import jsonschema +from ops.charm import ( + CharmBase, + RelationBrokenEvent, + RelationChangedEvent, + RelationCreatedEvent, + RelationDepartedEvent, +) +from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents +from ops.model import Relation, Secret, TooManyRelatedAppsError + +# The unique Charmhub library identifier, never change it +LIBID = "a3a301e325e34aac80a2d633ef61fe97" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 5 + +logger = logging.getLogger(__name__) + +DEFAULT_RELATION_NAME = "oauth" +ALLOWED_GRANT_TYPES = ["authorization_code", "refresh_token", "client_credentials"] +ALLOWED_CLIENT_AUTHN_METHODS = ["client_secret_basic", "client_secret_post"] +CLIENT_SECRET_FIELD = "secret" + +url_regex = re.compile( + r"(^http://)|(^https://)" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|" + r"[A-Z0-9-]{2,}\.?)|" # domain... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + +OAUTH_PROVIDER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/oauth/schemas/provider.json", + "type": "object", + "properties": { + "issuer_url": { + "type": "string", + }, + "authorization_endpoint": { + "type": "string", + }, + "token_endpoint": { + "type": "string", + }, + "introspection_endpoint": { + "type": "string", + }, + "userinfo_endpoint": { + "type": "string", + }, + "jwks_endpoint": { + "type": "string", + }, + "scope": { + "type": "string", + }, + "client_id": { + "type": "string", + }, + "client_secret_id": { + "type": "string", + }, + "groups": {"type": "string", "default": None}, + "ca_chain": {"type": "array", "items": {"type": "string"}, "default": []}, + }, + "required": [ + "issuer_url", + "authorization_endpoint", + "token_endpoint", + "introspection_endpoint", + "userinfo_endpoint", + "jwks_endpoint", + "scope", + ], +} +OAUTH_REQUIRER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/oauth/schemas/requirer.json", + "type": "object", + "properties": { + "redirect_uri": { + "type": "string", + "default": None, + }, + "audience": {"type": "array", "default": [], "items": {"type": "string"}}, + "scope": {"type": "string", "default": None}, + "grant_types": { + "type": "array", + "default": None, + "items": { + "enum": ["authorization_code", "client_credentials", "refresh_token"], + "type": "string", + }, + }, + "token_endpoint_auth_method": { + "type": "string", + "enum": ["client_secret_basic", "client_secret_post"], + "default": "client_secret_basic", + }, + }, + "required": ["redirect_uri", "audience", "scope", "grant_types", "token_endpoint_auth_method"], +} + + +class ClientConfigError(Exception): + """Emitted when invalid client config is provided.""" + + +class DataValidationError(RuntimeError): + """Raised when data validation fails on relation data.""" + + +def _load_data(data: Mapping, schema: Optional[Dict] = None) -> Dict: + """Parses nested fields and checks whether `data` matches `schema`.""" + ret = {} + for k, v in data.items(): + try: + ret[k] = json.loads(v) + except json.JSONDecodeError: + ret[k] = v + + if schema: + _validate_data(ret, schema) + return ret + + +def _dump_data(data: Dict, schema: Optional[Dict] = None) -> Dict: + if schema: + _validate_data(data, schema) + + ret = {} + for k, v in data.items(): + if isinstance(v, (list, dict)): + try: + ret[k] = json.dumps(v) + except json.JSONDecodeError as e: + raise DataValidationError(f"Failed to encode relation json: {e}") + else: + ret[k] = v + return ret + + +class OAuthRelation(Object): + """A class containing helper methods for oauth relation.""" + + def _pop_relation_data(self, relation_id: Relation) -> None: + if not self.model.unit.is_leader(): + return + + if len(self.model.relations) == 0: + return + + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + if not relation or not relation.app: + return + + try: + for data in list(relation.data[self.model.app]): + relation.data[self.model.app].pop(data, "") + except Exception as e: + logger.info(f"Failed to pop the relation data: {e}") + + +def _validate_data(data: Dict, schema: Dict) -> None: + """Checks whether `data` matches `schema`. + + Will raise DataValidationError if the data is not valid, else return None. + """ + try: + jsonschema.validate(instance=data, schema=schema) + except jsonschema.ValidationError as e: + raise DataValidationError(data, schema) from e + + +@dataclass +class ClientConfig: + """Helper class containing a client's configuration.""" + + redirect_uri: str + scope: str + grant_types: List[str] + audience: List[str] = field(default_factory=lambda: []) + token_endpoint_auth_method: str = "client_secret_basic" + client_id: Optional[str] = None + + def validate(self) -> None: + """Validate the client configuration.""" + # Validate redirect_uri + if not re.match(url_regex, self.redirect_uri): + raise ClientConfigError(f"Invalid URL {self.redirect_uri}") + + if self.redirect_uri.startswith("http://"): + logger.warning("Provided Redirect URL uses http scheme. Don't do this in production") + + # Validate grant_types + for grant_type in self.grant_types: + if grant_type not in ALLOWED_GRANT_TYPES: + raise ClientConfigError( + f"Invalid grant_type {grant_type}, must be one " f"of {ALLOWED_GRANT_TYPES}" + ) + + # Validate client authentication methods + if self.token_endpoint_auth_method not in ALLOWED_CLIENT_AUTHN_METHODS: + raise ClientConfigError( + f"Invalid client auth method {self.token_endpoint_auth_method}, " + f"must be one of {ALLOWED_CLIENT_AUTHN_METHODS}" + ) + + def to_dict(self) -> Dict: + """Convert object to dict.""" + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass +class OauthProviderConfig: + """Helper class containing provider's configuration.""" + + issuer_url: str + authorization_endpoint: str + token_endpoint: str + introspection_endpoint: str + userinfo_endpoint: str + jwks_endpoint: str + scope: str + client_id: Optional[str] = None + client_secret: Optional[str] = None + groups: Optional[str] = None + ca_chain: Optional[str] = None + + @classmethod + def from_dict(cls, dic: Dict) -> "OauthProviderConfig": + """Generate OauthProviderConfig instance from dict.""" + return cls(**{k: v for k, v in dic.items() if k in inspect.signature(cls).parameters}) + + +class OAuthInfoChangedEvent(EventBase): + """Event to notify the charm that the information in the databag changed.""" + + def __init__(self, handle: Handle, client_id: str, client_secret_id: str): + super().__init__(handle) + self.client_id = client_id + self.client_secret_id = client_secret_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "client_id": self.client_id, + "client_secret_id": self.client_secret_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.client_id = snapshot["client_id"] + self.client_secret_id = snapshot["client_secret_id"] + + +class InvalidClientConfigEvent(EventBase): + """Event to notify the charm that the client configuration is invalid.""" + + def __init__(self, handle: Handle, error: str): + super().__init__(handle) + self.error = error + + def snapshot(self) -> Dict: + """Save event.""" + return { + "error": self.error, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.error = snapshot["error"] + + +class OAuthInfoRemovedEvent(EventBase): + """Event to notify the charm that the provider data was removed.""" + + def snapshot(self) -> Dict: + """Save event.""" + return {} + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + pass + + +class OAuthRequirerEvents(ObjectEvents): + """Event descriptor for events raised by `OAuthRequirerEvents`.""" + + oauth_info_changed = EventSource(OAuthInfoChangedEvent) + oauth_info_removed = EventSource(OAuthInfoRemovedEvent) + invalid_client_config = EventSource(InvalidClientConfigEvent) + + +class OAuthRequirer(OAuthRelation): + """Register an oauth client.""" + + on = OAuthRequirerEvents() + + def __init__( + self, + charm: CharmBase, + client_config: Optional[ClientConfig] = None, + relation_name: str = DEFAULT_RELATION_NAME, + ) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + self._client_config = client_config + events = self._charm.on[relation_name] + self.framework.observe(events.relation_created, self._on_relation_created_event) + self.framework.observe(events.relation_changed, self._on_relation_changed_event) + self.framework.observe(events.relation_broken, self._on_relation_broken_event) + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + try: + self._update_relation_data(self._client_config, event.relation.id) + except ClientConfigError as e: + self.on.invalid_client_config.emit(e.args[0]) + + def _on_relation_broken_event(self, event: RelationBrokenEvent) -> None: + # Workaround for https://github.com/canonical/operator/issues/888 + self._pop_relation_data(event.relation.id) + if self.is_client_created(): + event.defer() + logger.info("Relation data still available. Deferring the event") + return + + # Notify the requirer that the relation data was removed + self.on.oauth_info_removed.emit() + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + if not self.model.unit.is_leader(): + return + + data = event.relation.data[event.app] + if not data: + logger.info("No relation data available.") + return + + data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + + client_id = data.get("client_id") + client_secret_id = data.get("client_secret_id") + if not client_id or not client_secret_id: + logger.info("OAuth Provider info is available, waiting for client to be registered.") + # The client credentials are not ready yet, so we do nothing + # This could mean that the client credentials were removed from the databag, + # but we don't allow that (for now), so we don't have to check for it. + return + + self.on.oauth_info_changed.emit(client_id, client_secret_id) + + def _update_relation_data( + self, client_config: Optional[ClientConfig], relation_id: Optional[int] = None + ) -> None: + if not self.model.unit.is_leader() or not client_config: + return + + if not isinstance(client_config, ClientConfig): + raise ValueError(f"Unexpected client_config type: {type(client_config)}") + + client_config.validate() + + try: + relation = self.model.get_relation( + relation_name=self._relation_name, relation_id=relation_id + ) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + + if not relation or not relation.app: + return + + data = _dump_data(client_config.to_dict(), OAUTH_REQUIRER_JSON_SCHEMA) + relation.data[self.model.app].update(data) + + def is_client_created(self, relation_id: Optional[int] = None) -> bool: + """Check if the client has been created.""" + if len(self.model.relations) == 0: + return None + try: + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + + if not relation or not relation.app: + return None + + return ( + "client_id" in relation.data[relation.app] + and "client_secret_id" in relation.data[relation.app] + ) + + def get_provider_info(self, relation_id: Optional[int] = None) -> OauthProviderConfig: + """Get the provider information from the databag.""" + if len(self.model.relations) == 0: + return None + try: + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + if not relation or not relation.app: + return None + + data = relation.data[relation.app] + if not data: + logger.info("No relation data available.") + return + + data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + + client_secret_id = data.get("client_secret_id") + if client_secret_id: + _client_secret = self.get_client_secret(client_secret_id) + client_secret = _client_secret.get_content()[CLIENT_SECRET_FIELD] + data["client_secret"] = client_secret + + oauth_provider = OauthProviderConfig.from_dict(data) + return oauth_provider + + def get_client_secret(self, client_secret_id: str) -> Secret: + """Get the client_secret.""" + client_secret = self.model.get_secret(id=client_secret_id) + return client_secret + + def update_client_config( + self, client_config: ClientConfig, relation_id: Optional[int] = None + ) -> None: + """Update the client config stored in the object.""" + self._client_config = client_config + self._update_relation_data(client_config, relation_id=relation_id) + + +class ClientCreatedEvent(EventBase): + """Event to notify the Provider charm to create a new client.""" + + def __init__( + self, + handle: Handle, + redirect_uri: str, + scope: str, + grant_types: List[str], + audience: List, + token_endpoint_auth_method: str, + relation_id: int, + ) -> None: + super().__init__(handle) + self.redirect_uri = redirect_uri + self.scope = scope + self.grant_types = grant_types + self.audience = audience + self.token_endpoint_auth_method = token_endpoint_auth_method + self.relation_id = relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "redirect_uri": self.redirect_uri, + "scope": self.scope, + "grant_types": self.grant_types, + "audience": self.audience, + "token_endpoint_auth_method": self.token_endpoint_auth_method, + "relation_id": self.relation_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.redirect_uri = snapshot["redirect_uri"] + self.scope = snapshot["scope"] + self.grant_types = snapshot["grant_types"] + self.audience = snapshot["audience"] + self.token_endpoint_auth_method = snapshot["token_endpoint_auth_method"] + self.relation_id = snapshot["relation_id"] + + def to_client_config(self) -> ClientConfig: + """Convert the event information to a ClientConfig object.""" + return ClientConfig( + self.redirect_uri, + self.scope, + self.grant_types, + self.audience, + self.token_endpoint_auth_method, + ) + + +class ClientChangedEvent(EventBase): + """Event to notify the Provider charm that the client config changed.""" + + def __init__( + self, + handle: Handle, + redirect_uri: str, + scope: str, + grant_types: List, + audience: List, + token_endpoint_auth_method: str, + relation_id: int, + client_id: str, + ) -> None: + super().__init__(handle) + self.redirect_uri = redirect_uri + self.scope = scope + self.grant_types = grant_types + self.audience = audience + self.token_endpoint_auth_method = token_endpoint_auth_method + self.relation_id = relation_id + self.client_id = client_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "redirect_uri": self.redirect_uri, + "scope": self.scope, + "grant_types": self.grant_types, + "audience": self.audience, + "token_endpoint_auth_method": self.token_endpoint_auth_method, + "relation_id": self.relation_id, + "client_id": self.client_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.redirect_uri = snapshot["redirect_uri"] + self.scope = snapshot["scope"] + self.grant_types = snapshot["grant_types"] + self.audience = snapshot["audience"] + self.token_endpoint_auth_method = snapshot["token_endpoint_auth_method"] + self.relation_id = snapshot["relation_id"] + self.client_id = snapshot["client_id"] + + def to_client_config(self) -> ClientConfig: + """Convert the event information to a ClientConfig object.""" + return ClientConfig( + self.redirect_uri, + self.scope, + self.grant_types, + self.audience, + self.token_endpoint_auth_method, + self.client_id, + ) + + +class ClientDeletedEvent(EventBase): + """Event to notify the Provider charm that the client was deleted.""" + + def __init__( + self, + handle: Handle, + relation_id: int, + ) -> None: + super().__init__(handle) + self.relation_id = relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return {"relation_id": self.relation_id} + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.relation_id = snapshot["relation_id"] + + +class OAuthProviderEvents(ObjectEvents): + """Event descriptor for events raised by `OAuthProviderEvents`.""" + + client_created = EventSource(ClientCreatedEvent) + client_changed = EventSource(ClientChangedEvent) + client_deleted = EventSource(ClientDeletedEvent) + + +class OAuthProvider(OAuthRelation): + """A provider object for OIDC Providers.""" + + on = OAuthProviderEvents() + + def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + + events = self._charm.on[relation_name] + self.framework.observe( + events.relation_changed, + self._get_client_config_from_relation_data, + ) + self.framework.observe( + events.relation_departed, + self._on_relation_departed, + ) + + def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> None: + if not self.model.unit.is_leader(): + return + + data = event.relation.data[event.app] + if not data: + logger.info("No requirer relation data available.") + return + + client_data = _load_data(data, OAUTH_REQUIRER_JSON_SCHEMA) + redirect_uri = client_data.get("redirect_uri") + scope = client_data.get("scope") + grant_types = client_data.get("grant_types") + audience = client_data.get("audience") + token_endpoint_auth_method = client_data.get("token_endpoint_auth_method") + + data = event.relation.data[self._charm.app] + if not data: + logger.info("No provider relation data available.") + return + provider_data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + client_id = provider_data.get("client_id") + + relation_id = event.relation.id + + if client_id: + # Modify an existing client + self.on.client_changed.emit( + redirect_uri, + scope, + grant_types, + audience, + token_endpoint_auth_method, + relation_id, + client_id, + ) + else: + # Create a new client + self.on.client_created.emit( + redirect_uri, scope, grant_types, audience, token_endpoint_auth_method, relation_id + ) + + def _get_secret_label(self, relation: Relation) -> str: + return f"client_secret_{relation.id}" + + def _on_relation_departed(self, event: RelationDepartedEvent) -> None: + # Workaround for https://github.com/canonical/operator/issues/888 + self._pop_relation_data(event.relation.id) + + self._delete_juju_secret(event.relation) + self.on.client_deleted.emit(event.relation.id) + + def _create_juju_secret(self, client_secret: str, relation: Relation) -> Secret: + """Create a juju secret and grant it to a relation.""" + secret = {CLIENT_SECRET_FIELD: client_secret} + juju_secret = self.model.app.add_secret(secret, label=self._get_secret_label(relation)) + juju_secret.grant(relation) + return juju_secret + + def _delete_juju_secret(self, relation: Relation) -> None: + secret = self.model.get_secret(label=self._get_secret_label(relation)) + secret.remove_all_revisions() + + def set_provider_info_in_relation_data( + self, + issuer_url: str, + authorization_endpoint: str, + token_endpoint: str, + introspection_endpoint: str, + userinfo_endpoint: str, + jwks_endpoint: str, + scope: str, + groups: Optional[str] = None, + ca_chain: Optional[str] = None, + ) -> None: + """Put the provider information in the databag.""" + if not self.model.unit.is_leader(): + return + + data = { + "issuer_url": issuer_url, + "authorization_endpoint": authorization_endpoint, + "token_endpoint": token_endpoint, + "introspection_endpoint": introspection_endpoint, + "userinfo_endpoint": userinfo_endpoint, + "jwks_endpoint": jwks_endpoint, + "scope": scope, + } + if groups: + data["groups"] = groups + if ca_chain: + data["ca_chain"] = ca_chain + + for relation in self.model.relations[self._relation_name]: + relation.data[self.model.app].update(_dump_data(data)) + + def set_client_credentials_in_relation_data( + self, relation_id: int, client_id: str, client_secret: str + ) -> None: + """Put the client credentials in the databag.""" + if not self.model.unit.is_leader(): + return + + relation = self.model.get_relation(self._relation_name, relation_id) + if not relation or not relation.app: + return + # TODO: What if we are refreshing the client_secret? We need to add a + # new revision for that + secret = self._create_juju_secret(client_secret, relation) + data = dict(client_id=client_id, client_secret_id=secret.id) + relation.data[self.model.app].update(_dump_data(data)) diff --git a/charms/jimm-k8s/metadata.yaml b/charms/jimm-k8s/metadata.yaml index 703b28a22..5eef23279 100644 --- a/charms/jimm-k8s/metadata.yaml +++ b/charms/jimm-k8s/metadata.yaml @@ -58,6 +58,9 @@ requires: interface: loki_push_api optional: true limit: 1 + oauth: + interface: oauth + limit: 1 containers: jimm: diff --git a/charms/jimm-k8s/requirements.txt b/charms/jimm-k8s/requirements.txt index 7b6542db5..4d6a6257a 100644 --- a/charms/jimm-k8s/requirements.txt +++ b/charms/jimm-k8s/requirements.txt @@ -5,3 +5,4 @@ jsonschema >= 3.2.0 cryptography >= 3.4.8 hvac >= 0.11.0 requests >= 2.25.1 +jsonschema diff --git a/charms/jimm-k8s/src/charm.py b/charms/jimm-k8s/src/charm.py index 566db985d..1734dda29 100755 --- a/charms/jimm-k8s/src/charm.py +++ b/charms/jimm-k8s/src/charm.py @@ -19,6 +19,7 @@ import json import logging import socket +from urllib.parse import urljoin import hvac import requests @@ -27,6 +28,7 @@ DatabaseRequires, ) from charms.grafana_k8s.v0.grafana_dashboard import GrafanaDashboardProvider +from charms.hydra.v0.oauth import ClientConfig, OAuthInfoChangedEvent, OAuthRequirer from charms.loki_k8s.v0.loki_push_api import LogProxyConsumer from charms.nginx_ingress_integrator.v0.nginx_route import require_nginx_route from charms.openfga_k8s.v0.openfga import OpenFGARequires, OpenFGAStoreCreateEvent @@ -72,6 +74,10 @@ LOG_FILE = "/var/log/jimm" # This likely will just be JIMM's port. PROMETHEUS_PORT = 8080 +OAUTH = "oauth" +OAUTH_SCOPES = "openid email offline_access" +# TODO: Add "device_code" below once the charm interface supports it. +OAUTH_GRANT_TYPES = ["authorization_code", "refresh_token"] class DeferError(Exception): @@ -89,7 +95,10 @@ def __init__(self, *args): self._state = State(self.app, lambda: self.model.get_relation("peer")) self._unit_state = State(self.unit, lambda: self.model.get_relation("peer")) + self.oauth = OAuthRequirer(self, self._oauth_client_config, relation_name=OAUTH) + self.framework.observe(self.oauth.on.oauth_info_changed, self._on_oauth_info_changed) + self.framework.observe(self.oauth.on.oauth_info_removed, self._on_oauth_info_changed) self.framework.observe(self.on.peer_relation_changed, self._on_peer_relation_changed) self.framework.observe(self.on.jimm_pebble_ready, self._on_jimm_pebble_ready) self.framework.observe(self.on.config_changed, self._on_config_changed) @@ -199,6 +208,9 @@ def _on_jimm_pebble_ready(self, event): def _on_config_changed(self, event): self._update_workload(event) + def _on_oauth_info_changed(self, event: OAuthInfoChangedEvent): + self._update_workload(event) + @requires_state_setter def _on_leader_elected(self, event): if not self._state.private_key: @@ -235,6 +247,7 @@ def _update_workload(self, event): event.defer() return + self.oauth.update_client_config(client_config=self._oauth_client_config) self._ensure_bakery_agent_file(event) self._ensure_vault_file(event) if self.model.get_relation("vault") and not container.exists(self._vault_secret_filename): @@ -242,11 +255,18 @@ def _update_workload(self, event): self.unit.status = BlockedStatus("Vault relation present but vault setup is not ready yet") return + if not self.oauth.is_client_created(): + logger.warning("OAuth relation is not ready yet") + self.unit.status = BlockedStatus("Waiting for OAuth relation") + return + dns_name = self._get_dns_name(event) if not dns_name: logger.warning("dns name not set") return + oauth_provider_info = self.oauth.get_provider_info() + config_values = { "CANDID_PUBLIC_KEY": self.config.get("candid-public-key", ""), "CANDID_URL": self.config.get("candid-url", ""), @@ -265,8 +285,13 @@ def _update_workload(self, event): "OPENFGA_PORT": self._state.openfga_port, "PRIVATE_KEY": self.config.get("private-key", ""), "PUBLIC_KEY": self.config.get("public-key", ""), - "JIMM_JWT_EXPIRY": self.config.get("jwt-expiry", "5m"), + "JIMM_JWT_EXPIRY": self.config.get("jwt-expiry"), "JIMM_MACAROON_EXPIRY_DURATION": self.config.get("macaroon-expiry-duration", "24h"), + "JIMM_ACCESS_TOKEN_EXPIRY_DURATION": self.config.get("session-expiry-duration"), + "JIMM_OAUTH_ISSUER_URL": oauth_provider_info.issuer_url, + "JIMM_OAUTH_CLIENT_ID": oauth_provider_info.client_id, + "JIMM_OAUTH_CLIENT_SECRET": oauth_provider_info.client_secret, + "JIMM_OAUTH_SCOPES": oauth_provider_info.scope, } if self._state.dsn: config_values["JIMM_DSN"] = self._state.dsn @@ -735,6 +760,25 @@ def _on_create_authorization_model_action(self, event: ActionEvent): self._state.openfga_auth_model_id = authorization_model_id self._update_workload(event) + @property + def _oauth_client_config(self) -> ClientConfig: + dns = self.config.get("dns-name") + if dns is None or dns == "": + dns = "http://localhost" + dns = ensureFQDN(dns) + return ClientConfig( + urljoin(dns, "/oauth/callback"), + OAUTH_SCOPES, + OAUTH_GRANT_TYPES, + ) + + +def ensureFQDN(dns: str): # noqa: N802 + """Ensures a domain name has an https:// prefix.""" + if not dns.startswith("http"): + dns = "https://" + dns + return dns + def _json_data(event, key): logger.debug("getting relation data {}".format(key)) diff --git a/charms/jimm-k8s/tests/unit/test_charm.py b/charms/jimm-k8s/tests/unit/test_charm.py index 44306f157..3053aa224 100644 --- a/charms/jimm-k8s/tests/unit/test_charm.py +++ b/charms/jimm-k8s/tests/unit/test_charm.py @@ -15,6 +15,18 @@ from src.charm import JimmOperatorCharm +OAUTH_CLIENT_ID = "jimm_client_id" +OAUTH_CLIENT_SECRET = "test-secret" +OAUTH_PROVIDER_INFO = { + "authorization_endpoint": "https://example.oidc.com/oauth2/auth", + "introspection_endpoint": "https://example.oidc.com/admin/oauth2/introspect", + "issuer_url": "https://example.oidc.com", + "jwks_endpoint": "https://example.oidc.com/.well-known/jwks.json", + "scope": "openid profile email phone", + "token_endpoint": "https://example.oidc.com/oauth2/token", + "userinfo_endpoint": "https://example.oidc.com/userinfo", +} + MINIMAL_CONFIG = { "uuid": "1234567890", "candid-url": "test-candid-url", @@ -28,7 +40,6 @@ "JIMM_DASHBOARD_LOCATION": "https://jaas.ai/models", "JIMM_DNS_NAME": "juju-jimm-k8s-0.juju-jimm-k8s-endpoints.None.svc.cluster.local", "JIMM_ENABLE_JWKS_ROTATOR": "1", - "JIMM_JWT_EXPIRY": "5m", "JIMM_LISTEN_ADDR": ":8080", "JIMM_LOG_LEVEL": "info", "JIMM_UUID": "1234567890", @@ -37,9 +48,36 @@ "PUBLIC_KEY": "izcYsQy3TePp6bLjqOo3IRPFvkQd2IKtyODGqC6SdFk=", "JIMM_AUDIT_LOG_RETENTION_PERIOD_IN_DAYS": "0", "JIMM_MACAROON_EXPIRY_DURATION": "24h", + "JIMM_JWT_EXPIRY": "5m", + "JIMM_ACCESS_TOKEN_EXPIRY_DURATION": "6h", + "JIMM_OAUTH_ISSUER_URL": OAUTH_PROVIDER_INFO["issuer_url"], + "JIMM_OAUTH_CLIENT_ID": OAUTH_CLIENT_ID, + "JIMM_OAUTH_CLIENT_SECRET": OAUTH_CLIENT_SECRET, + "JIMM_OAUTH_SCOPES": OAUTH_PROVIDER_INFO["scope"], } +def get_expected_plan(env): + return { + "services": { + "jimm": { + "summary": "JAAS Intelligent Model Manager", + "startup": "disabled", + "override": "replace", + "command": "/root/jimmsrv", + "environment": env, + } + }, + "checks": { + "jimm-check": { + "override": "replace", + "period": "1m", + "http": {"url": "http://localhost:8080/debug/status"}, + } + }, + } + + class MockExec: def wait_output(): return True @@ -64,8 +102,22 @@ def setUp(self): self.harness.add_relation_unit(jimm_id, "juju-jimm-k8s/1") self.harness.container_pebble_ready("jimm") - rel_id = self.harness.add_relation("ingress", "nginx-ingress") - self.harness.add_relation_unit(rel_id, "nginx-ingress/0") + self.ingress_rel_id = self.harness.add_relation("ingress", "nginx-ingress") + self.harness.add_relation_unit(self.ingress_rel_id, "nginx-ingress/0") + + self.oauth_rel_id = self.harness.add_relation("oauth", "hydra") + self.harness.add_relation_unit(self.oauth_rel_id, "hydra/0") + secret_id = self.harness.add_model_secret("hydra", {"secret": OAUTH_CLIENT_SECRET}) + self.harness.grant_secret(secret_id, "juju-jimm-k8s") + self.harness.update_relation_data( + self.oauth_rel_id, + "hydra", + { + "client_id": OAUTH_CLIENT_ID, + "client_secret_id": secret_id, + **OAUTH_PROVIDER_INFO, + }, + ) # import ipdb; ipdb.set_trace() def test_on_pebble_ready(self): @@ -77,20 +129,7 @@ def test_on_pebble_ready(self): # Check the that the plan was updated plan = self.harness.get_container_pebble_plan("jimm") - self.assertEqual( - plan.to_dict(), - { - "services": { - "jimm": { - "summary": "JAAS Intelligent Model Manager", - "startup": "disabled", - "override": "replace", - "command": "/root/jimmsrv", - "environment": EXPECTED_ENV, - } - } - }, - ) + self.assertEqual(plan.to_dict(), get_expected_plan(EXPECTED_ENV)) def test_on_config_changed(self): container = self.harness.model.unit.get_container("jimm") @@ -104,20 +143,7 @@ def test_on_config_changed(self): # Check the that the plan was updated plan = self.harness.get_container_pebble_plan("jimm") - self.assertEqual( - plan.to_dict(), - { - "services": { - "jimm": { - "summary": "JAAS Intelligent Model Manager", - "startup": "disabled", - "override": "replace", - "command": "/root/jimmsrv", - "environment": EXPECTED_ENV, - } - } - }, - ) + self.assertEqual(plan.to_dict(), get_expected_plan(EXPECTED_ENV)) def test_postgres_secret_storage_config(self): container = self.harness.model.unit.get_container("jimm") @@ -134,20 +160,50 @@ def test_postgres_secret_storage_config(self): plan = self.harness.get_container_pebble_plan("jimm") expected_env = EXPECTED_ENV.copy() expected_env.update({"INSECURE_SECRET_STORAGE": "enabled"}) - self.assertEqual( - plan.to_dict(), - { - "services": { - "jimm": { - "summary": "JAAS Intelligent Model Manager", - "startup": "disabled", - "override": "replace", - "command": "/root/jimmsrv", - "environment": expected_env, - } - } - }, + self.assertEqual(plan.to_dict(), get_expected_plan(expected_env)) + + def test_app_dns_address(self): + self.harness.update_config(MINIMAL_CONFIG) + self.harness.update_config({"dns-name": "jimm.com"}) + oauth_client = self.harness.charm._oauth_client_config + self.assertEqual(oauth_client.redirect_uri, "https://jimm.com/oauth/callback") + + def test_app_enters_block_states_if_oauth_relation_removed(self): + self.harness.update_config(MINIMAL_CONFIG) + self.harness.remove_relation(self.oauth_rel_id) + container = self.harness.model.unit.get_container("jimm") + # Emit the pebble-ready event for jimm + self.harness.charm.on.jimm_pebble_ready.emit(container) + + # Check the that the plan is empty + plan = self.harness.get_container_pebble_plan("jimm") + self.assertEqual(plan.to_dict(), {}) + self.assertEqual(self.harness.charm.unit.status.name, BlockedStatus.name) + self.assertEqual(self.harness.charm.unit.status.message, "Waiting for OAuth relation") + + def test_app_enters_block_state_if_oauth_relation_not_ready(self): + self.harness.update_config(MINIMAL_CONFIG) + self.harness.remove_relation(self.oauth_rel_id) + oauth_relation = self.harness.add_relation("oauth", "hydra") + self.harness.add_relation_unit(oauth_relation, "hydra/0") + secret_id = self.harness.add_model_secret("hydra", {"secret": OAUTH_CLIENT_SECRET}) + self.harness.grant_secret(secret_id, "juju-jimm-k8s") + # If the client-id is empty we should detect that the oauth relation is not ready. + # The readiness check is handled by the OAuth library. + self.harness.update_relation_data( + oauth_relation, + "hydra", + {"client_id": ""}, ) + container = self.harness.model.unit.get_container("jimm") + # Emit the pebble-ready event for jimm + self.harness.charm.on.jimm_pebble_ready.emit(container) + + # Check the that the plan is empty + plan = self.harness.get_container_pebble_plan("jimm") + self.assertEqual(plan.to_dict(), {}) + self.assertEqual(self.harness.charm.unit.status.name, BlockedStatus.name) + self.assertEqual(self.harness.charm.unit.status.message, "Waiting for OAuth relation") def test_bakery_configuration(self): container = self.harness.model.unit.get_container("jimm") @@ -171,20 +227,7 @@ def test_bakery_configuration(self): expected_env.update({"BAKERY_AGENT_FILE": "/root/config/agent.json"}) # Check the that the plan was updated plan = self.harness.get_container_pebble_plan("jimm") - self.assertEqual( - plan.to_dict(), - { - "services": { - "jimm": { - "summary": "JAAS Intelligent Model Manager", - "startup": "disabled", - "override": "replace", - "command": "/root/jimmsrv", - "environment": expected_env, - } - } - }, - ) + self.assertEqual(plan.to_dict(), get_expected_plan(expected_env)) agent_data = container.pull("/root/config/agent.json") agent_json = json.loads(agent_data.read()) self.assertEqual( @@ -211,20 +254,7 @@ def test_audit_log_retention_config(self): expected_env.update({"JIMM_AUDIT_LOG_RETENTION_PERIOD_IN_DAYS": "10"}) # Check the that the plan was updated plan = self.harness.get_container_pebble_plan("jimm") - self.assertEqual( - plan.to_dict(), - { - "services": { - "jimm": { - "summary": "JAAS Intelligent Model Manager", - "startup": "disabled", - "override": "replace", - "command": "/root/jimmsrv", - "environment": expected_env, - } - } - }, - ) + self.assertEqual(plan.to_dict(), get_expected_plan(expected_env)) def test_dashboard_relation_joined(self): harness = Harness(JimmOperatorCharm) diff --git a/cmd/jimmsrv/main.go b/cmd/jimmsrv/main.go index 36eb1de48..db7a665d2 100644 --- a/cmd/jimmsrv/main.go +++ b/cmd/jimmsrv/main.go @@ -105,7 +105,7 @@ func start(ctx context.Context, s *service.Service) error { } scopes := os.Getenv("JIMM_OAUTH_SCOPES") - scopesParsed := strings.Split(scopes, ",") + scopesParsed := strings.Split(scopes, " ") for i, scope := range scopesParsed { scopesParsed[i] = strings.TrimSpace(scope) } diff --git a/docker-compose.yaml b/docker-compose.yaml index 2e663a1c0..f094ebe80 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -77,7 +77,7 @@ services: JIMM_OAUTH_ISSUER_URL: "http://keycloak:8082/realms/jimm" # Scheme required JIMM_OAUTH_CLIENT_ID: "jimm-device" JIMM_OAUTH_CLIENT_SECRET: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4" - JIMM_OAUTH_SCOPES: "openid, profile, email" # Comma separated list of scopes + JIMM_OAUTH_SCOPES: "openid profile email" # Space separated list of scopes volumes: - ./:/jimm/ - ./local/vault/approle.json:/vault/approle.json:rw