Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssl connection #178

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,14 @@ def _conn_state_change(self, node_id, sock, conn):
try:
self._selector.register(sock, selectors.EVENT_WRITE, conn)
except KeyError:
self._selector.modify(sock, selectors.EVENT_WRITE, conn)
# SSL detaches the original socket, and transfers the
# underlying file descriptor to a new SSLSocket. We should
# explicitly unregister the original socket.
if conn.state == ConnectionStates.HANDSHAKE:
self._selector.unregister(sock)
self._selector.register(sock, selectors.EVENT_WRITE, conn)
else:
self._selector.modify(sock, selectors.EVENT_WRITE, conn)

if self.cluster.is_bootstrap(node_id):
self._last_bootstrap = time.time()
Expand Down
4 changes: 2 additions & 2 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ def connect(self):

if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
log.debug('%s: initiating SSL handshake', self)
self.state = ConnectionStates.HANDSHAKE
self.config['state_change_callback'](self.node_id, self._sock, self)
# _wrap_ssl can alter the connection state -- disconnects on failure
self._wrap_ssl()
self.state = ConnectionStates.HANDSHAKE
self.config['state_change_callback'](self.node_id, self._sock, self)

elif self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.debug('%s: initiating SASL authentication', self)
Expand Down
9 changes: 7 additions & 2 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def gen_ssl_resources(directory):

# Step 1
keytool -keystore kafka.server.keystore.jks -alias localhost -validity 1 \
-genkey -storepass foobar -keypass foobar \
-genkey -keyalg RSA -storepass foobar -keypass foobar \
-dname "CN=localhost, OU=kafka-python, O=kafka-python, L=SF, ST=CA, C=US" \
-ext SAN=dns:localhost

Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self, host, port, broker_id, zookeeper, zk_chroot,
self.sasl_mechanism = sasl_mechanism.upper()
else:
self.sasl_mechanism = None
self.ssl_dir = self.test_resource('ssl')
self.ssl_dir = None

# TODO: checking for port connection would be better than scanning logs
# until then, we need the pattern to work across all supported broker versions
Expand Down Expand Up @@ -410,6 +410,8 @@ def start(self):
jaas_conf = self.tmp_dir.join("kafka_server_jaas.conf")
properties_template = self.test_resource("kafka.properties")
jaas_conf_template = self.test_resource("kafka_server_jaas.conf")
self.ssl_dir = self.tmp_dir
gen_ssl_resources(self.ssl_dir.strpath)
wbarnha marked this conversation as resolved.
Show resolved Hide resolved

args = self.kafka_run_class_args("kafka.Kafka", properties.strpath)
env = self.kafka_run_class_env()
Expand Down Expand Up @@ -641,6 +643,9 @@ def _enrich_client_params(self, params, **defaults):
if self.sasl_mechanism in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'):
params.setdefault('sasl_plain_username', self.broker_user)
params.setdefault('sasl_plain_password', self.broker_password)
if self.transport in ["SASL_SSL", "SSL"]:
params.setdefault("ssl_cafile", self.ssl_dir.join('ca-cert').strpath)
params.setdefault("security_protocol", self.transport)
return params

@staticmethod
Expand Down
67 changes: 67 additions & 0 deletions test/test_ssl_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import uuid

import pytest

from kafka.admin import NewTopic
from kafka.protocol.metadata import MetadataRequest_v1
from test.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore


@pytest.fixture(scope="module")
def ssl_kafka(request, kafka_broker_factory):
return kafka_broker_factory(transport="SSL")[0]


@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9")
def test_admin(request, ssl_kafka):
topic_name = special_to_underscore(request.node.name + random_string(4))
admin, = ssl_kafka.get_admin_clients(1)
admin.create_topics([NewTopic(topic_name, 1, 1)])
assert topic_name in ssl_kafka.get_topic_names()


@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9")
def test_produce_and_consume(request, ssl_kafka):
topic_name = special_to_underscore(request.node.name + random_string(4))
ssl_kafka.create_topics([topic_name], num_partitions=2)
producer, = ssl_kafka.get_producers(1)

messages_and_futures = [] # [(message, produce_future),]
for i in range(100):
encoded_msg = "{}-{}-{}".format(i, request.node.name, uuid.uuid4()).encode("utf-8")
future = producer.send(topic_name, value=encoded_msg, partition=i % 2)
messages_and_futures.append((encoded_msg, future))
producer.flush()

for (msg, f) in messages_and_futures:
assert f.succeeded()

consumer, = ssl_kafka.get_consumers(1, [topic_name])
messages = {0: [], 1: []}
for i, message in enumerate(consumer, 1):
logging.debug("Consumed message %s", repr(message))
messages[message.partition].append(message)
if i >= 100:
break

assert_message_count(messages[0], 50)
assert_message_count(messages[1], 50)


@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9")
def test_client(request, ssl_kafka):
topic_name = special_to_underscore(request.node.name + random_string(4))
ssl_kafka.create_topics([topic_name], num_partitions=1)

client, = ssl_kafka.get_clients(1)
request = MetadataRequest_v1(None)
client.send(0, request)
for _ in range(10):
result = client.poll(timeout_ms=10000)
if len(result) > 0:
break
else:
raise RuntimeError("Couldn't fetch topic response from Broker.")
result = result[0]
assert topic_name in [t[1] for t in result.topics]
Loading