diff --git a/src/deadline/client/ui/widgets/host_requirements_tab.py b/src/deadline/client/ui/widgets/host_requirements_tab.py index 19224793..ec4563fc 100644 --- a/src/deadline/client/ui/widgets/host_requirements_tab.py +++ b/src/deadline/client/ui/widgets/host_requirements_tab.py @@ -3,44 +3,45 @@ """ UI widgets for the host requirements tab. """ -from typing import Any, Dict, List, Optional, Union, Literal +import re +from logging import getLogger from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union from qtpy.QtCore import Qt # type: ignore from qtpy.QtGui import ( # type: ignore - QFont, - QValidator, - QIntValidator, - QDoubleValidator, QBrush, + QDoubleValidator, + QFont, QIcon, + QIntValidator, QRegularExpressionValidator, QStandardItem, + QValidator, ) from qtpy.QtWidgets import ( # type: ignore QComboBox, + QDoubleSpinBox, + QFrame, QGroupBox, QHBoxLayout, QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QPushButton, QRadioButton, QSizePolicy, QSpacerItem, - QDoubleSpinBox, QSpinBox, QVBoxLayout, QWidget, - QPushButton, - QListWidget, - QListWidgetItem, - QFrame, - QLineEdit, - QListView, ) from deadline.client.exceptions import NonValidInputError -from ..dataclasses import HostRequirements, CustomRequirements, OsRequirements, HardwareRequirements -from logging import getLogger +from ..dataclasses import CustomRequirements, HardwareRequirements, HostRequirements, OsRequirements logger = getLogger(__name__) @@ -72,11 +73,19 @@ ) CUSTOM_CAPABILITY_NAME_REGEX = "^(\\.[a-zA-Z][a-zA-Z0-9]{0,63})+$" - ATTRIBUTE_CAPABILITY_VALUE_REGEX = "^[a-zA-Z_]([a-zA-Z0-9_\\-]{0,99})$" -ATTRIBUTE_CAPABILITY_PREFIX = "attr.worker." -AMOUNT_CAPABILITY_PREFIX = "amount.worker." +ATTRIBUTE_CAPABILITY_PREFIX = "attr." +AMOUNT_CAPABILITY_PREFIX = "amount." + +RESERVED_FIRST_IDENTIFIERS = ["worker", "job", "step", "task"] + +# An attribute name needs to be <= 100 characters. Each Identifier must start with a letter or underscore and can +# be a maximum of 64 characters long. Periods separate each Identifier. +IDENTIFIER_REGEX = "[a-zA-Z_][a-zA-Z0-9_]{0,63}" + +ATTRIBUTE_CAPABILITY_NAME_REGEX = rf"^({IDENTIFIER_REGEX})(\.{IDENTIFIER_REGEX})*$" +AMOUNT_CAPABILITY_NAME_REGEX = rf"^({IDENTIFIER_REGEX})(\.{IDENTIFIER_REGEX})*$" class AddIcon(QIcon): @@ -557,9 +566,7 @@ def _build_ui(self): self.name_label.setFixedWidth(LABEL_FIXED_WIDTH) self.name_line_edit = QLineEdit() self.name_line_edit.setFixedWidth(LABEL_FIXED_WIDTH) - self.name_line_edit.setValidator( - QRegularExpressionValidator(ATTRIBUTE_CAPABILITY_VALUE_REGEX) - ) + self.name_line_edit.setValidator(QRegularExpressionValidator(AMOUNT_CAPABILITY_NAME_REGEX)) assert (100 - len(AMOUNT_CAPABILITY_PREFIX)) > 0 self.name_line_edit.setMaxLength(100 - len(AMOUNT_CAPABILITY_PREFIX)) @@ -611,7 +618,7 @@ def name(self, name: str): self.name_line_edit.setText(name) @property - def minimum(self) -> Optional[int]: + def minimum(self) -> Optional[float]: if self.min_spin_box.has_input(): return self.min_spin_box.value() return None @@ -621,7 +628,7 @@ def minimum(self, minimum: int): self.min_spin_box.setValue(minimum) @property - def maximum(self) -> Optional[int]: + def maximum(self) -> Optional[float]: if self.max_spin_box.has_input(): return self.max_spin_box.value() return None @@ -639,6 +646,9 @@ def get_requirement(self) -> Dict[str, Any]: """ requirement: Dict[str, Any] = {} if self.name_line_edit.text(): + if self.name[-1] == ".": + raise NonValidInputError("Your requirement name cannot end with a period.") + requirement = {"name": AMOUNT_CAPABILITY_PREFIX + self.name} minimum = self.minimum maximum = self.maximum @@ -656,6 +666,13 @@ def get_requirement(self) -> Dict[str, Any]: elif maximum: requirement["max"] = maximum + parsed_name_match = re.match(AMOUNT_CAPABILITY_NAME_REGEX, self.name) + + if parsed_name_match and parsed_name_match.group(1) in RESERVED_FIRST_IDENTIFIERS: + raise NonValidInputError( + "Please make sure that the first identifier in your name is not a reserved identifier. " + + str(RESERVED_FIRST_IDENTIFIERS) + ) else: raise NonValidInputError( "Please fill out all custom amount names in the custom host requirement options!" @@ -690,7 +707,7 @@ def _build_ui(self): assert (100 - len(ATTRIBUTE_CAPABILITY_PREFIX)) > 0 self.name_line_edit.setMaxLength(100 - len(ATTRIBUTE_CAPABILITY_PREFIX)) self.name_line_edit.setValidator( - QRegularExpressionValidator(ATTRIBUTE_CAPABILITY_VALUE_REGEX) + QRegularExpressionValidator(ATTRIBUTE_CAPABILITY_NAME_REGEX) ) self.add_value_button = None @@ -876,6 +893,16 @@ def get_requirement(self) -> Dict[str, Any]: requirements_are_valid = True if self.name_line_edit.text(): + if self.name[-1] == ".": + raise NonValidInputError("Your requirement name cannot end with a period.") + + parsed_name_match = re.match(AMOUNT_CAPABILITY_NAME_REGEX, self.name) + if parsed_name_match and parsed_name_match.group(1) in RESERVED_FIRST_IDENTIFIERS: + raise NonValidInputError( + "Please make sure that the first identifier in your name is not a reserved identifier. " + + str(RESERVED_FIRST_IDENTIFIERS) + ) + try: values = self.values except ValueError: @@ -893,6 +920,7 @@ def get_requirement(self) -> Dict[str, Any]: raise NonValidInputError( "Please fill out all custom attribute names and values in the custom host requirements options!" ) + return requirement @@ -1051,7 +1079,7 @@ def __init__(self, items: List[str], parent=None): self.addItems(items) def has_input(self) -> bool: - return PLACEHOLDER_TEXT != self.currentText() + return self.currentText() != PLACEHOLDER_TEXT class OptionalMultiSelectComboBox(QComboBox): diff --git a/test/unit/deadline_client/ui/widgets/test_host_requirements_tab.py b/test/unit/deadline_client/ui/widgets/test_host_requirements_tab.py index e9d671ff..57ea8744 100644 --- a/test/unit/deadline_client/ui/widgets/test_host_requirements_tab.py +++ b/test/unit/deadline_client/ui/widgets/test_host_requirements_tab.py @@ -1,24 +1,29 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -import pytest +import math from unittest.mock import MagicMock +import pytest + try: from deadline.client.ui.widgets.host_requirements_tab import ( - HardwareRequirementsWidget, + AMOUNT_CAPABILITY_PREFIX, + ATTRIBUTE_CAPABILITY_PREFIX, + MAX_INT_VALUE, + RESERVED_FIRST_IDENTIFIERS, CustomAmountWidget, - CustomAttributeWidget, CustomAttributeValueWidget, + CustomAttributeWidget, CustomRequirementsWidget, - ATTRIBUTE_CAPABILITY_PREFIX, - AMOUNT_CAPABILITY_PREFIX, - MAX_INT_VALUE, + HardwareRequirementsWidget, ) except ImportError: # The tests in this file should be skipped if Qt UI related modules cannot be loaded pytest.importorskip("deadline.client.ui.widgets.host_requirements_tab") +from deadline.client.exceptions import NonValidInputError +IDENTFIER_MAX_LENGTH = 64 AMOUNT_NAME_MAX_LENGTH = 100 - len(AMOUNT_CAPABILITY_PREFIX) ATTRIBUTE_NAME_MAX_LENGTH = 100 - len(ATTRIBUTE_CAPABILITY_PREFIX) @@ -62,7 +67,7 @@ def test_name_in_custom_amount_widget_should_be_truncated(qtbot): assert widget.name_line_edit.text() == invalid_str[:AMOUNT_NAME_MAX_LENGTH] -def test_name_in_custom_amount_widget_should_follow_regex_pattern(qtbot): +def test_name_in_custom_amount_widget_should_not_allow_invalid_chars(qtbot): widget = CustomAmountWidget(MagicMock(), 1) qtbot.addWidget(widget) @@ -71,6 +76,51 @@ def test_name_in_custom_amount_widget_should_follow_regex_pattern(qtbot): assert widget.name_line_edit.hasAcceptableInput() is False +def test_name_in_custom_amount_widget_should_allow_identifiers(qtbot): + widget = CustomAmountWidget(MagicMock(), 1) + qtbot.addWidget(widget) + + valid_identifier = "a" + (".a" * math.floor((AMOUNT_NAME_MAX_LENGTH - 1) / 2)) + widget.name_line_edit.setText(valid_identifier) + assert widget.name_line_edit.hasAcceptableInput() + + +def test_name_in_custom_amount_widget_does_not_allow_invalid_identifiers(qtbot): + widget = CustomAmountWidget(MagicMock(), 1) + qtbot.addWidget(widget) + + valid_identifier = "a" + invalid_identifier = "a" * (IDENTFIER_MAX_LENGTH + 1) + + widget.name_line_edit.setText(".".join([valid_identifier, invalid_identifier])) + assert widget.name_line_edit.hasAcceptableInput() is False + + +def test_name_in_custom_amount_widget_should_not_allow_missing_identifiers(qtbot): + widget = CustomAmountWidget(MagicMock(), 1) + qtbot.addWidget(widget) + + missing_identifier = "a..a" + widget.name_line_edit.setText(missing_identifier) + assert widget.name_line_edit.hasAcceptableInput() is False + + +def test_name_in_custom_amount_widget_should_not_allow_reserved_first_identifier(qtbot): + widget = CustomAmountWidget(MagicMock(), 1) + qtbot.addWidget(widget) + + for reserved_identifier in RESERVED_FIRST_IDENTIFIERS: + widget.name_line_edit.setText(reserved_identifier) + with pytest.raises(NonValidInputError) as e: + widget.get_requirement() + + assert str( + e.value + ) == "Please make sure that the first identifier in your name is not a reserved identifier. " + str( + RESERVED_FIRST_IDENTIFIERS + ) + + def test_value_in_custom_amount_widget_should_be_integer_within_range(qtbot): widget = CustomAmountWidget(MagicMock(), 1) qtbot.addWidget(widget) @@ -117,3 +167,68 @@ def test_value_in_custom_attribute_widget_should_follow_regex_pattern(qtbot): invalid_str = "" widget.line_edit.setText(invalid_str) assert widget.line_edit.hasAcceptableInput() is False + + +def test_name_in_custom_attribute_widget_should_allow_identifiers(qtbot): + widget = CustomAttributeWidget(MagicMock(), 1, CustomRequirementsWidget()) + qtbot.addWidget(widget) + + valid_identifier = "a" + (".a" * math.floor((AMOUNT_NAME_MAX_LENGTH - 1) / 2)) + widget.name_line_edit.setText(valid_identifier) + assert widget.name_line_edit.hasAcceptableInput() + + +def test_name_in_custom_attribute_widget_does_not_allow_invalid_identifiers(qtbot): + widget = CustomAttributeWidget(MagicMock(), 1, CustomRequirementsWidget()) + qtbot.addWidget(widget) + + valid_identifier = "a" + invalid_identifier = "a" * (IDENTFIER_MAX_LENGTH + 1) + + widget.name_line_edit.setText(".".join([valid_identifier, invalid_identifier])) + assert widget.name_line_edit.hasAcceptableInput() is False + + +def test_name_in_custom_attribute_widget_should_not_allow_missing_identifiers(qtbot): + widget = CustomAttributeWidget(MagicMock(), 1, CustomRequirementsWidget()) + qtbot.addWidget(widget) + + missing_identifier = "a..a" + widget.name_line_edit.setText(missing_identifier) + assert widget.name_line_edit.hasAcceptableInput() is False + + +def test_name_in_custom_attribute_widget_should_not_end_with_period(qtbot): + widget = CustomAttributeWidget(MagicMock(), 1, CustomRequirementsWidget()) + value_widget = CustomAttributeValueWidget(MagicMock(), widget) + qtbot.addWidget(widget) + qtbot.addWidget(value_widget) + + value_widget.line_edit.setText("test") + + identifier_ends_with_period = "a." + widget.name_line_edit.setText(identifier_ends_with_period) + with pytest.raises(NonValidInputError) as e: + widget.get_requirement() + + assert str(e.value) == "Your requirement name cannot end with a period." + + +def test_name_in_custom_attribute_widget_should_not_allow_reserved_first_identifier(qtbot): + widget = CustomAttributeWidget(MagicMock(), 1, CustomRequirementsWidget()) + value_widget = CustomAttributeValueWidget(MagicMock(), widget) + qtbot.addWidget(widget) + qtbot.addWidget(value_widget) + + value_widget.line_edit.setText("test") + + for reserved_identifier in RESERVED_FIRST_IDENTIFIERS: + widget.name_line_edit.setText(reserved_identifier) + with pytest.raises(NonValidInputError) as e: + widget.get_requirement() + + assert str( + e.value + ) == "Please make sure that the first identifier in your name is not a reserved identifier. " + str( + RESERVED_FIRST_IDENTIFIERS + )