Skip to content

Commit 7be3732

Browse files
authored
Fix remaining type checking exceptions. (#774)
* Fix remaining type checking exceptions.
1 parent f747414 commit 7be3732

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

pyproject.toml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ pyink-use-majority-quotes = true
9292
[tool.mypy]
9393
follow_untyped_imports = true
9494
warn_unreachable = true
95-
9695
strict = true
9796
# Current code is not compatible with all of the strict flags:
9897
disallow_any_generics = false
@@ -101,10 +100,4 @@ disallow_untyped_defs = false
101100
disallow_incomplete_defs = false
102101
check_untyped_defs = false
103102
no_implicit_reexport = false
104-
105-
# Remove follow_imports below once the exclude list is empty:
106-
follow_imports = "silent"
107103
files = "src"
108-
exclude = [
109-
'src/xpk/core/blueprint/blueprint_generator\.py',
110-
]

src/xpk/core/blueprint/blueprint_definitions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ class DeploymentModule:
3636
settings: Optional[dict[str, Any]] = None
3737
use: Optional[list[str]] = None
3838

39+
def update_settings(self, additionalSettings: dict[str, Any]) -> None:
40+
if self.settings is None:
41+
self.settings = dict()
42+
self.settings.update(additionalSettings)
43+
44+
def set_setting(self, key: str, value: Any) -> None:
45+
if self.settings is None:
46+
self.settings = dict()
47+
self.settings[key] = value
48+
49+
def append_use(self, use: str) -> None:
50+
if self.use is None:
51+
self.use = list()
52+
self.use.append(use)
53+
3954

4055
@dataclass
4156
class DeploymentGroup:
@@ -59,6 +74,6 @@ class Blueprint:
5974
blueprint_name: Optional[str]
6075
toolkit_modules_url: str
6176
toolkit_modules_version: str
62-
vars: dict[str, str | list[str]] | None
77+
vars: dict[str, str | list[str] | dict[str, str]] | None
6378
terraform_providers: Optional[dict[str, Any]] = None
6479
validators: Optional[list[Any]] = None

src/xpk/core/blueprint/blueprint_generator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def generate_a3_mega_blueprint(
211211
outputs=["instructions"],
212212
)
213213
if capacity_type == CapacityType.FLEX_START:
214-
a3_megagpu_pool_0.settings.update(self.get_dws_flex_start())
214+
a3_megagpu_pool_0.update_settings(self.get_dws_flex_start())
215215
else:
216-
a3_megagpu_pool_0.settings.update({"static_node_count": num_nodes})
216+
a3_megagpu_pool_0.update_settings({"static_node_count": num_nodes})
217217

218218
set_placement_policy = capacity_type != CapacityType.SPOT
219219
workload = DeploymentModule(
@@ -252,8 +252,8 @@ def generate_a3_mega_blueprint(
252252

253253
print(reservation_placement_policy)
254254
if reservation_placement_policy is not None:
255-
a3_megagpu_pool_0.settings["placement_policy"] = (
256-
reservation_placement_policy
255+
a3_megagpu_pool_0.set_setting(
256+
"placement_policy", reservation_placement_policy
257257
)
258258

259259
primary_group = DeploymentGroup(
@@ -268,7 +268,7 @@ def generate_a3_mega_blueprint(
268268
],
269269
)
270270
if set_placement_policy and reservation_placement_policy is None:
271-
a3_megagpu_pool_0.use.append(group_placement_0.id)
271+
a3_megagpu_pool_0.append_use(group_placement_0.id)
272272
primary_group.modules.append(group_placement_0)
273273
a3_mega_blueprint = Blueprint(
274274
terraform_backend_defaults=self._getblock_terraform_backend(
@@ -580,9 +580,9 @@ def generate_a3_ultra_blueprint(
580580
outputs=["instructions"],
581581
)
582582
if capacity_type == CapacityType.FLEX_START:
583-
gpu_pool.settings.update(self.get_dws_flex_start())
583+
gpu_pool.update_settings(self.get_dws_flex_start())
584584
else:
585-
gpu_pool.settings.update({"static_node_count": num_nodes})
585+
gpu_pool.update_settings({"static_node_count": num_nodes})
586586

587587
workload_manager_install_id = "workload-manager-install"
588588
workload_manager_install = DeploymentModule(
@@ -855,9 +855,9 @@ def generate_a4_blueprint(
855855
outputs=["instructions"],
856856
)
857857
if capacity_type == CapacityType.FLEX_START:
858-
gpu_pool.settings.update(self.get_dws_flex_start())
858+
gpu_pool.update_settings(self.get_dws_flex_start())
859859
else:
860-
gpu_pool.settings.update({"static_node_count": num_nodes})
860+
gpu_pool.update_settings({"static_node_count": num_nodes})
861861

862862
workload_manager_install_id = "workload-manager-install"
863863
workload_manager_install = DeploymentModule(
@@ -956,7 +956,7 @@ def _getblock_reservation_affinity(
956956
)
957957

958958
def _getblock_terraform_backend(
959-
self, gcs_bucket: str, cluster_name: str, prefix: str = ""
959+
self, gcs_bucket: str | None, cluster_name: str, prefix: str = ""
960960
) -> dict | None:
961961
if gcs_bucket is None:
962962
return None
@@ -986,7 +986,7 @@ def _save_blueprint_to_file(
986986
yaml_parser.dump(xpk_blueprint, blueprint_file)
987987
return blueprint_path
988988

989-
def _get_blueprint_path(self, blueprint_name, prefix: str = ""):
989+
def _get_blueprint_path(self, blueprint_name, prefix: str = "") -> str:
990990
blueprint_path = os.path.join(
991991
self._get_storage_path(prefix), f"{blueprint_name}.yaml"
992992
)

0 commit comments

Comments
 (0)