From 0f312980a92a8d13e2dcee0dfd840cc0ec28a1f2 Mon Sep 17 00:00:00 2001 From: alwx Date: Wed, 7 Oct 2020 11:26:20 +0200 Subject: [PATCH 1/6] Update `KafkaEventBroker` to support `SASL_SSL` and `PLAINTEXT` protocols. --- rasa/core/brokers/kafka.py | 105 +++++++++++++++++++++++++++++++------ tests/core/test_broker.py | 6 +-- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/rasa/core/brokers/kafka.py b/rasa/core/brokers/kafka.py index fb54f9ab36ac..c57fcb8e86ff 100644 --- a/rasa/core/brokers/kafka.py +++ b/rasa/core/brokers/kafka.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional +from typing import Any, Text, List, Optional, Union, Dict from rasa.core.brokers.broker import EventBroker from rasa.shared.utils.io import DEFAULT_ENCODING @@ -11,22 +11,66 @@ class KafkaEventBroker(EventBroker): def __init__( self, - host, - sasl_username=None, - sasl_password=None, - ssl_cafile=None, - ssl_certfile=None, - ssl_keyfile=None, - ssl_check_hostname=False, - topic="rasa_core_events", - security_protocol="SASL_PLAINTEXT", - loglevel=logging.ERROR, + url: Union[Text, List[Text], None], + topic: Text = "rasa_core_events", + client_id: Optional[Text] = None, + group_id: Optional[Text] = None, + sasl_username: Union[Text, int, None] = None, + sasl_password: Optional[Text] = None, + ssl_cafile: Optional[Text] = None, + ssl_certfile: Optional[Text] = None, + ssl_keyfile: Optional[Text] = None, + ssl_check_hostname: bool = False, + security_protocol: Text = "SASL_PLAINTEXT", + loglevel: Union[int, Text] = logging.ERROR, + **kwargs: Any, ) -> None: + """Kafka event broker. + Args: + url: 'url[:port]' string (or list of 'url[:port]' + strings) that the consumer should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. The default port is 9092. If no servers are + specified, it will default to `localhost:9092`. + topic: Topics to subscribe to. If not set, call subscribe() or assign() + before consuming records + client_id: A name for this client. This string is passed in each request + to servers and can be used to identify specific server-side log entries + that correspond to this client. Also submitted to `GroupCoordinator` for + logging with respect to consumer group administration. + Default: ‘kafka-python-{version}’ + group_id: The name of the consumer group to join for dynamic partition + assignment (if enabled), and to use for fetching and committing offsets. + If None, auto-partition assignment (via group coordinator) and offset + commits are disabled. Default: None + sasl_username: Username for sasl PLAIN authentication. + Required if `sasl_mechanism` is `PLAIN`. + sasl_password: Password for sasl PLAIN authentication. + Required if `sasl_mechanism` is PLAIN. + ssl_cafile: Optional filename of ca file to use in certificate + verification. Default: None. + ssl_certfile: Optional filename of file in pem format containing + the client certificate, as well as any ca certificates needed to + establish the certificate's authenticity. Default: None. + ssl_keyfile: Optional filename containing the client private key. + Default: None. + ssl_check_hostname: Flag to configure whether ssl handshake + should verify that the certificate matches the brokers hostname. + Default: False. + security_protocol: Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + loglevel: Logging level of the kafka logger. + + """ self.producer = None - self.host = host + self.url = url self.topic = topic - self.security_protocol = security_protocol + self.client_id = client_id + self.group_id = group_id + self.security_protocol = security_protocol.upper() self.sasl_username = sasl_username self.sasl_password = sasl_password self.ssl_cafile = ssl_cafile @@ -51,9 +95,18 @@ def publish(self, event) -> None: def _create_producer(self) -> None: import kafka - if self.security_protocol == "SASL_PLAINTEXT": + if self.security_protocol == "PLAINTEXT": + self.consumer = kafka.KafkaConsumer( + self.topic, + bootstrap_servers=self.url, + client_id=self.client_id, + group_id=self.group_id, + security_protocol="PLAINTEXT", + ssl_check_hostname=False, + ) + elif self.security_protocol == "SASL_PLAINTEXT": self.producer = kafka.KafkaProducer( - bootstrap_servers=[self.host], + bootstrap_servers=self.url, value_serializer=lambda v: json.dumps(v).encode(DEFAULT_ENCODING), sasl_plain_username=self.sasl_username, sasl_plain_password=self.sasl_password, @@ -62,7 +115,7 @@ def _create_producer(self) -> None: ) elif self.security_protocol == "SSL": self.producer = kafka.KafkaProducer( - bootstrap_servers=[self.host], + bootstrap_servers=self.url, value_serializer=lambda v: json.dumps(v).encode(DEFAULT_ENCODING), ssl_cafile=self.ssl_cafile, ssl_certfile=self.ssl_certfile, @@ -70,6 +123,26 @@ def _create_producer(self) -> None: ssl_check_hostname=False, security_protocol=self.security_protocol, ) + elif self.security_protocol == "SASL_SSL": + self.consumer = kafka.KafkaConsumer( + self.topic, + bootstrap_servers=self.url, + client_id=self.client_id, + group_id=self.group_id, + security_protocol="SASL_SSL", + sasl_mechanism="PLAIN", + sasl_plain_username=self.sasl_username, + sasl_plain_password=self.sasl_password, + ssl_cafile=self.ssl_cafile, + ssl_certfile=self.ssl_certfile, + ssl_keyfile=self.ssl_keyfile, + ssl_check_hostname=self.ssl_check_hostname, + ) + else: + raise ValueError( + f"Cannot initialise `KafkaEventBroker` " + f"with security protocol '{self.security_protocol}'." + ) def _publish(self, event) -> None: self.producer.send(self.topic, event) diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index 20f3a7bac93b..883e6faeb379 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -201,13 +201,13 @@ def test_kafka_broker_from_config(): expected = KafkaEventBroker( "localhost", - "username", - "password", + sasl_username="username", + sasl_password="password", topic="topic", security_protocol="SASL_PLAINTEXT", ) - assert actual.host == expected.host + assert actual.url == expected.url assert actual.sasl_username == expected.sasl_username assert actual.sasl_password == expected.sasl_password assert actual.topic == expected.topic From c24805080ad6ff3a865ff1884e146020f6d5a982 Mon Sep 17 00:00:00 2001 From: alwx Date: Wed, 7 Oct 2020 11:55:09 +0200 Subject: [PATCH 2/6] Tests --- .../kafka_invalid_security_protocol.yml | 3 ++ .../kafka_plaintext_endpoint.yml | 8 ++-- .../kafka_plaintext_endpoint_no_url.yml | 4 ++ .../kafka_sasl_plaintext_endpoint.yml | 7 ++++ .../event_brokers/kafka_sasl_ssl_endpoint.yml | 7 ++++ .../event_brokers/kafka_ssl_endpoint.yml | 6 +-- rasa/core/brokers/kafka.py | 8 +++- tests/core/test_broker.py | 39 +++++++++++++++---- 8 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml create mode 100644 data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml create mode 100644 data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml create mode 100644 data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml diff --git a/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml b/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml new file mode 100644 index 000000000000..40d67ea1f8ac --- /dev/null +++ b/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml @@ -0,0 +1,3 @@ +event_broker: + security_protocol: SOMETHING + type: kafka \ No newline at end of file diff --git a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml index 373ca01380f8..6539b2b819cd 100644 --- a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml +++ b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml @@ -1,7 +1,5 @@ event_broker: url: localhost - sasl_username: username - sasl_password: password - topic: topic - security_protocol: SASL_PLAINTEXT - type: kafka + client_id: kafka-python-rasa + security_protocol: PLAINTEXT + type: kafka \ No newline at end of file diff --git a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml new file mode 100644 index 000000000000..ef00912d8c2b --- /dev/null +++ b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml @@ -0,0 +1,4 @@ +event_broker: + client_id: kafka-python-rasa + security_protocol: PLAINTEXT + type: kafka \ No newline at end of file diff --git a/data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml b/data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml new file mode 100644 index 000000000000..373ca01380f8 --- /dev/null +++ b/data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml @@ -0,0 +1,7 @@ +event_broker: + url: localhost + sasl_username: username + sasl_password: password + topic: topic + security_protocol: SASL_PLAINTEXT + type: kafka diff --git a/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml b/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml new file mode 100644 index 000000000000..41e93a042e58 --- /dev/null +++ b/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml @@ -0,0 +1,7 @@ +event_broker: + url: localhost + sasl_username: username + sasl_password: password + topic: topic + security_protocol: SASL_SSL + type: kafka \ No newline at end of file diff --git a/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml b/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml index d97fdac70c50..606de903b813 100644 --- a/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml +++ b/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml @@ -2,8 +2,4 @@ event_broker: url: localhost topic: topic security_protocol: SSL - ssl_cafile: CARoot.pem - ssl_certfile: certificate.pem - ssl_keyfile: key.pem - ssl_check_hostname: True - type: kafka + type: kafka \ No newline at end of file diff --git a/rasa/core/brokers/kafka.py b/rasa/core/brokers/kafka.py index c57fcb8e86ff..a7532756a983 100644 --- a/rasa/core/brokers/kafka.py +++ b/rasa/core/brokers/kafka.py @@ -65,6 +65,8 @@ def __init__( loglevel: Logging level of the kafka logger. """ + import kafka + self.producer = None self.url = url self.topic = topic @@ -78,6 +80,8 @@ def __init__( self.ssl_keyfile = ssl_keyfile self.ssl_check_hostname = ssl_check_hostname + self.consumer: Optional[kafka.KafkaConsumer] = None + logging.getLogger("kafka").setLevel(loglevel) @classmethod @@ -140,8 +144,8 @@ def _create_producer(self) -> None: ) else: raise ValueError( - f"Cannot initialise `KafkaEventBroker` " - f"with security protocol '{self.security_protocol}'." + f"Cannot initialise `KafkaEventBroker`: " + f"Invalid `security_protocol` ('{self.security_protocol}')." ) def _publish(self, event) -> None: diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index 883e6faeb379..2a8194a30f02 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -1,14 +1,14 @@ import json import logging -from pathlib import Path import textwrap - -from typing import Union, Text, List, Optional, Type - +import kafka import pytest -from _pytest.logging import LogCaptureFixture +from pathlib import Path +from typing import Union, Text, List, Optional, Type, Any +from _pytest.logging import LogCaptureFixture from _pytest.monkeypatch import MonkeyPatch +from tests.core.conftest import DEFAULT_ENDPOINTS_FILE import rasa.shared.utils.io import rasa.utils.io @@ -19,7 +19,6 @@ from rasa.core.brokers.sql import SQLEventBroker from rasa.shared.core.events import Event, Restarted, SlotSet, UserUttered from rasa.utils.endpoints import EndpointConfig, read_endpoint_config -from tests.core.conftest import DEFAULT_ENDPOINTS_FILE TEST_EVENTS = [ UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []), @@ -194,7 +193,7 @@ def test_load_non_existent_custom_broker_name(): def test_kafka_broker_from_config(): - endpoints_path = "data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml" + endpoints_path = "data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml" cfg = read_endpoint_config(endpoints_path, "event_broker") actual = KafkaEventBroker.from_endpoint_config(cfg) @@ -213,6 +212,32 @@ def test_kafka_broker_from_config(): assert actual.topic == expected.topic +@pytest.mark.parametrize( + "file,exception", + [ + # `_create_producer()` raises `kafka.errors.NoBrokersAvailable` exception + # which means that configuration seems correct but a connection to + # broker cannot be established + ("kafka_sasl_plaintext_endpoint.yml", kafka.errors.NoBrokersAvailable), + ("kafka_plaintext_endpoint.yml", kafka.errors.NoBrokersAvailable), + ("kafka_sasl_ssl_endpoint.yml", kafka.errors.NoBrokersAvailable), + ("kafka_ssl_endpoint.yml", kafka.errors.NoBrokersAvailable), + # `ValueError` exception is raised when the `security_protocol` is incorrect + ("kafka_invalid_security_protocol.yml", ValueError), + # `TypeError` exception is raised when there is no `url` specified + ("kafka_plaintext_endpoint_no_url.yml", TypeError), + ], +) +def test_kafka_broker_security_protocols(file: Text, exception: Exception): + endpoints_path = f"data/test_endpoints/event_brokers/{file}" + cfg = read_endpoint_config(endpoints_path, "event_broker") + + actual = KafkaEventBroker.from_endpoint_config(cfg) + with pytest.raises(exception): + # noinspection PyProtectedMember + actual._create_producer() + + def test_no_pika_logs_if_no_debug_mode(caplog: LogCaptureFixture): from rasa.core.brokers import pika From 5d179bf6d11efd008237f41c501571c540fbc281 Mon Sep 17 00:00:00 2001 From: alwx Date: Wed, 7 Oct 2020 11:56:07 +0200 Subject: [PATCH 3/6] Changelog entry --- changelog/6943.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/6943.improvement.md diff --git a/changelog/6943.improvement.md b/changelog/6943.improvement.md new file mode 100644 index 000000000000..e4a7fd763054 --- /dev/null +++ b/changelog/6943.improvement.md @@ -0,0 +1 @@ +Update `KafkaEventBroker` to support `SASL_SSL` and `PLAINTEXT` protocols. \ No newline at end of file From 1594bee8c94436b3ea291bca8da08476a9ca0d7b Mon Sep 17 00:00:00 2001 From: alwx Date: Wed, 7 Oct 2020 11:57:34 +0200 Subject: [PATCH 4/6] Code style --- .../event_brokers/kafka_invalid_security_protocol.yml | 2 +- .../test_endpoints/event_brokers/kafka_plaintext_endpoint.yml | 2 +- .../event_brokers/kafka_plaintext_endpoint_no_url.yml | 2 +- data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml | 2 +- data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml | 2 +- tests/core/test_broker.py | 4 +++- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml b/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml index 40d67ea1f8ac..74c07e10b27f 100644 --- a/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml +++ b/data/test_endpoints/event_brokers/kafka_invalid_security_protocol.yml @@ -1,3 +1,3 @@ event_broker: security_protocol: SOMETHING - type: kafka \ No newline at end of file + type: kafka diff --git a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml index 6539b2b819cd..c402333a9925 100644 --- a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml +++ b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint.yml @@ -2,4 +2,4 @@ event_broker: url: localhost client_id: kafka-python-rasa security_protocol: PLAINTEXT - type: kafka \ No newline at end of file + type: kafka diff --git a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml index ef00912d8c2b..eb71a075e9d8 100644 --- a/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml +++ b/data/test_endpoints/event_brokers/kafka_plaintext_endpoint_no_url.yml @@ -1,4 +1,4 @@ event_broker: client_id: kafka-python-rasa security_protocol: PLAINTEXT - type: kafka \ No newline at end of file + type: kafka diff --git a/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml b/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml index 41e93a042e58..072c7f7c7f87 100644 --- a/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml +++ b/data/test_endpoints/event_brokers/kafka_sasl_ssl_endpoint.yml @@ -4,4 +4,4 @@ event_broker: sasl_password: password topic: topic security_protocol: SASL_SSL - type: kafka \ No newline at end of file + type: kafka diff --git a/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml b/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml index 606de903b813..dd415c52abc2 100644 --- a/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml +++ b/data/test_endpoints/event_brokers/kafka_ssl_endpoint.yml @@ -2,4 +2,4 @@ event_broker: url: localhost topic: topic security_protocol: SSL - type: kafka \ No newline at end of file + type: kafka diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index 2a8194a30f02..c5b1eb6e1803 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -193,7 +193,9 @@ def test_load_non_existent_custom_broker_name(): def test_kafka_broker_from_config(): - endpoints_path = "data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml" + endpoints_path = ( + "data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml" + ) cfg = read_endpoint_config(endpoints_path, "event_broker") actual = KafkaEventBroker.from_endpoint_config(cfg) From 3cd8403224bb10d648cd66c35fbcb36c29fcd078 Mon Sep 17 00:00:00 2001 From: alwx Date: Wed, 7 Oct 2020 14:38:06 +0200 Subject: [PATCH 5/6] Review comments addressed --- rasa/core/brokers/kafka.py | 36 ++++++++++++++---------------------- tests/core/test_broker.py | 4 ++-- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/rasa/core/brokers/kafka.py b/rasa/core/brokers/kafka.py index a7532756a983..dfd58aa554d5 100644 --- a/rasa/core/brokers/kafka.py +++ b/rasa/core/brokers/kafka.py @@ -15,7 +15,7 @@ def __init__( topic: Text = "rasa_core_events", client_id: Optional[Text] = None, group_id: Optional[Text] = None, - sasl_username: Union[Text, int, None] = None, + sasl_username: Optional[Text] = None, sasl_password: Optional[Text] = None, ssl_cafile: Optional[Text] = None, ssl_certfile: Optional[Text] = None, @@ -29,39 +29,31 @@ def __init__( Args: url: 'url[:port]' string (or list of 'url[:port]' - strings) that the consumer should contact to bootstrap initial + strings) that the producer should contact to bootstrap initial cluster metadata. This does not have to be the full node list. It just needs to have at least one broker that will respond to a - Metadata API Request. The default port is 9092. If no servers are - specified, it will default to `localhost:9092`. - topic: Topics to subscribe to. If not set, call subscribe() or assign() - before consuming records + Metadata API Request. + topic: Topics to subscribe to. client_id: A name for this client. This string is passed in each request to servers and can be used to identify specific server-side log entries that correspond to this client. Also submitted to `GroupCoordinator` for - logging with respect to consumer group administration. - Default: ‘kafka-python-{version}’ - group_id: The name of the consumer group to join for dynamic partition + logging with respect to producer group administration. + group_id: The name of the producer group to join for dynamic partition assignment (if enabled), and to use for fetching and committing offsets. If None, auto-partition assignment (via group coordinator) and offset - commits are disabled. Default: None - sasl_username: Username for sasl PLAIN authentication. - Required if `sasl_mechanism` is `PLAIN`. - sasl_password: Password for sasl PLAIN authentication. - Required if `sasl_mechanism` is PLAIN. + commits are disabled. + sasl_username: Username for plain authentication. + sasl_password: Password for plain authentication. ssl_cafile: Optional filename of ca file to use in certificate - verification. Default: None. + verification. ssl_certfile: Optional filename of file in pem format containing the client certificate, as well as any ca certificates needed to - establish the certificate's authenticity. Default: None. + establish the certificate's authenticity. ssl_keyfile: Optional filename containing the client private key. - Default: None. ssl_check_hostname: Flag to configure whether ssl handshake should verify that the certificate matches the brokers hostname. - Default: False. security_protocol: Protocol used to communicate with brokers. Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. loglevel: Logging level of the kafka logger. """ @@ -80,7 +72,7 @@ def __init__( self.ssl_keyfile = ssl_keyfile self.ssl_check_hostname = ssl_check_hostname - self.consumer: Optional[kafka.KafkaConsumer] = None + self.producer: Optional[kafka.KafkaConsumer] = None logging.getLogger("kafka").setLevel(loglevel) @@ -100,7 +92,7 @@ def _create_producer(self) -> None: import kafka if self.security_protocol == "PLAINTEXT": - self.consumer = kafka.KafkaConsumer( + self.producer = kafka.KafkaConsumer( self.topic, bootstrap_servers=self.url, client_id=self.client_id, @@ -128,7 +120,7 @@ def _create_producer(self) -> None: security_protocol=self.security_protocol, ) elif self.security_protocol == "SASL_SSL": - self.consumer = kafka.KafkaConsumer( + self.producer = kafka.KafkaConsumer( self.topic, bootstrap_servers=self.url, client_id=self.client_id, diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index c5b1eb6e1803..91f722d2087c 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -218,8 +218,8 @@ def test_kafka_broker_from_config(): "file,exception", [ # `_create_producer()` raises `kafka.errors.NoBrokersAvailable` exception - # which means that configuration seems correct but a connection to - # broker cannot be established + # which means that the configuration seems correct but a connection to + # the broker cannot be established ("kafka_sasl_plaintext_endpoint.yml", kafka.errors.NoBrokersAvailable), ("kafka_plaintext_endpoint.yml", kafka.errors.NoBrokersAvailable), ("kafka_sasl_ssl_endpoint.yml", kafka.errors.NoBrokersAvailable), From cc8a2647604cfc96138813c3081d5268af622c10 Mon Sep 17 00:00:00 2001 From: Vladimir Vlasov Date: Thu, 8 Oct 2020 17:39:55 +0200 Subject: [PATCH 6/6] properly fix intent_name by removing it (#6962) * properly fix intent_name by removing it * fix training_data filtering * return dissappeared nlu_examples * fix creating message from events * remove unused import * fix a mess created by empty string * pass only nlu examples * fix is_core_message * fix is_core_message * fix pattern utils tests * fix test_training_data * fix test_regex_entity_extractor * fix test_test * use only NLU examples if needed * fix test_bilou_utils * increase timeout for pipeline test * update examples for tests * update check if training is possible Co-authored-by: Tanja Bergmann Co-authored-by: Daksh --- data/examples/wit/demo-flights.json | 18 +++++++ data/test/wit_converted_to_rasa.json | 9 ++-- rasa/nlu/test.py | 4 +- rasa/nlu/utils/bilou_utils.py | 4 +- rasa/shared/core/events.py | 2 +- rasa/shared/importers/importer.py | 17 ++++--- rasa/shared/nlu/constants.py | 1 - rasa/shared/nlu/training_data/message.py | 24 +-------- .../shared/nlu/training_data/training_data.py | 49 +++++++------------ rasa/train.py | 6 +-- .../extractors/test_regex_entity_extractor.py | 31 +++++++++--- tests/nlu/test_train.py | 5 ++ tests/nlu/utils/test_bilou_utils.py | 12 +++-- tests/nlu/utils/test_pattern_utils.py | 12 ++++- tests/shared/importers/test_importer.py | 18 +++---- .../shared/nlu/training_data/test_message.py | 14 +----- .../nlu/training_data/test_training_data.py | 42 +--------------- 17 files changed, 119 insertions(+), 149 deletions(-) diff --git a/data/examples/wit/demo-flights.json b/data/examples/wit/demo-flights.json index 81084857a778..504df6bdc342 100755 --- a/data/examples/wit/demo-flights.json +++ b/data/examples/wit/demo-flights.json @@ -34,6 +34,12 @@ { "text" : "i'm looking for a flight from london to amsterdam next monday", "entities" : [ + { + "entity" : "intent", + "value" : "\"flight_booking\"", + "start" : 0, + "end" : 42 + }, { "entity" : "location", "value" : "\"london\"", @@ -59,6 +65,12 @@ { "text" : "i want to fly to berlin", "entities" : [ + { + "entity" : "intent", + "value" : "\"flight_booking\"", + "start" : 0, + "end" : 42 + }, { "entity" : "location", "value" : "\"berlin\"", @@ -71,6 +83,12 @@ { "text" : "i want to fly from london", "entities" : [ + { + "entity" : "intent", + "value" : "\"flight_booking\"", + "start" : 0, + "end" : 42 + }, { "entity" : "location", "value" : "\"london\"", diff --git a/data/test/wit_converted_to_rasa.json b/data/test/wit_converted_to_rasa.json index 947d7297cb27..8219770d6d5c 100644 --- a/data/test/wit_converted_to_rasa.json +++ b/data/test/wit_converted_to_rasa.json @@ -28,7 +28,8 @@ ] }, { - "text": "i'm looking for a flight from london to amsterdam next monday", + "text": "i'm looking for a flight from london to amsterdam next monday", + "intent": "flight_booking", "entities": [ { "start": 30, @@ -53,7 +54,8 @@ ] }, { - "text": "i want to fly to berlin", + "text": "i want to fly to berlin", + "intent": "flight_booking", "entities": [ { "start": 17, @@ -65,7 +67,8 @@ ] }, { - "text": "i want to fly from london", + "text": "i want to fly from london", + "intent": "flight_booking", "entities": [ { "start": 19, diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index eb2c900a1066..eadd6fd69460 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -1306,7 +1306,7 @@ def get_eval_data( should_eval_entities = is_entity_extractor_present(interpreter) - for example in tqdm(test_data.training_examples): + for example in tqdm(test_data.nlu_examples): result = interpreter.parse(example.get(TEXT), only_output_properties=False) if should_eval_intents: @@ -1861,7 +1861,7 @@ def compare_nlu( _, train_included = train.train_test_split(percentage / 100) # only count for the first run and ignore the others if run == 0: - training_examples_per_run.append(len(train_included.training_examples)) + training_examples_per_run.append(len(train_included.nlu_examples)) model_output_path = os.path.join(run_path, percent_string) train_split_path = os.path.join(model_output_path, "train") diff --git a/rasa/nlu/utils/bilou_utils.py b/rasa/nlu/utils/bilou_utils.py index adec1cbc8e57..de29ae67cfb6 100644 --- a/rasa/nlu/utils/bilou_utils.py +++ b/rasa/nlu/utils/bilou_utils.py @@ -130,7 +130,7 @@ def build_tag_id_dict( distinct_tags = set( [ tag_without_prefix(e) - for example in training_data.training_examples + for example in training_data.nlu_examples if example.get(bilou_key) for e in example.get(bilou_key) ] @@ -157,7 +157,7 @@ def apply_bilou_schema(training_data: TrainingData) -> None: Args: training_data: the training data """ - for message in training_data.training_examples: + for message in training_data.nlu_examples: entities = message.get(ENTITIES) if not entities: diff --git a/rasa/shared/core/events.py b/rasa/shared/core/events.py index 0007411ad230..9e50238b5ba7 100644 --- a/rasa/shared/core/events.py +++ b/rasa/shared/core/events.py @@ -1039,7 +1039,7 @@ class ActionExecuted(Event): def __init__( self, - action_name: Text, + action_name: Optional[Text] = None, policy: Optional[Text] = None, confidence: Optional[float] = None, timestamp: Optional[float] = None, diff --git a/rasa/shared/importers/importer.py b/rasa/shared/importers/importer.py index 9d25c9df40f5..f0a6ccb4a668 100644 --- a/rasa/shared/importers/importer.py +++ b/rasa/shared/importers/importer.py @@ -13,7 +13,7 @@ from rasa.shared.core.training_data.structures import StoryGraph from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData -from rasa.shared.nlu.constants import INTENT_NAME, TEXT +from rasa.shared.nlu.constants import ENTITIES, ACTION_NAME from rasa.shared.importers.autoconfig import TrainingType from rasa.shared.core.domain import IS_RETRIEVAL_INTENT_KEY @@ -533,18 +533,23 @@ def _unique_events_from_stories( def _messages_from_user_utterance(event: UserUttered) -> Message: - return Message(data={TEXT: event.text, INTENT_NAME: event.intent_name}) + # sub state correctly encodes intent vs text + data = event.as_sub_state() + # sub state stores entities differently + if data.get(ENTITIES) and event.entities: + data[ENTITIES] = event.entities + + return Message(data=data) def _messages_from_action(event: ActionExecuted) -> Message: - return Message.build_from_action( - action_name=event.action_name, action_text=event.action_text or "" - ) + # sub state correctly encodes action_name vs action_text + return Message(data=event.as_sub_state()) def _additional_training_data_from_default_actions() -> TrainingData: additional_messages_from_default_actions = [ - Message.build_from_action(action_name=action_name) + Message(data={ACTION_NAME: action_name}) for action_name in rasa.shared.core.constants.DEFAULT_ACTION_NAMES ] diff --git a/rasa/shared/nlu/constants.py b/rasa/shared/nlu/constants.py index dde63d4c32f1..5bc16fa6c86f 100644 --- a/rasa/shared/nlu/constants.py +++ b/rasa/shared/nlu/constants.py @@ -4,7 +4,6 @@ INTENT_RESPONSE_KEY = "intent_response_key" ACTION_TEXT = "action_text" ACTION_NAME = "action_name" -INTENT_NAME = "intent_name" INTENT_NAME_KEY = "name" METADATA = "metadata" METADATA_INTENT = "intent" diff --git a/rasa/shared/nlu/training_data/message.py b/rasa/shared/nlu/training_data/message.py index adb0227a452a..7b1881b52bb4 100644 --- a/rasa/shared/nlu/training_data/message.py +++ b/rasa/shared/nlu/training_data/message.py @@ -19,7 +19,6 @@ FEATURE_TYPE_SEQUENCE, ACTION_TEXT, ACTION_NAME, - INTENT_NAME, ) if typing.TYPE_CHECKING: @@ -132,24 +131,6 @@ def build( # pytype: enable=unsupported-operands return cls(data, **kwargs) - @classmethod - def build_from_action( - cls, - action_text: Optional[Text] = "", - action_name: Optional[Text] = "", - **kwargs: Any, - ) -> "Message": - """ - Build a `Message` from `ActionExecuted` data. - Args: - action_text: text of a bot's utterance - action_name: name of an action executed - Returns: - Message - """ - action_data = {ACTION_TEXT: action_text, ACTION_NAME: action_name} - return cls(data=action_data, **kwargs) - def get_full_intent(self) -> Text: """Get intent as it appears in training data""" @@ -330,9 +311,8 @@ def is_core_message(self) -> bool: Returns: True, if message is a core message, false otherwise. """ - return ( - self.data.get(ACTION_NAME) is not None - or self.data.get(INTENT_NAME) is not None + return bool( + self.data.get(ACTION_NAME) or self.data.get(ACTION_TEXT) or ( (self.data.get(INTENT) or self.data.get(RESPONSE)) diff --git a/rasa/shared/nlu/training_data/training_data.py b/rasa/shared/nlu/training_data/training_data.py index af4b2d53ebca..f4bbaf024fa0 100644 --- a/rasa/shared/nlu/training_data/training_data.py +++ b/rasa/shared/nlu/training_data/training_data.py @@ -21,7 +21,7 @@ ENTITIES, TEXT, ACTION_NAME, - INTENT_NAME, + ACTION_TEXT, ) from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data import util @@ -130,21 +130,21 @@ def sanitize_examples(examples: List[Message]) -> List[Message]: return list(OrderedDict.fromkeys(examples)) + @lazy_property + def nlu_examples(self) -> List[Message]: + return [ex for ex in self.training_examples if not ex.is_core_message()] + @lazy_property def intent_examples(self) -> List[Message]: - return [ex for ex in self.training_examples if ex.get(INTENT)] + return [ex for ex in self.nlu_examples if ex.get(INTENT)] @lazy_property def response_examples(self) -> List[Message]: - return [ex for ex in self.training_examples if ex.get(INTENT_RESPONSE_KEY)] + return [ex for ex in self.nlu_examples if ex.get(INTENT_RESPONSE_KEY)] @lazy_property def entity_examples(self) -> List[Message]: - return [ex for ex in self.training_examples if ex.get(ENTITIES)] - - @lazy_property - def nlu_examples(self) -> List[Message]: - return [ex for ex in self.training_examples if not ex.is_core_message()] + return [ex for ex in self.nlu_examples if ex.get(ENTITIES)] @lazy_property def intents(self) -> Set[Text]: @@ -577,38 +577,23 @@ def is_empty(self) -> bool: """Checks if any training data was loaded.""" lists_to_check = [ - self._training_examples_without_empty_e2e_examples(), + self.training_examples, self.entity_synonyms, self.regex_features, self.lookup_tables, ] return not any([len(lst) > 0 for lst in lists_to_check]) - def without_empty_e2e_examples(self) -> "TrainingData": - """Removes training data examples from intent labels and action names which - were added for end-to-end training. - - Returns: - Itself but without training examples which don't have a text or intent. - """ - training_examples = copy.deepcopy(self.training_examples) - entity_synonyms = self.entity_synonyms.copy() - regex_features = copy.deepcopy(self.regex_features) - lookup_tables = copy.deepcopy(self.lookup_tables) - responses = copy.deepcopy(self.responses) - copied = TrainingData( - training_examples, entity_synonyms, regex_features, lookup_tables, responses - ) - copied.training_examples = self._training_examples_without_empty_e2e_examples() - - return copied + def can_train_nlu_model(self) -> bool: + """Checks if any NLU training data was loaded.""" - def _training_examples_without_empty_e2e_examples(self) -> List[Message]: - return [ - example - for example in self.training_examples - if not example.get(ACTION_NAME) and not example.get(INTENT_NAME) + lists_to_check = [ + self.nlu_examples, + self.entity_synonyms, + self.regex_features, + self.lookup_tables, ] + return not any([len(lst) > 0 for lst in lists_to_check]) def list_to_str(lst: List[Text], delim: Text = ", ", quote: Text = "'") -> Text: diff --git a/rasa/train.py b/rasa/train.py index 68406287968d..802ba6dc99e6 100644 --- a/rasa/train.py +++ b/rasa/train.py @@ -157,7 +157,7 @@ async def _train_async_internal( file_importer.get_stories(), file_importer.get_nlu_data() ) - if stories.is_empty() and nlu_data.is_empty(): + if stories.is_empty() and nlu_data.can_train_nlu_model(): print_error( "No training data given. Please provide stories and NLU data in " "order to train a Rasa model using the '--data' argument." @@ -174,7 +174,7 @@ async def _train_async_internal( additional_arguments=nlu_additional_arguments, ) - if nlu_data.is_empty(): + if nlu_data.can_train_nlu_model(): print_warning("No NLU data present. Just a Rasa Core model will be trained.") return await _train_core_with_validated_data( file_importer, @@ -495,7 +495,7 @@ async def _train_nlu_async( ) training_data = await file_importer.get_nlu_data() - if training_data.is_empty(): + if training_data.can_train_nlu_model(): print_error( f"Path '{nlu_data}' doesn't contain valid NLU data in it. " f"Please verify the data format. " diff --git a/tests/nlu/extractors/test_regex_entity_extractor.py b/tests/nlu/extractors/test_regex_entity_extractor.py index 1b625450d1fd..677de16e5e9c 100644 --- a/tests/nlu/extractors/test_regex_entity_extractor.py +++ b/tests/nlu/extractors/test_regex_entity_extractor.py @@ -4,7 +4,7 @@ from rasa.shared.nlu.training_data.training_data import TrainingData from rasa.shared.nlu.training_data.message import Message -from rasa.shared.nlu.constants import ENTITIES, TEXT +from rasa.shared.nlu.constants import ENTITIES, TEXT, INTENT from rasa.nlu.extractors.regex_entity_extractor import RegexEntityExtractor @@ -86,12 +86,17 @@ def test_process( training_data.lookup_tables = lookup training_data.training_examples = [ Message( - data={TEXT: "Hi Max!", "entities": [{"entity": "person", "value": "Max"}]} + data={ + TEXT: "Hi Max!", + INTENT: "greet", + ENTITIES: [{"entity": "person", "value": "Max"}], + } ), Message( data={ TEXT: "I live in Berlin", - "entities": [{"entity": "city", "value": "Berlin"}], + INTENT: "inform", + ENTITIES: [{"entity": "city", "value": "Berlin"}], } ), ] @@ -165,12 +170,17 @@ def test_lowercase( training_data.lookup_tables = lookup training_data.training_examples = [ Message( - data={TEXT: "Hi Max!", "entities": [{"entity": "person", "value": "Max"}]} + data={ + TEXT: "Hi Max!", + INTENT: "greet", + ENTITIES: [{"entity": "person", "value": "Max"}], + } ), Message( data={ TEXT: "I live in Berlin", - "entities": [{"entity": "city", "value": "Berlin"}], + INTENT: "inform", + ENTITIES: [{"entity": "city", "value": "Berlin"}], } ), ] @@ -184,18 +194,23 @@ def test_lowercase( def test_do_not_overwrite_any_entities(): - message = Message(data={TEXT: "Max lives in Berlin."}) + message = Message(data={TEXT: "Max lives in Berlin.", INTENT: "infrom"}) message.set(ENTITIES, [{"entity": "person", "value": "Max", "start": 0, "end": 3}]) training_data = TrainingData() training_data.training_examples = [ Message( - data={TEXT: "Hi Max!", "entities": [{"entity": "person", "value": "Max"}]} + data={ + TEXT: "Hi Max!", + INTENT: "greet", + ENTITIES: [{"entity": "person", "value": "Max"}], + } ), Message( data={ TEXT: "I live in Berlin", - "entities": [{"entity": "city", "value": "Berlin"}], + INTENT: "inform", + ENTITIES: [{"entity": "city", "value": "Berlin"}], } ), ] diff --git a/tests/nlu/test_train.py b/tests/nlu/test_train.py index 67bbd856a336..d0990176ddad 100644 --- a/tests/nlu/test_train.py +++ b/tests/nlu/test_train.py @@ -116,6 +116,7 @@ def test_all_components_are_in_at_least_one_test_pipeline(): ), "`all_components` template is missing component." +@pytest.mark.timeout(600) @pytest.mark.parametrize("language, pipeline", pipelines_for_tests()) async def test_train_persist_load_parse(language, pipeline, component_builder, tmpdir): _config = RasaNLUModelConfig({"pipeline": pipeline, "language": language}) @@ -135,6 +136,7 @@ async def test_train_persist_load_parse(language, pipeline, component_builder, t assert loaded.parse("Rasa is great!") is not None +@pytest.mark.timeout(600) @pytest.mark.parametrize("language, pipeline", pipelines_for_non_windows_tests()) @pytest.mark.skip_on_windows async def test_train_persist_load_parse_non_windows( @@ -157,6 +159,7 @@ def test_train_model_without_data(language, pipeline, component_builder, tmpdir) assert loaded.parse("Rasa is great!") is not None +@pytest.mark.timeout(600) @pytest.mark.parametrize("language, pipeline", pipelines_for_non_windows_tests()) @pytest.mark.skip_on_windows def test_train_model_without_data_non_windows( @@ -165,6 +168,7 @@ def test_train_model_without_data_non_windows( test_train_model_without_data(language, pipeline, component_builder, tmpdir) +@pytest.mark.timeout(600) @pytest.mark.parametrize("language, pipeline", pipelines_for_tests()) def test_load_and_persist_without_train(language, pipeline, component_builder, tmpdir): _config = RasaNLUModelConfig({"pipeline": pipeline, "language": language}) @@ -178,6 +182,7 @@ def test_load_and_persist_without_train(language, pipeline, component_builder, t assert loaded.parse("Rasa is great!") is not None +@pytest.mark.timeout(600) @pytest.mark.parametrize("language, pipeline", pipelines_for_non_windows_tests()) @pytest.mark.skip_on_windows def test_load_and_persist_without_train_non_windows( diff --git a/tests/nlu/utils/test_bilou_utils.py b/tests/nlu/utils/test_bilou_utils.py index e8ff3425e614..feacdf36477e 100644 --- a/tests/nlu/utils/test_bilou_utils.py +++ b/tests/nlu/utils/test_bilou_utils.py @@ -67,13 +67,15 @@ def test_remove_bilou_prefixes(): def test_build_tag_id_dict(): - message_1 = Message.build(text="Germany is part of the European Union") + message_1 = Message.build( + text="Germany is part of the European Union", intent="inform" + ) message_1.set( BILOU_ENTITIES, ["U-location", "O", "O", "O", "O", "B-organisation", "L-organisation"], ) - message_2 = Message.build(text="Berlin is the capital of Germany") + message_2 = Message.build(text="Berlin is the capital of Germany", intent="inform") message_2.set(BILOU_ENTITIES, ["U-location", "O", "O", "O", "O", "U-location"]) training_data = TrainingData([message_1, message_2]) @@ -96,7 +98,9 @@ def test_build_tag_id_dict(): def test_apply_bilou_schema(): tokenizer = WhitespaceTokenizer() - message_1 = Message.build(text="Germany is part of the European Union") + message_1 = Message.build( + text="Germany is part of the European Union", intent="inform" + ) message_1.set( ENTITIES, [ @@ -110,7 +114,7 @@ def test_apply_bilou_schema(): ], ) - message_2 = Message.build(text="Berlin is the capital of Germany") + message_2 = Message.build(text="Berlin is the capital of Germany", intent="inform") message_2.set( ENTITIES, [ diff --git a/tests/nlu/utils/test_pattern_utils.py b/tests/nlu/utils/test_pattern_utils.py index c2d237978329..864541c06ca8 100644 --- a/tests/nlu/utils/test_pattern_utils.py +++ b/tests/nlu/utils/test_pattern_utils.py @@ -77,7 +77,11 @@ def test_extract_patterns_use_only_entities_regexes( if entity: training_data.training_examples = [ Message( - data={"text": "text", "entities": [{"entity": entity, "value": "text"}]} + data={ + "text": "text", + "intent": "greet", + "entities": [{"entity": entity, "value": "text"}], + } ) ] if regex_features: @@ -109,7 +113,11 @@ def test_extract_patterns_use_only_entities_lookup_tables( if entity: training_data.training_examples = [ Message( - data={"text": "text", "entities": [{"entity": entity, "value": "text"}]} + data={ + "text": "text", + "intent": "greet", + "entities": [{"entity": entity, "value": "text"}], + } ) ] if lookup_tables: diff --git a/tests/shared/importers/test_importer.py b/tests/shared/importers/test_importer.py index 2203f1c89de7..6a1ac45209ee 100644 --- a/tests/shared/importers/test_importer.py +++ b/tests/shared/importers/test_importer.py @@ -23,7 +23,7 @@ ) from rasa.shared.importers.multi_project import MultiProjectImporter from rasa.shared.importers.rasa import RasaFileImporter -from rasa.shared.nlu.constants import ACTION_TEXT, ACTION_NAME, INTENT_NAME, TEXT +from rasa.shared.nlu.constants import ACTION_TEXT, ACTION_NAME, INTENT, TEXT from rasa.shared.nlu.training_data.message import Message @@ -190,14 +190,14 @@ async def test_import_nlu_training_data_from_e2e_stories( StoryStep( events=[ SlotSet("some slot", "doesn't matter"), - UserUttered("greet_from_stories", {"name": "greet_from_stories"}), + UserUttered(intent={"name": "greet_from_stories"}), ActionExecuted("utter_greet_from_stories"), ] ), StoryStep( events=[ UserUttered("how are you doing?"), - ActionExecuted("utter_greet_from_stories", action_text="Hi Joey."), + ActionExecuted(action_text="Hi Joey."), ] ), ] @@ -227,12 +227,10 @@ async def mocked_stories(*_: Any, **__: Any) -> StoryGraph: # Check if the NLU training data was added correctly from the story training data expected_additional_messages = [ - Message(data={TEXT: "greet_from_stories", INTENT_NAME: "greet_from_stories"}), - Message(data={ACTION_NAME: "utter_greet_from_stories", ACTION_TEXT: ""}), - Message(data={TEXT: "how are you doing?", INTENT_NAME: None}), - Message( - data={ACTION_NAME: "utter_greet_from_stories", ACTION_TEXT: "Hi Joey."} - ), + Message(data={INTENT: "greet_from_stories"}), + Message(data={ACTION_NAME: "utter_greet_from_stories"}), + Message(data={TEXT: "how are you doing?"}), + Message(data={ACTION_TEXT: "Hi Joey."}), ] assert all(m in nlu_data.training_examples for m in expected_additional_messages) @@ -294,7 +292,7 @@ async def test_import_nlu_training_data_with_default_actions( extended_training_data = await default_importer.get_nlu_data() assert all( - Message(data={ACTION_NAME: action_name, ACTION_TEXT: ""}) + Message(data={ACTION_NAME: action_name}) in extended_training_data.training_examples for action_name in rasa.shared.core.constants.DEFAULT_ACTION_NAMES ) diff --git a/tests/shared/nlu/training_data/test_message.py b/tests/shared/nlu/training_data/test_message.py index 88e1dc8fe07a..96753d851a97 100644 --- a/tests/shared/nlu/training_data/test_message.py +++ b/tests/shared/nlu/training_data/test_message.py @@ -13,7 +13,6 @@ ACTION_NAME, INTENT, RESPONSE, - INTENT_NAME, ) import rasa.shared.nlu.training_data.message from rasa.shared.nlu.training_data.message import Message @@ -261,24 +260,15 @@ def test_ordered(): ] -def test_build_from_action(): - test_action_name = "test_action_name" - test_action_text = "test action text" - assert Message.build_from_action( - action_text=test_action_text, action_name=test_action_name - ) == Message(data={ACTION_NAME: test_action_name, ACTION_TEXT: test_action_text}) - - @pytest.mark.parametrize( "message, core_message", [ (Message({INTENT: "intent", TEXT: "text"}), False), (Message({RESPONSE: "response", TEXT: "text"}), False), (Message({INTENT: "intent"}), True), - (Message({ACTION_TEXT: "action text", ACTION_NAME: ""}), True), - (Message({ACTION_NAME: "action"}), True), + (Message({ACTION_TEXT: "action text"}), True), + (Message({ACTION_NAME: "action name"}), True), (Message({TEXT: "text"}), True), - (Message({TEXT: None, INTENT_NAME: "affirm"}), True), ], ) def test_is_core_message( diff --git a/tests/shared/nlu/training_data/test_training_data.py b/tests/shared/nlu/training_data/test_training_data.py index 059698745337..5b6c32d44746 100644 --- a/tests/shared/nlu/training_data/test_training_data.py +++ b/tests/shared/nlu/training_data/test_training_data.py @@ -1,14 +1,8 @@ -import asyncio -from pathlib import Path from typing import Text, List import pytest import rasa.shared.utils.io -from rasa.shared.core.domain import Domain -from rasa.shared.core.events import UserUttered, ActionExecuted -from rasa.shared.core.training_data.structures import StoryStep, StoryGraph -from rasa.shared.importers.importer import E2EImporter, TrainingDataImporter from rasa.shared.nlu.constants import TEXT, INTENT_RESPONSE_KEY from rasa.nlu.convert import convert_training_data from rasa.nlu.extractors.mitie_entity_extractor import MitieEntityExtractor @@ -44,7 +38,7 @@ def test_wit_data(): td = load_data("data/examples/wit/demo-flights.json") assert not td.is_empty() assert len(td.entity_examples) == 4 - assert len(td.intent_examples) == 1 + assert len(td.intent_examples) == 4 assert len(td.training_examples) == 4 assert td.entity_synonyms == {} assert td.intents == {"flight_booking"} @@ -607,37 +601,3 @@ def test_custom_attributes(tmp_path): assert len(td.training_examples) == 1 example = td.training_examples[0] assert example.get("sentiment") == 0.8 - - -async def test_without_additional_e2e_examples(tmp_path: Path): - domain_path = tmp_path / "domain.yml" - domain_path.write_text(Domain.empty().as_yaml()) - - config_path = tmp_path / "config.yml" - config_path.touch() - - existing = TrainingDataImporter.load_from_dict( - {}, str(config_path), str(domain_path), [] - ) - - stories = StoryGraph( - [ - StoryStep( - events=[ - UserUttered("greet_from_stories", {"name": "greet_from_stories"}), - ActionExecuted("utter_greet_from_stories"), - ] - ) - ] - ) - - # Patch to return our test stories - existing.get_stories = asyncio.coroutine(lambda *args: stories) - - importer = E2EImporter(existing) - - training_data = await importer.get_nlu_data() - - assert training_data.training_examples - assert training_data.is_empty() - assert not training_data.without_empty_e2e_examples().training_examples