Skip to content

Commit

Permalink
fix azure servicebus using managed identity support (#1801)
Browse files Browse the repository at this point in the history
* azure servicebus: use DefaultAzureCredential in documentation

* azure servicebus: only use connection string when using sas key

* azure servicebus: add two small tests for paring of connection string

* azure servicebus: fix lint issues
  • Loading branch information
marnikow authored Oct 4, 2023
1 parent 5f4c531 commit 0e445d1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
10 changes: 8 additions & 2 deletions kombu/transport/azureservicebus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
.. code-block::
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
azureservicebus://DefaultAzureIdentity@SERVICE_BUSNAMESPACE
azureservicebus://DefaultAzureCredential@SERVICE_BUSNAMESPACE
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
Transport Options
Expand Down Expand Up @@ -140,6 +140,11 @@ def __init__(self, *args, **kwargs):
def _try_parse_connection_string(self) -> None:
self._namespace, self._credential = Transport.parse_uri(
self.conninfo.hostname)

if (isinstance(self._credential, DefaultAzureCredential) or
isinstance(self._credential, ManagedIdentityCredential)):
return None

if ":" in self._credential:
self._policy, self._sas_key = self._credential.split(':', 1)

Expand Down Expand Up @@ -434,7 +439,8 @@ class Transport(virtual.Transport):
can_parse_url = True

@staticmethod
def parse_uri(uri: str) -> tuple[str, str, str]:
def parse_uri(uri: str) -> tuple[str, str | DefaultAzureCredential |
ManagedIdentityCredential]:
# URL like:
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
# urllib parse does not work as the sas key could contain a slash
Expand Down
53 changes: 40 additions & 13 deletions t/unit/transport/test_azureservicebus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import azure.servicebus.exceptions # noqa
from azure.servicebus import ServiceBusMessage, ServiceBusReceiveMode # noqa

try:
from azure.identity import (DefaultAzureCredential,
ManagedIdentityCredential)
except ImportError:
DefaultAzureCredential = None
ManagedIdentityCredential = None

from kombu.transport import azureservicebus # noqa


Expand Down Expand Up @@ -95,7 +102,9 @@ def get_queue_runtime_properties(self, queue_name):


URL_NOCREDS = 'azureservicebus://'
URL_CREDS = 'azureservicebus://policyname:ke/y@hostname'
URL_CREDS_SAS = 'azureservicebus://policyname:ke/y@hostname'
URL_CREDS_DA = 'azureservicebus://DefaultAzureCredential@hostname'
URL_CREDS_MI = 'azureservicebus://ManagedIdentityCredential@hostname'


def test_queue_service_nocredentials():
Expand All @@ -105,9 +114,9 @@ def test_queue_service_nocredentials():
assert exc == 'Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}' # noqa


def test_queue_service():
def test_queue_service_sas():
# Test gettings queue service without credentials
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
with patch('kombu.transport.azureservicebus.ServiceBusClient') as m:
channel = conn.channel()

Expand All @@ -126,20 +135,38 @@ def test_queue_service():
assert m.from_connection_string.call_count == 1


def test_queue_service_da():
conn = Connection(URL_CREDS_DA, transport=azureservicebus.Transport)
channel = conn.channel()

# Check the DefaultAzureCredential has been parsed from the url correctly
# and the credential is a ManagedIdentityCredential
assert isinstance(channel._credential, DefaultAzureCredential)


def test_queue_service_mi():
conn = Connection(URL_CREDS_MI, transport=azureservicebus.Transport)
channel = conn.channel()

# Check the ManagedIdentityCredential has been parsed from the url
# correctly and the credential is a ManagedIdentityCredential
assert isinstance(channel._credential, ManagedIdentityCredential)


def test_conninfo():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()
assert channel.conninfo is conn


def test_transport_type():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()
assert not channel.transport_options


def test_default_wait_timeout_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()

assert channel.wait_time_seconds == \
Expand All @@ -148,7 +175,7 @@ def test_default_wait_timeout_seconds():

def test_custom_wait_timeout_seconds():
conn = Connection(
URL_CREDS,
URL_CREDS_SAS,
transport=azureservicebus.Transport,
transport_options={'wait_time_seconds': 10}
)
Expand All @@ -158,15 +185,15 @@ def test_custom_wait_timeout_seconds():


def test_default_peek_lock_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()

assert channel.peek_lock_seconds == \
azureservicebus.Channel.default_peek_lock_seconds


def test_custom_peek_lock_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport,
transport_options={'peek_lock_seconds': 65})
channel = conn.channel()

Expand All @@ -175,7 +202,7 @@ def test_custom_peek_lock_seconds():

def test_invalid_peek_lock_seconds():
# Max is 300
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport,
transport_options={'peek_lock_seconds': 900})
channel = conn.channel()

Expand Down Expand Up @@ -230,7 +257,7 @@ def mock_clients(
def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue:
exchange = Exchange('test_servicebus', type='direct')
queue = Queue(random_queue, exchange, random_queue)
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()

queue(channel).declare()
Expand Down Expand Up @@ -312,7 +339,7 @@ def test_purge(mock_queue: MockQueue):

def test_custom_queue_name_prefix():
conn = Connection(
URL_CREDS,
URL_CREDS_SAS,
transport=azureservicebus.Transport,
transport_options={'queue_name_prefix': 'test-queue'}
)
Expand All @@ -322,7 +349,7 @@ def test_custom_queue_name_prefix():


def test_custom_entity_name():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
conn = Connection(URL_CREDS_SAS, transport=azureservicebus.Transport)
channel = conn.channel()

# dashes allowed and dots replaced by dashes
Expand Down

0 comments on commit 0e445d1

Please sign in to comment.