diff --git a/canopen/network.py b/canopen/network.py index 02bec899..57667eab 100644 --- a/canopen/network.py +++ b/canopen/network.py @@ -75,6 +75,9 @@ def unsubscribe(self, can_id, callback=None) -> None: If given, remove only this callback. Otherwise all callbacks for the CAN ID. """ + if can_id not in self.subscribers: + return + if callback is None: del self.subscribers[can_id] else: diff --git a/canopen/node/local.py b/canopen/node/local.py index 8f2493d9..763d4cfe 100644 --- a/canopen/node/local.py +++ b/canopen/node/local.py @@ -51,6 +51,7 @@ def associate_network(self, network: canopen.network.Network): def remove_network(self) -> None: self.network.unsubscribe(self.sdo.rx_cobid, self.sdo.on_request) self.network.unsubscribe(0, self.nmt.on_command) + self.stop_pdo_services() self.network = canopen.network._UNINITIALIZED_NETWORK self.sdo.network = canopen.network._UNINITIALIZED_NETWORK self.tpdo.network = canopen.network._UNINITIALIZED_NETWORK @@ -64,6 +65,21 @@ def add_read_callback(self, callback): def add_write_callback(self, callback): self._write_callbacks.append(callback) + def start_pdo_services(self, period: float): + """ + Start the PDO related services of the node. + :param period: Service interval in seconds. + """ + self.rpdo.subscribe() + self.tpdo.start(period=period) + + def stop_pdo_services(self): + """ + Stop the PDO related services of the node. + """ + self.rpdo.unsubscribe() + self.tpdo.stop() + def get_data( self, index: int, subindex: int, check_readable: bool = False ) -> bytes: diff --git a/canopen/pdo/__init__.py b/canopen/pdo/__init__.py index 533309f8..20af6bac 100644 --- a/canopen/pdo/__init__.py +++ b/canopen/pdo/__init__.py @@ -2,7 +2,7 @@ from canopen import node from canopen.pdo.base import PdoBase, PdoMap, PdoMaps, PdoVariable - +import canopen.network __all__ = [ "PdoBase", @@ -74,6 +74,19 @@ def __init__(self, node): self.map = PdoMaps(0x1800, 0x1A00, self, 0x180) logger.debug('TPDO Map as %d', len(self.map)) + def start(self, period: float): + """Start transmission of all TPDOs. + + :param float period: Transmission period in seconds. + :raises TypeError: Exception is thrown if the node associated with the PDO does not + support this function. + """ + if isinstance(self.node, node.LocalNode): + for pdo in self.map.values(): + pdo.start(period) + else: + raise TypeError('The node type does not support this function.') + def stop(self): """Stop transmission of all TPDOs. diff --git a/canopen/pdo/base.py b/canopen/pdo/base.py index 0ba65199..c3e9a0e7 100644 --- a/canopen/pdo/base.py +++ b/canopen/pdo/base.py @@ -77,6 +77,11 @@ def subscribe(self): for pdo_map in self.map.values(): pdo_map.subscribe() + def unsubscribe(self) -> None: + """Unregister the node's PDOs for reception on the network.""" + for pdo_map in self.map.values(): + pdo_map.unsubscribe() + def export(self, filename): """Export current configuration to a database file. @@ -469,6 +474,12 @@ def subscribe(self) -> None: logger.info("Subscribing to enabled PDO 0x%X on the network", self.cob_id) self.pdo_node.network.subscribe(self.cob_id, self.on_message) + def unsubscribe(self) -> None: + """Unregister the PDO for reception on the network.""" + if self.enabled: + logger.info("Unsubscribing from enabled PDO 0x%X on the network", self.cob_id) + self.pdo_node.network.unsubscribe(self.cob_id, self.on_message) + def clear(self) -> None: """Clear all variables from this map.""" self.map = [] @@ -534,6 +545,13 @@ def start(self, period: Optional[float] = None) -> None: raise ValueError("A valid transmission period has not been given") logger.info("Starting %s with a period of %s seconds", self.name, self.period) + if self.cob_id is None and self.predefined_cob_id is not None: + self.cob_id = self.predefined_cob_id + logger.info("Using predefined COB-ID 0x%X", self.cob_id) + + if self.cob_id is None: + raise ValueError("COB-ID has not been set") + self._task = self.pdo_node.network.send_periodic( self.cob_id, self.data, self.period) diff --git a/test/test_network.py b/test/test_network.py index 1d45a1c2..e9892e75 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -144,6 +144,8 @@ def hook(*args, i=i): self.assertEqual(accumulators[2], [(2, bytes([4, 5, 6]), 1003)]) self.network.unsubscribe(0) + # Should not raise an error. + self.network.unsubscribe(10) self.network.notify(0, bytes([7, 7, 7]), 1004) # Verify that no new data was added to the accumulator. self.assertEqual(accumulators[0], [(0, bytes([1, 2, 3]), 1000)]) diff --git a/test/test_pdo.py b/test/test_pdo.py index 1badc89d..0c99598c 100644 --- a/test/test_pdo.py +++ b/test/test_pdo.py @@ -80,6 +80,16 @@ def test_pdo_export(self): self.assertIn("ID", header) self.assertIn("Frame Name", header) + def test_tpdo_start_stop(self): + network = canopen.Network() + network.connect("test", interface="virtual") + self.node.associate_network(network) + self.node.tpdo[1].start(period=0.01) + self.node.tpdo[1].stop() + + def test_rpdo_subscribe_unsubscribe(self): + self.node.rpdo.subscribe() + self.node.rpdo.unsubscribe() if __name__ == "__main__": unittest.main()