Skip to content

Commit

Permalink
works. todo: only do this for A10 VMs
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Jun 29, 2024
1 parent 3351caa commit 74b669c
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions sky/skylet/providers/azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,29 @@ def _create_node(self, node_config, tags, count):
create_or_update = get_azure_sdk_function(
client=self.resource_client.deployments, function_name="create_or_update"
)
create_or_update(
poller = create_or_update(
resource_group_name=resource_group,
deployment_name=vm_name,
parameters=parameters,
).wait()
)
poller.wait()

# Configure driver extension for A10 GPUs
# TODO(tian): Only do this for A10 vms.
logger.info("Begin to configure A10 driver extension for VM: %s", vm_name)
self._configure_a10_driver_extension(vm_name)
logger.info("A10 driver extension configured for VM: %s", vm_name)
create_result = poller.result().as_dict()
output_resources = create_result.get("properties", {}).get(
"output_resources", []
)
vms_to_add_driver = []
for r in output_resources:
r_id = r.get("id", "")
if "Microsoft.Compute/virtualMachines" in r_id:
vms_to_add_driver.append(r_id.split("/")[-1])

for v in vms_to_add_driver:
logger.info(f"Begin to configure A10 driver extension for VM: {v}")
self._configure_a10_driver_extension(v)
logger.info(f"A10 driver extension configured for VM: {v}")

def _configure_a10_driver_extension(self, vm_name):
resource_group = self.provider_config["resource_group"]
Expand All @@ -345,9 +357,7 @@ def _configure_a10_driver_extension(self, vm_name):
"parameters": {
"vmName": {
"type": "string",
"metadata": {
"description": "Name of the virtual machine"
}
"metadata": {"description": "Name of the virtual machine"},
}
},
"resources": [
Expand Down

0 comments on commit 74b669c

Please sign in to comment.