From 182a8355678fd3ab36ca767c308f992396e69a72 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Mon, 9 Dec 2024 16:09:49 -0500 Subject: [PATCH] Filter while paging in GetByVMID This updates GetByVMID, used by azure joining, to filter by id for each page of vms instead of filtering after accumulating _all_ vms in memory first. While this does reduce memory consumption, the listing process is still suboptimal in that _all_ instances in a subscription are retrieved in response to _every_ join request. --- lib/auth/join_azure.go | 2 +- lib/auth/join_azure_test.go | 6 +++--- lib/cloud/azure/vm.go | 24 ++++++++++++++---------- lib/srv/discovery/discovery_test.go | 2 +- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 4ce6311a6970f..311fea19ca14e 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -269,7 +269,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken // If the token is from a user-assigned managed identity, the resource ID is // for the identity and we need to look the VM up by VM ID. } else { - vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID) + vm, err = vmClient.GetByVMID(ctx, vmID) if err != nil { if trace.IsNotFound(err) { return nil, trace.AccessDenied("no VM found with matching VM ID") diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index faa9bb6f0cb95..50b026598dcd7 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -73,13 +73,13 @@ func (m *mockAzureVMClient) Get(_ context.Context, resourceID string) (*azure.Vi return vm, nil } -func (m *mockAzureVMClient) GetByVMID(_ context.Context, resourceGroup, vmID string) (*azure.VirtualMachine, error) { +func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.VirtualMachine, error) { for _, vm := range m.vms { - if vm.VMID == vmID && (resourceGroup == types.Wildcard || vm.ResourceGroup == resourceGroup) { + if vm.VMID == vmID { return vm, nil } } - return nil, trace.NotFound("no vm in groups %q with id %q", resourceGroup, vmID) + return nil, trace.NotFound("no vm with id %q", vmID) } func makeVMClientGetter(clients map[string]*mockAzureVMClient) vmClientGetter { diff --git a/lib/cloud/azure/vm.go b/lib/cloud/azure/vm.go index 2c9242712d4ba..48572cbc2066e 100644 --- a/lib/cloud/azure/vm.go +++ b/lib/cloud/azure/vm.go @@ -46,7 +46,7 @@ type VirtualMachinesClient interface { // Get returns the virtual machine for the given resource ID. Get(ctx context.Context, resourceID string) (*VirtualMachine, error) // GetByVMID returns the virtual machine for a given VM ID. - GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) + GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) // ListVirtualMachines gets all of the virtual machines in the given resource group. ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) } @@ -146,15 +146,19 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, } // GetByVMID returns the virtual machine for a given VM ID. -func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) { - vms, err := c.ListVirtualMachines(ctx, resourceGroup) - if err != nil { - return nil, trace.Wrap(err) - } - for _, vm := range vms { - if vm.Properties != nil && *vm.Properties.VMID == vmID { - result, err := parseVirtualMachine(vm) - return result, trace.Wrap(err) +func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) { + pager := newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{})) + for pager.more() { + res, err := pager.nextPage(ctx) + if err != nil { + return nil, trace.Wrap(ConvertResponseError(err)) + } + + for _, vm := range res { + if vm.Properties != nil && *vm.Properties.VMID == vmID { + result, err := parseVirtualMachine(vm) + return result, trace.Wrap(err) + } } } return nil, trace.NotFound("no VM with ID %q", vmID) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index e7b0ec6f08347..a5396b4ccf53a 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -2302,7 +2302,7 @@ func (m *mockAzureClient) Get(_ context.Context, _ string) (*azure.VirtualMachin return nil, nil } -func (m *mockAzureClient) GetByVMID(_ context.Context, _, _ string) (*azure.VirtualMachine, error) { +func (m *mockAzureClient) GetByVMID(_ context.Context, _ string) (*azure.VirtualMachine, error) { return nil, nil }