Skip to content

Commit

Permalink
Fix TLS flag (#44)
Browse files Browse the repository at this point in the history
* fix early tls deployment by only reloading patroni config if it's already running

* added unit test for reloading patroni

* lint

* removing postgres restart check

* adding series flags to test apps

* adding series flags to test apps

* made series into a list

* Update test_new_relations.py

* updating test to better emulate bundle deploymen

* updated tls lib

* Fix TLS IP address on CSR

* removed failing test for now. This will be fixed before merge

* Fix TLS

* Update TLS lib

* Remove unneeded series

* Add comment

* Add Juju agent version bootstrap constraint

* Add test for restart method

* Update TLS lib

* Add test for update config method

* Update TLS lib

* Improve comment

Co-authored-by: WRFitch <[email protected]>
Co-authored-by: Will Fitch <[email protected]>
  • Loading branch information
3 people authored Oct 27, 2022
1 parent 4994516 commit 5c6e5f7
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 25 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:
uses: charmed-kubernetes/actions-operator@main
with:
provider: lxd
# This is needed until https://bugs.launchpad.net/juju/+bug/1992833 is fixed.
bootstrap-options: "--agent-version 2.9.34"
- name: Run integration tests
run: tox -e database-relation-integration

Expand Down Expand Up @@ -105,6 +107,8 @@ jobs:
uses: charmed-kubernetes/actions-operator@main
with:
provider: lxd
# This is needed until https://bugs.launchpad.net/juju/+bug/1992833 is fixed.
bootstrap-options: "--agent-version 2.9.34"
- name: Run integration tests
run: tox -e ha-self-healing-integration

Expand Down
48 changes: 37 additions & 11 deletions lib/charms/postgresql_k8s/v0/postgresql_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import base64
import ipaddress
import logging
import re
import socket
Expand All @@ -34,7 +35,7 @@
from cryptography.x509.extensions import ExtensionType
from ops.charm import ActionEvent
from ops.framework import Object
from ops.pebble import PathError, ProtocolError
from ops.pebble import ConnectionError, PathError, ProtocolError

# The unique Charmhub library identifier, never change it
LIBID = "c27af44a92df4ef38d7ae06418b2800f"
Expand All @@ -44,7 +45,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version.
LIBPATCH = 1
LIBPATCH = 4

logger = logging.getLogger(__name__)
SCOPE = "unit"
Expand All @@ -54,11 +55,14 @@
class PostgreSQLTLS(Object):
"""In this class we manage certificates relation."""

def __init__(self, charm, peer_relation):
def __init__(
self, charm, peer_relation: str, additional_dns_names: Optional[List[str]] = None
):
"""Manager of PostgreSQL relation with TLS Certificates Operator."""
super().__init__(charm, "client-relations")
self.charm = charm
self.peer_relation = peer_relation
self.additional_dns_names = additional_dns_names or []
self.certs = TLSCertificatesRequiresV1(self.charm, TLS_RELATION)
self.framework.observe(
self.charm.on.set_tls_private_key_action, self._on_set_tls_private_key
Expand Down Expand Up @@ -86,8 +90,8 @@ def _request_certificate(self, param: Optional[str]):
csr = generate_csr(
private_key=key,
subject=self.charm.get_hostname_by_unit(self.charm.unit.name),
sans=self._get_sans(),
additional_critical_extensions=self._get_tls_extensions(),
**self._get_sans(),
)

self.charm.set_secret(SCOPE, "key", key.decode("utf-8"))
Expand Down Expand Up @@ -133,7 +137,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:

try:
self.charm.push_tls_files_to_workload()
except (PathError, ProtocolError) as e:
except (ConnectionError, PathError, ProtocolError) as e:
logger.error("Cannot push TLS certificates: %r", e)
event.defer()
return
Expand All @@ -149,35 +153,57 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
new_csr = generate_csr(
private_key=key,
subject=self.charm.get_hostname_by_unit(self.charm.unit.name),
sans=self._get_sans(),
additional_critical_extensions=self._get_tls_extensions(),
**self._get_sans(),
)
self.certs.request_certificate_renewal(
old_certificate_signing_request=old_csr,
new_certificate_signing_request=new_csr,
)
self.charm.set_secret(SCOPE, "csr", new_csr.decode("utf-8"))

def _get_sans(self) -> List[str]:
"""Create a list of DNS names for a PostgreSQL unit.
def _get_sans(self) -> dict:
"""Create a list of Subject Alternative Names for a PostgreSQL unit.
Returns:
A list representing the hostnames of the PostgreSQL unit.
A list representing the IP and hostnames of the PostgreSQL unit.
"""

def is_ip_address(address: str) -> bool:
"""Returns whether and address is an IP address."""
try:
ipaddress.ip_address(address)
return True
except (ipaddress.AddressValueError, ValueError):
return False

unit_id = self.charm.unit.name.split("/")[1]
return [

# Create a list of all the Subject Alternative Names.
sans = [
f"{self.charm.app.name}-{unit_id}",
self.charm.get_hostname_by_unit(self.charm.unit.name),
socket.getfqdn(),
str(self.charm.model.get_binding(self.peer_relation).network.bind_address),
]
sans.extend(self.additional_dns_names)

# Separate IP addresses and DNS names.
sans_ip = [san for san in sans if is_ip_address(san)]
sans_dns = [san for san in sans if not is_ip_address(san)]

return {
"sans_ip": sans_ip,
"sans_dns": sans_dns,
}

@staticmethod
def _get_tls_extensions() -> Optional[List[ExtensionType]]:
"""Return a list of TLS extensions for which certificate key can be used."""
basic_constraints = x509.BasicConstraints(ca=True, path_length=None)
return [basic_constraints]

def get_tls_files(self) -> (Optional[str], Optional[str]):
def get_tls_files(self) -> (Optional[str], Optional[str], Optional[str]):
"""Prepare TLS files in special PostgreSQL way.
PostgreSQL needs three files:
Expand Down
24 changes: 18 additions & 6 deletions lib/charms/tls_certificates_interface/v1/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
import logging
import uuid
from datetime import datetime, timedelta
from ipaddress import IPv4Address
from typing import Dict, List, Optional

from cryptography import x509
Expand Down Expand Up @@ -657,7 +658,9 @@ def generate_csr(
email_address: str = None,
country_name: str = None,
private_key_password: Optional[bytes] = None,
sans: Optional[List[str]] = None,
sans_oid: Optional[str] = None,
sans_ip: Optional[List[str]] = None,
sans_dns: Optional[List[str]] = None,
additional_critical_extensions: Optional[List] = None,
) -> bytes:
"""Generates a CSR using private key and subject.
Expand All @@ -672,7 +675,9 @@ def generate_csr(
email_address (str): Email address.
country_name (str): Country Name.
private_key_password (bytes): Private key password
sans (list): List of subject alternative names
sans_dns (list): List of DNS subject alternative names
sans_ip (list): List of IP subject alternative names
sans_oid (str): Additional OID
additional_critical_extensions (list): List if critical additional extension objects.
Object must be a x509 ExtensionType.
Expand All @@ -693,10 +698,17 @@ def generate_csr(
if country_name:
subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name))
if sans:
csr = csr.add_extension(
x509.SubjectAlternativeName([x509.DNSName(san) for san in sans]), critical=False
)

_sans = []
if sans_oid:
_sans.append(x509.RegisteredID(x509.ObjectIdentifier(sans_oid)))
if sans_ip:
_sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip])
if sans_dns:
_sans.extend([x509.DNSName(san) for san in sans_dns])
if _sans:
csr = csr.add_extension(x509.SubjectAlternativeName(_sans), critical=False)

if additional_critical_extensions:
for extension in additional_critical_extensions:
csr = csr.add_extension(extension, critical=True)
Expand Down
14 changes: 10 additions & 4 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,12 @@ def push_tls_files_to_workload(self) -> None:
self.update_config()

def _restart(self, _) -> None:
"""Restart Patroni and PostgreSQL."""
if not self._patroni.restart_patroni():
logger.exception("failed to restart PostgreSQL")
self.unit.status = BlockedStatus("failed to restart Patroni and PostgreSQL")
"""Restart PostgreSQL."""
try:
self._patroni.restart_postgresql()
except RetryError as e:
logger.error("failed to restart PostgreSQL")
self.unit.status = BlockedStatus(f"failed to restart PostgreSQL with error {e}")

def update_config(self) -> None:
"""Updates Patroni config file based on the existence of the TLS files."""
Expand All @@ -833,6 +835,10 @@ def update_config(self) -> None:
# Update and reload configuration based on TLS files availability.
self._patroni.render_patroni_yml_file(enable_tls=enable_tls)
if not self._patroni.member_started:
# If Patroni/PostgreSQL has not started yet and TLS relations was initialised,
# then mark TLS as enabled. This commonly happens when the charm is deployed
# in a bundle together with the TLS certificates operator.
self.unit_peer_data.update({"tls": "enabled" if enable_tls else ""})
return

restart_postgresql = enable_tls != self.postgresql.is_tls_enabled()
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ async def test_deploy_active(ops_test: OpsTest):
charm, resources={"patroni": "patroni.tar.gz"}, application_name=APP_NAME, num_units=3
)
await ops_test.juju("attach-resource", APP_NAME, "patroni=patroni.tar.gz")
await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000)
# No wait between deploying charms, since we can't guarantee users will wait. Furthermore,
# bundles don't wait between deploying charms.


@pytest.mark.tls_tests
Expand All @@ -36,9 +37,6 @@ async def test_tls_enabled(ops_test: OpsTest) -> None:
# Deploy TLS Certificates operator.
config = {"generate-self-signed-certificates": "true", "ca-common-name": "Test CA"}
await ops_test.model.deploy(TLS_CERTIFICATES_APP_NAME, channel="edge", config=config)
await ops_test.model.wait_for_idle(
apps=[TLS_CERTIFICATES_APP_NAME], status="active", timeout=1000
)

# Relate it to the PostgreSQL to enable TLS.
await ops_test.model.relate(DATABASE_APP_NAME, TLS_CERTIFICATES_APP_NAME)
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from ops.model import ActiveStatus, BlockedStatus, WaitingStatus
from ops.testing import Harness
from tenacity import RetryError

from charm import PostgresqlOperatorCharm
from constants import PEER
Expand Down Expand Up @@ -424,3 +425,74 @@ def test_set_secret(self, _):
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["password"]
== "test-password"
)

@patch_network_get(private_address="1.1.1.1")
@patch("charm.Patroni.restart_postgresql")
def test_restart(self, _restart_postgresql):
# Test a successful restart.
self.charm._restart(None)
self.assertFalse(isinstance(self.charm.unit.status, BlockedStatus))

# Test a failed restart.
_restart_postgresql.side_effect = RetryError(last_attempt=1)
self.charm._restart(None)
self.assertTrue(isinstance(self.charm.unit.status, BlockedStatus))

@patch_network_get(private_address="1.1.1.1")
@patch("charms.rolling_ops.v0.rollingops.RollingOpsManager._on_acquire_lock")
@patch("charm.Patroni.reload_patroni_configuration")
@patch("charm.Patroni.member_started", new_callable=PropertyMock)
@patch("charm.Patroni.render_patroni_yml_file")
@patch("charms.postgresql_k8s.v0.postgresql_tls.PostgreSQLTLS.get_tls_files")
def test_update_config(
self,
_get_tls_files,
_render_patroni_yml_file,
_member_started,
_reload_patroni_configuration,
_restart,
):
with patch.object(PostgresqlOperatorCharm, "postgresql", Mock()) as postgresql_mock:
# Mock some properties.
postgresql_mock.is_tls_enabled = PropertyMock(side_effect=[False, False, False])
_member_started.side_effect = [True, True, False]

# Test without TLS files available.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": "enabled"}
) # Mock some data in the relation to test that it change.
_get_tls_files.return_value = [None]
self.charm.update_config()
_render_patroni_yml_file.assert_called_once_with(enable_tls=False)
_reload_patroni_configuration.assert_called_once()
_restart.assert_not_called()
self.assertNotIn(
"tls", self.harness.get_relation_data(self.rel_id, self.charm.unit.name)
)

# Test with TLS files available.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": ""}
) # Mock some data in the relation to test that it change.
_get_tls_files.return_value = ["something"]
_render_patroni_yml_file.reset_mock()
_reload_patroni_configuration.reset_mock()
self.charm.update_config()
_render_patroni_yml_file.assert_called_once_with(enable_tls=True)
_reload_patroni_configuration.assert_called_once()
_restart.assert_called_once()
self.assertEqual(
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["tls"], "enabled"
)

# Test with member not started yet.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": ""}
) # Mock some data in the relation to test that it change.
_reload_patroni_configuration.reset_mock()
self.charm.update_config()
_reload_patroni_configuration.assert_not_called()
_restart.assert_called_once()
self.assertEqual(
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["tls"], "enabled"
)

0 comments on commit 5c6e5f7

Please sign in to comment.