Skip to content

Commit

Permalink
Merge branch 'master' into 2.0.x
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Oct 8, 2020
2 parents 4c09656 + cc8a264 commit 75266ea
Show file tree
Hide file tree
Showing 26 changed files with 265 additions and 183 deletions.
1 change: 1 addition & 0 deletions changelog/6943.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update `KafkaEventBroker` to support `SASL_SSL` and `PLAINTEXT` protocols.
18 changes: 18 additions & 0 deletions data/examples/wit/demo-flights.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
{
"text" : "i'm looking for a flight from london to amsterdam next monday",
"entities" : [
{
"entity" : "intent",
"value" : "\"flight_booking\"",
"start" : 0,
"end" : 42
},
{
"entity" : "location",
"value" : "\"london\"",
Expand All @@ -59,6 +65,12 @@
{
"text" : "i want to fly to berlin",
"entities" : [
{
"entity" : "intent",
"value" : "\"flight_booking\"",
"start" : 0,
"end" : 42
},
{
"entity" : "location",
"value" : "\"berlin\"",
Expand All @@ -71,6 +83,12 @@
{
"text" : "i want to fly from london",
"entities" : [
{
"entity" : "intent",
"value" : "\"flight_booking\"",
"start" : 0,
"end" : 42
},
{
"entity" : "location",
"value" : "\"london\"",
Expand Down
9 changes: 6 additions & 3 deletions data/test/wit_converted_to_rasa.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
]
},
{
"text": "i'm looking for a flight from london to amsterdam next monday",
"text": "i'm looking for a flight from london to amsterdam next monday",
"intent": "flight_booking",
"entities": [
{
"start": 30,
Expand All @@ -53,7 +54,8 @@
]
},
{
"text": "i want to fly to berlin",
"text": "i want to fly to berlin",
"intent": "flight_booking",
"entities": [
{
"start": 17,
Expand All @@ -65,7 +67,8 @@
]
},
{
"text": "i want to fly from london",
"text": "i want to fly from london",
"intent": "flight_booking",
"entities": [
{
"start": 19,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
event_broker:
security_protocol: SOMETHING
type: kafka
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
event_broker:
url: localhost
sasl_username: username
sasl_password: password
topic: topic
security_protocol: SASL_PLAINTEXT
client_id: kafka-python-rasa
security_protocol: PLAINTEXT
type: kafka
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
event_broker:
client_id: kafka-python-rasa
security_protocol: PLAINTEXT
type: kafka
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
event_broker:
url: localhost
sasl_username: username
sasl_password: password
topic: topic
security_protocol: SASL_PLAINTEXT
type: kafka
7 changes: 7 additions & 0 deletions data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
event_broker:
url: localhost
sasl_username: username
sasl_password: password
topic: topic
security_protocol: SASL_SSL
type: kafka
4 changes: 0 additions & 4 deletions data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,4 @@ event_broker:
url: localhost
topic: topic
security_protocol: SSL
ssl_cafile: CARoot.pem
ssl_certfile: certificate.pem
ssl_keyfile: key.pem
ssl_check_hostname: True
type: kafka
101 changes: 85 additions & 16 deletions rasa/core/brokers/kafka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional
from typing import Any, Text, List, Optional, Union, Dict

from rasa.core.brokers.broker import EventBroker
from rasa.shared.utils.io import DEFAULT_ENCODING
Expand All @@ -11,29 +11,69 @@
class KafkaEventBroker(EventBroker):
def __init__(
self,
host,
sasl_username=None,
sasl_password=None,
ssl_cafile=None,
ssl_certfile=None,
ssl_keyfile=None,
ssl_check_hostname=False,
topic="rasa_core_events",
security_protocol="SASL_PLAINTEXT",
loglevel=logging.ERROR,
url: Union[Text, List[Text], None],
topic: Text = "rasa_core_events",
client_id: Optional[Text] = None,
group_id: Optional[Text] = None,
sasl_username: Optional[Text] = None,
sasl_password: Optional[Text] = None,
ssl_cafile: Optional[Text] = None,
ssl_certfile: Optional[Text] = None,
ssl_keyfile: Optional[Text] = None,
ssl_check_hostname: bool = False,
security_protocol: Text = "SASL_PLAINTEXT",
loglevel: Union[int, Text] = logging.ERROR,
**kwargs: Any,
) -> None:
"""Kafka event broker.
Args:
url: 'url[:port]' string (or list of 'url[:port]'
strings) that the producer should contact to bootstrap initial
cluster metadata. This does not have to be the full node list.
It just needs to have at least one broker that will respond to a
Metadata API Request.
topic: Topics to subscribe to.
client_id: A name for this client. This string is passed in each request
to servers and can be used to identify specific server-side log entries
that correspond to this client. Also submitted to `GroupCoordinator` for
logging with respect to producer group administration.
group_id: The name of the producer group to join for dynamic partition
assignment (if enabled), and to use for fetching and committing offsets.
If None, auto-partition assignment (via group coordinator) and offset
commits are disabled.
sasl_username: Username for plain authentication.
sasl_password: Password for plain authentication.
ssl_cafile: Optional filename of ca file to use in certificate
verification.
ssl_certfile: Optional filename of file in pem format containing
the client certificate, as well as any ca certificates needed to
establish the certificate's authenticity.
ssl_keyfile: Optional filename containing the client private key.
ssl_check_hostname: Flag to configure whether ssl handshake
should verify that the certificate matches the brokers hostname.
security_protocol: Protocol used to communicate with brokers.
Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL.
loglevel: Logging level of the kafka logger.
"""
import kafka

self.producer = None
self.host = host
self.url = url
self.topic = topic
self.security_protocol = security_protocol
self.client_id = client_id
self.group_id = group_id
self.security_protocol = security_protocol.upper()
self.sasl_username = sasl_username
self.sasl_password = sasl_password
self.ssl_cafile = ssl_cafile
self.ssl_certfile = ssl_certfile
self.ssl_keyfile = ssl_keyfile
self.ssl_check_hostname = ssl_check_hostname

self.producer: Optional[kafka.KafkaConsumer] = None

logging.getLogger("kafka").setLevel(loglevel)

@classmethod
Expand All @@ -51,9 +91,18 @@ def publish(self, event) -> None:
def _create_producer(self) -> None:
import kafka

if self.security_protocol == "SASL_PLAINTEXT":
if self.security_protocol == "PLAINTEXT":
self.producer = kafka.KafkaConsumer(
self.topic,
bootstrap_servers=self.url,
client_id=self.client_id,
group_id=self.group_id,
security_protocol="PLAINTEXT",
ssl_check_hostname=False,
)
elif self.security_protocol == "SASL_PLAINTEXT":
self.producer = kafka.KafkaProducer(
bootstrap_servers=[self.host],
bootstrap_servers=self.url,
value_serializer=lambda v: json.dumps(v).encode(DEFAULT_ENCODING),
sasl_plain_username=self.sasl_username,
sasl_plain_password=self.sasl_password,
Expand All @@ -62,14 +111,34 @@ def _create_producer(self) -> None:
)
elif self.security_protocol == "SSL":
self.producer = kafka.KafkaProducer(
bootstrap_servers=[self.host],
bootstrap_servers=self.url,
value_serializer=lambda v: json.dumps(v).encode(DEFAULT_ENCODING),
ssl_cafile=self.ssl_cafile,
ssl_certfile=self.ssl_certfile,
ssl_keyfile=self.ssl_keyfile,
ssl_check_hostname=False,
security_protocol=self.security_protocol,
)
elif self.security_protocol == "SASL_SSL":
self.producer = kafka.KafkaConsumer(
self.topic,
bootstrap_servers=self.url,
client_id=self.client_id,
group_id=self.group_id,
security_protocol="SASL_SSL",
sasl_mechanism="PLAIN",
sasl_plain_username=self.sasl_username,
sasl_plain_password=self.sasl_password,
ssl_cafile=self.ssl_cafile,
ssl_certfile=self.ssl_certfile,
ssl_keyfile=self.ssl_keyfile,
ssl_check_hostname=self.ssl_check_hostname,
)
else:
raise ValueError(
f"Cannot initialise `KafkaEventBroker`: "
f"Invalid `security_protocol` ('{self.security_protocol}')."
)

def _publish(self, event) -> None:
self.producer.send(self.topic, event)
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ def get_eval_data(

should_eval_entities = is_entity_extractor_present(interpreter)

for example in tqdm(test_data.training_examples):
for example in tqdm(test_data.nlu_examples):
result = interpreter.parse(example.get(TEXT), only_output_properties=False)

if should_eval_intents:
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def compare_nlu(
_, train_included = train.train_test_split(percentage / 100)
# only count for the first run and ignore the others
if run == 0:
training_examples_per_run.append(len(train_included.training_examples))
training_examples_per_run.append(len(train_included.nlu_examples))

model_output_path = os.path.join(run_path, percent_string)
train_split_path = os.path.join(model_output_path, "train")
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/utils/bilou_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def build_tag_id_dict(
distinct_tags = set(
[
tag_without_prefix(e)
for example in training_data.training_examples
for example in training_data.nlu_examples
if example.get(bilou_key)
for e in example.get(bilou_key)
]
Expand All @@ -157,7 +157,7 @@ def apply_bilou_schema(training_data: TrainingData) -> None:
Args:
training_data: the training data
"""
for message in training_data.training_examples:
for message in training_data.nlu_examples:
entities = message.get(ENTITIES)

if not entities:
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ class ActionExecuted(Event):

def __init__(
self,
action_name: Text,
action_name: Optional[Text] = None,
policy: Optional[Text] = None,
confidence: Optional[float] = None,
timestamp: Optional[float] = None,
Expand Down
17 changes: 11 additions & 6 deletions rasa/shared/importers/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rasa.shared.core.training_data.structures import StoryGraph
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.constants import INTENT_NAME, TEXT
from rasa.shared.nlu.constants import ENTITIES, ACTION_NAME
from rasa.shared.importers.autoconfig import TrainingType
from rasa.shared.core.domain import IS_RETRIEVAL_INTENT_KEY

Expand Down Expand Up @@ -533,18 +533,23 @@ def _unique_events_from_stories(


def _messages_from_user_utterance(event: UserUttered) -> Message:
return Message(data={TEXT: event.text, INTENT_NAME: event.intent_name})
# sub state correctly encodes intent vs text
data = event.as_sub_state()
# sub state stores entities differently
if data.get(ENTITIES) and event.entities:
data[ENTITIES] = event.entities

return Message(data=data)


def _messages_from_action(event: ActionExecuted) -> Message:
return Message.build_from_action(
action_name=event.action_name, action_text=event.action_text or ""
)
# sub state correctly encodes action_name vs action_text
return Message(data=event.as_sub_state())


def _additional_training_data_from_default_actions() -> TrainingData:
additional_messages_from_default_actions = [
Message.build_from_action(action_name=action_name)
Message(data={ACTION_NAME: action_name})
for action_name in rasa.shared.core.constants.DEFAULT_ACTION_NAMES
]

Expand Down
1 change: 0 additions & 1 deletion rasa/shared/nlu/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
INTENT_RESPONSE_KEY = "intent_response_key"
ACTION_TEXT = "action_text"
ACTION_NAME = "action_name"
INTENT_NAME = "intent_name"
INTENT_NAME_KEY = "name"
METADATA = "metadata"
METADATA_INTENT = "intent"
Expand Down
24 changes: 2 additions & 22 deletions rasa/shared/nlu/training_data/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
FEATURE_TYPE_SEQUENCE,
ACTION_TEXT,
ACTION_NAME,
INTENT_NAME,
)

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -132,24 +131,6 @@ def build(
# pytype: enable=unsupported-operands
return cls(data, **kwargs)

@classmethod
def build_from_action(
cls,
action_text: Optional[Text] = "",
action_name: Optional[Text] = "",
**kwargs: Any,
) -> "Message":
"""
Build a `Message` from `ActionExecuted` data.
Args:
action_text: text of a bot's utterance
action_name: name of an action executed
Returns:
Message
"""
action_data = {ACTION_TEXT: action_text, ACTION_NAME: action_name}
return cls(data=action_data, **kwargs)

def get_full_intent(self) -> Text:
"""Get intent as it appears in training data"""

Expand Down Expand Up @@ -330,9 +311,8 @@ def is_core_message(self) -> bool:
Returns:
True, if message is a core message, false otherwise.
"""
return (
self.data.get(ACTION_NAME) is not None
or self.data.get(INTENT_NAME) is not None
return bool(
self.data.get(ACTION_NAME)
or self.data.get(ACTION_TEXT)
or (
(self.data.get(INTENT) or self.data.get(RESPONSE))
Expand Down
Loading

0 comments on commit 75266ea

Please sign in to comment.