Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ pyink-use-majority-quotes = true
[tool.mypy]
follow_untyped_imports = true
warn_unreachable = true

strict = true
# Current code is not compatible with all of the strict flags:
disallow_any_generics = false
Expand All @@ -101,10 +100,4 @@ disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = false
no_implicit_reexport = false

# Remove follow_imports below once the exclude list is empty:
follow_imports = "silent"
files = "src"
exclude = [
'src/xpk/core/blueprint/blueprint_generator\.py',
]
17 changes: 16 additions & 1 deletion src/xpk/core/blueprint/blueprint_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ class DeploymentModule:
settings: Optional[dict[str, Any]] = None
use: Optional[list[str]] = None

def update_settings(self, additionalSettings: dict[str, Any]) -> None:
if self.settings is None:
self.settings = dict()
self.settings.update(additionalSettings)

def set_setting(self, key: str, value: Any) -> None:
if self.settings is None:
self.settings = dict()
self.settings[key] = value

def append_use(self, use: str) -> None:
if self.use is None:
self.use = list()
self.use.append(use)


@dataclass
class DeploymentGroup:
Expand All @@ -59,6 +74,6 @@ class Blueprint:
blueprint_name: Optional[str]
toolkit_modules_url: str
toolkit_modules_version: str
vars: dict[str, str | list[str]] | None
vars: dict[str, str | list[str] | dict[str, str]] | None
terraform_providers: Optional[dict[str, Any]] = None
validators: Optional[list[Any]] = None
22 changes: 11 additions & 11 deletions src/xpk/core/blueprint/blueprint_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def generate_a3_mega_blueprint(
outputs=["instructions"],
)
if capacity_type == CapacityType.FLEX_START:
a3_megagpu_pool_0.settings.update(self.get_dws_flex_start())
a3_megagpu_pool_0.update_settings(self.get_dws_flex_start())
else:
a3_megagpu_pool_0.settings.update({"static_node_count": num_nodes})
a3_megagpu_pool_0.update_settings({"static_node_count": num_nodes})

set_placement_policy = capacity_type != CapacityType.SPOT
workload = DeploymentModule(
Expand Down Expand Up @@ -252,8 +252,8 @@ def generate_a3_mega_blueprint(

print(reservation_placement_policy)
if reservation_placement_policy is not None:
a3_megagpu_pool_0.settings["placement_policy"] = (
reservation_placement_policy
a3_megagpu_pool_0.set_setting(
"placement_policy", reservation_placement_policy
)

primary_group = DeploymentGroup(
Expand All @@ -268,7 +268,7 @@ def generate_a3_mega_blueprint(
],
)
if set_placement_policy and reservation_placement_policy is None:
a3_megagpu_pool_0.use.append(group_placement_0.id)
a3_megagpu_pool_0.append_use(group_placement_0.id)
primary_group.modules.append(group_placement_0)
a3_mega_blueprint = Blueprint(
terraform_backend_defaults=self._getblock_terraform_backend(
Expand Down Expand Up @@ -580,9 +580,9 @@ def generate_a3_ultra_blueprint(
outputs=["instructions"],
)
if capacity_type == CapacityType.FLEX_START:
gpu_pool.settings.update(self.get_dws_flex_start())
gpu_pool.update_settings(self.get_dws_flex_start())
else:
gpu_pool.settings.update({"static_node_count": num_nodes})
gpu_pool.update_settings({"static_node_count": num_nodes})

workload_manager_install_id = "workload-manager-install"
workload_manager_install = DeploymentModule(
Expand Down Expand Up @@ -855,9 +855,9 @@ def generate_a4_blueprint(
outputs=["instructions"],
)
if capacity_type == CapacityType.FLEX_START:
gpu_pool.settings.update(self.get_dws_flex_start())
gpu_pool.update_settings(self.get_dws_flex_start())
else:
gpu_pool.settings.update({"static_node_count": num_nodes})
gpu_pool.update_settings({"static_node_count": num_nodes})

workload_manager_install_id = "workload-manager-install"
workload_manager_install = DeploymentModule(
Expand Down Expand Up @@ -956,7 +956,7 @@ def _getblock_reservation_affinity(
)

def _getblock_terraform_backend(
self, gcs_bucket: str, cluster_name: str, prefix: str = ""
self, gcs_bucket: str | None, cluster_name: str, prefix: str = ""
) -> dict | None:
if gcs_bucket is None:
return None
Expand Down Expand Up @@ -986,7 +986,7 @@ def _save_blueprint_to_file(
yaml_parser.dump(xpk_blueprint, blueprint_file)
return blueprint_path

def _get_blueprint_path(self, blueprint_name, prefix: str = ""):
def _get_blueprint_path(self, blueprint_name, prefix: str = "") -> str:
blueprint_path = os.path.join(
self._get_storage_path(prefix), f"{blueprint_name}.yaml"
)
Expand Down
Loading