generated from canonical/is-charms-template-repo
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement REST client for k8s charm (#6)
* Add K8sSnapManager * Add license * Refactor API to accept a ConnectionFactory * Ignore woke in pylint configuration
- Loading branch information
1 parent
b843495
commit ed15099
Showing
5 changed files
with
328 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
ops >= 2.2.0 | ||
pydantic == 1.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
# Copyright 2024 Canonical Ltd. | ||
# See LICENSE file for licensing details. | ||
|
||
"""Module for managing Kubernetes Snap interactions.""" | ||
import json | ||
import socket | ||
from contextlib import contextmanager | ||
from http.client import HTTPConnection, HTTPException | ||
from typing import List, Type, TypeVar | ||
|
||
from pydantic import BaseModel, Field, validator | ||
|
||
|
||
class K8sdAPIManagerError(Exception): | ||
"""Base exception for K8sd API Manager errors.""" | ||
|
||
|
||
class K8sdConnectionError(K8sdAPIManagerError): | ||
"""Raised when there is a connection error.""" | ||
|
||
|
||
class InvalidResponseError(K8sdAPIManagerError): | ||
"""Raised when the response is invalid or unexpected.""" | ||
|
||
|
||
class BaseRequestModel(BaseModel): | ||
"""Base model for k8s request responses. | ||
Attributes: | ||
type (str): The type of the request. | ||
status (str): The status of the response, defaults to an empty string. | ||
status_code (int): The status code of the response. | ||
operation (str): The operation of the request, defaults to an empty string. | ||
error_code (int): The error code associated with the response. | ||
error (str): The error message, defaults to an empty string. | ||
""" | ||
|
||
type: str | ||
status: str = Field(default="") | ||
status_code: int | ||
operation: str = Field(default="") | ||
error_code: int | ||
error: str = Field(default="") | ||
|
||
@validator("status_code", always=True) | ||
def check_status_code(cls, v): | ||
"""Validate the status_code field. | ||
Args: | ||
v (int): The value of the status_code field to validate. | ||
Returns: | ||
int: The validated status code if it is 200. | ||
Raises: | ||
ValueError: If the status_code is not 200. | ||
""" | ||
if v != 200: | ||
raise ValueError(f"Status code must be 200. Received {v}") | ||
return v | ||
|
||
@validator("error_code", always=True) | ||
def check_error_code(cls, v): | ||
"""Validate the error_code field. | ||
Args: | ||
v (int): The value of the error_code field to validate. | ||
Returns: | ||
int: The validated error code if it is 0. | ||
Raises: | ||
ValueError: If the error_code is not 0. | ||
""" | ||
if v != 0: | ||
raise ValueError(f"Error code must be 0. Received {v}") | ||
return v | ||
|
||
@validator("error", always=True) | ||
def check_error(cls, v, values): | ||
"""Validate the error field. | ||
Args: | ||
v (str): The value of the error_code field to validate. | ||
values (dict): Dictionary of field values. | ||
Returns: | ||
str: The validated error message. | ||
Raises: | ||
ValueError: If the error_code is non-zero and the error message is missing. | ||
""" | ||
if "error_code" in values and values["error_code"] != 0 and not v: | ||
raise ValueError("Error message must be provided for non-zero error code") | ||
return v | ||
|
||
|
||
class UpdateComponentResponse(BaseRequestModel): | ||
"""Response model for updating a k8s component.""" | ||
|
||
|
||
class TokenMetadata(BaseModel): | ||
"""Model representing metadata for a token. | ||
Attributes: | ||
token (str): The actual token string. | ||
""" | ||
|
||
token: str | ||
|
||
|
||
class AuthTokenResponse(BaseRequestModel): | ||
"""Response model for Kubernetes authentication token requests. | ||
Attributes: | ||
metadata (TokenMetadata): Metadata containing the authentication token. | ||
""" | ||
|
||
metadata: TokenMetadata | ||
|
||
|
||
class CreateJoinTokenResponse(BaseRequestModel): | ||
"""Response model for join token creation requests. | ||
Attributes: | ||
metadata (TokenMetadata): Metadata containing the newly created join token. | ||
""" | ||
|
||
metadata: TokenMetadata | ||
|
||
|
||
T = TypeVar("T", bound=BaseRequestModel) | ||
|
||
|
||
class UnixSocketHTTPConnection(HTTPConnection): | ||
"""HTTP connection over a Unix socket.""" | ||
|
||
def __init__(self, unix_socket: str, timeout: int = 30): | ||
"""Initialise the UnixSocketHTTPConnection. | ||
Args: | ||
unix_socket (str): Path to the Unix socket. | ||
timeout (int): Connection timeout in seconds. | ||
""" | ||
super().__init__("localhost", timeout=timeout) | ||
self.unix_socket = unix_socket | ||
|
||
def connect(self): | ||
"""Establish a connection to the server using a Unix socket. | ||
Raises: | ||
K8sdConnectionError: If there is an error connecting to the Unix socket. | ||
""" | ||
try: | ||
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | ||
self.sock.settimeout(self.timeout) | ||
self.sock.connect(self.unix_socket) | ||
except socket.error as e: | ||
raise K8sdConnectionError(f"Error connecting to socket: {e}") from e | ||
|
||
|
||
class ConnectionFactory: | ||
"""Abstract factory for creating connection objects.""" | ||
|
||
def create_connection(self): | ||
"""Create a new connection instance. | ||
Raises: | ||
NotImplementedError: If create_connection is not implemented by the subclass. | ||
""" | ||
raise NotImplementedError("create_connection must be implemented by subclasses") | ||
|
||
|
||
class UnixSocketConnectionFactory(ConnectionFactory): | ||
"""Concrete factory for creating Unix socket connections.""" | ||
|
||
def __init__(self, unix_socket: str, timeout: int = 30): | ||
"""Initialize a new instance of UnixSocketConnectionFactory. | ||
Args: | ||
unix_socket (str): The file path to the Unix socket. | ||
timeout (int, optional): The timeout for the connection in seconds. | ||
Defaults to 30 seconds. | ||
""" | ||
self.unix_socket = unix_socket | ||
self.timeout = timeout | ||
|
||
@contextmanager | ||
def create_connection(self): | ||
"""Create and manage a Unix socket HTTP connection. | ||
Yields: | ||
UnixSocketHTTPConnection: The created Unix socket HTTP connection. | ||
""" | ||
conn = UnixSocketHTTPConnection(self.unix_socket, self.timeout) | ||
try: | ||
yield conn | ||
finally: | ||
conn.close() | ||
|
||
|
||
class HTTPConnectionFactory(ConnectionFactory): | ||
"""Concrete factory for creating HTTP connections.""" | ||
|
||
def __init__(self, host: str, port=None, timeout: int = 30): | ||
"""Initialize a new instance of HTTPConnectionFactory. | ||
Args: | ||
host (str): Hostname for the HTTP connection. | ||
port (int): Port for the HTTP connection. | ||
timeout (int, optional): The timeout for the connection in seconds. | ||
Defaults to 30 seconds. | ||
""" | ||
self.host = host | ||
self.port = port | ||
self.timeout = timeout | ||
|
||
@contextmanager | ||
def create_connection(self): | ||
"""Create and manage an HTTP connection. | ||
Yields: | ||
HTTPConnection: The created HTTP connection. | ||
""" | ||
conn = HTTPConnection(self.host, self.port, self.timeout) | ||
try: | ||
yield conn | ||
finally: | ||
conn.close() | ||
|
||
|
||
class K8sdAPIManager: | ||
"""Manager for K8sd API interactions.""" | ||
|
||
def __init__(self, factory: ConnectionFactory): | ||
"""Initialise the K8sdAPIManager. | ||
Args: | ||
factory (ConnectionFactory): An instance of a connection factory that will be used | ||
to create connections. This factory determines the type | ||
of connection (e.g., Unix socket or HTTP). | ||
""" | ||
self.factory = factory | ||
|
||
def _send_request(self, endpoint: str, method: str, body: dict, response_cls: Type[T]) -> T: | ||
"""Send a request to the k8sd API endpoint. | ||
Args: | ||
endpoint (str): The endpoint to send the request to. | ||
method (str): HTTP method for the request. | ||
body (dict): Body of the request. | ||
response_cls (Type[T]): The response class to deserialize the response into. | ||
Raises: | ||
K8sdConnectionError: If there's an HTTP or socket error. | ||
InvalidResponseError: If the response has invalid JSON or fails validation. | ||
Returns: | ||
T: An instance of the response class with the response data. | ||
""" | ||
try: | ||
with self.factory.create_connection() as connection: | ||
connection.request( | ||
method, | ||
endpoint, | ||
body=json.dumps(body), | ||
headers={"Content-Type": "application/json"}, | ||
) | ||
response = connection.getresponse() | ||
data = response.read().decode() | ||
return response_cls.parse_raw(data) | ||
|
||
except ValueError as e: | ||
raise InvalidResponseError(f"Request failed: {e}") from e | ||
except (socket.error, HTTPException) as e: | ||
raise K8sdConnectionError(f"HTTP or Socket error: {e}") from e | ||
|
||
def create_join_token(self, name: str): | ||
"""Create a join token. | ||
Args: | ||
name (str): Name of the node. | ||
Returns: | ||
str: The generated join token if successful. | ||
""" | ||
endpoint = "/1.0/k8sd/tokens" | ||
body = {"name": name} | ||
join_response = self._send_request(endpoint, "POST", body, CreateJoinTokenResponse) | ||
return join_response.metadata.token | ||
|
||
def enable_component(self, name: str, enable: bool): | ||
"""Enable or disable a k8s component. | ||
Args: | ||
name (str): Name of the component. | ||
enable (bool): True to enable, False to disable the component. | ||
""" | ||
endpoint = f"/1.0/k8sd/components/{name}" | ||
body = {"status": "enable" if enable else "disable"} | ||
self._send_request(endpoint, "PUT", body, UpdateComponentResponse) | ||
|
||
def request_auth_token(self, username: str, groups: List[str]) -> str: | ||
"""Request a Kubernetes authentication token. | ||
Args: | ||
username (str): Username for which the token is requested. | ||
groups (List[str]): Groups associated with the user. | ||
Returns: | ||
str: The authentication token. | ||
""" | ||
endpoint = "/1.0/kubernetes/auth/tokens" | ||
body = {"username": username, "groups": groups} | ||
auth_response = self._send_request(endpoint, "POST", body, AuthTokenResponse) | ||
return auth_response.metadata.token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters