Skip to content

Commit

Permalink
Filter while paging in GetByVMID
Browse files Browse the repository at this point in the history
This updates GetByVMID, used by azure joining, to filter by id for
each page of vms instead of filtering after accumulating _all_ vms
in memory first. While this does reduce memory consumption, the listing
process is still suboptimal in that _all_ instances in a subscription
are retrieved in response to _every_ join request.
  • Loading branch information
rosstimothy authored and github-actions committed Dec 10, 2024
1 parent af06f54 commit 182a835
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
// If the token is from a user-assigned managed identity, the resource ID is
// for the identity and we need to look the VM up by VM ID.
} else {
vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID)
vm, err = vmClient.GetByVMID(ctx, vmID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("no VM found with matching VM ID")
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/join_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ func (m *mockAzureVMClient) Get(_ context.Context, resourceID string) (*azure.Vi
return vm, nil
}

func (m *mockAzureVMClient) GetByVMID(_ context.Context, resourceGroup, vmID string) (*azure.VirtualMachine, error) {
func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.VirtualMachine, error) {
for _, vm := range m.vms {
if vm.VMID == vmID && (resourceGroup == types.Wildcard || vm.ResourceGroup == resourceGroup) {
if vm.VMID == vmID {
return vm, nil
}
}
return nil, trace.NotFound("no vm in groups %q with id %q", resourceGroup, vmID)
return nil, trace.NotFound("no vm with id %q", vmID)
}

func makeVMClientGetter(clients map[string]*mockAzureVMClient) vmClientGetter {
Expand Down
24 changes: 14 additions & 10 deletions lib/cloud/azure/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type VirtualMachinesClient interface {
// Get returns the virtual machine for the given resource ID.
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
// GetByVMID returns the virtual machine for a given VM ID.
GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error)
GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error)
// ListVirtualMachines gets all of the virtual machines in the given resource group.
ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error)
}
Expand Down Expand Up @@ -146,15 +146,19 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine,
}

// GetByVMID returns the virtual machine for a given VM ID.
func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) {
vms, err := c.ListVirtualMachines(ctx, resourceGroup)
if err != nil {
return nil, trace.Wrap(err)
}
for _, vm := range vms {
if vm.Properties != nil && *vm.Properties.VMID == vmID {
result, err := parseVirtualMachine(vm)
return result, trace.Wrap(err)
func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) {
pager := newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{}))
for pager.more() {
res, err := pager.nextPage(ctx)
if err != nil {
return nil, trace.Wrap(ConvertResponseError(err))
}

for _, vm := range res {
if vm.Properties != nil && *vm.Properties.VMID == vmID {
result, err := parseVirtualMachine(vm)
return result, trace.Wrap(err)
}
}
}
return nil, trace.NotFound("no VM with ID %q", vmID)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2302,7 +2302,7 @@ func (m *mockAzureClient) Get(_ context.Context, _ string) (*azure.VirtualMachin
return nil, nil
}

func (m *mockAzureClient) GetByVMID(_ context.Context, _, _ string) (*azure.VirtualMachine, error) {
func (m *mockAzureClient) GetByVMID(_ context.Context, _ string) (*azure.VirtualMachine, error) {
return nil, nil
}

Expand Down

0 comments on commit 182a835

Please sign in to comment.