Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] Filter while paging in GetByVMID #50000

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
// If the token is from a user-assigned managed identity, the resource ID is
// for the identity and we need to look the VM up by VM ID.
} else {
vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID)
vm, err = vmClient.GetByVMID(ctx, vmID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("no VM found with matching VM ID")
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/join_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ func (m *mockAzureVMClient) Get(_ context.Context, resourceID string) (*azure.Vi
return vm, nil
}

func (m *mockAzureVMClient) GetByVMID(_ context.Context, resourceGroup, vmID string) (*azure.VirtualMachine, error) {
func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.VirtualMachine, error) {
for _, vm := range m.vms {
if vm.VMID == vmID && (resourceGroup == types.Wildcard || vm.ResourceGroup == resourceGroup) {
if vm.VMID == vmID {
return vm, nil
}
}
return nil, trace.NotFound("no vm in groups %q with id %q", resourceGroup, vmID)
return nil, trace.NotFound("no vm with id %q", vmID)
}

func makeVMClientGetter(clients map[string]*mockAzureVMClient) vmClientGetter {
Expand Down
24 changes: 14 additions & 10 deletions lib/cloud/azure/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type VirtualMachinesClient interface {
// Get returns the virtual machine for the given resource ID.
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
// GetByVMID returns the virtual machine for a given VM ID.
GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error)
GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error)
// ListVirtualMachines gets all of the virtual machines in the given resource group.
ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error)
}
Expand Down Expand Up @@ -146,15 +146,19 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine,
}

// GetByVMID returns the virtual machine for a given VM ID.
func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) {
vms, err := c.ListVirtualMachines(ctx, resourceGroup)
if err != nil {
return nil, trace.Wrap(err)
}
for _, vm := range vms {
if vm.Properties != nil && *vm.Properties.VMID == vmID {
result, err := parseVirtualMachine(vm)
return result, trace.Wrap(err)
func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) {
pager := newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{}))
for pager.more() {
res, err := pager.nextPage(ctx)
if err != nil {
return nil, trace.Wrap(ConvertResponseError(err))
}

for _, vm := range res {
if vm.Properties != nil && *vm.Properties.VMID == vmID {
result, err := parseVirtualMachine(vm)
return result, trace.Wrap(err)
}
}
}
return nil, trace.NotFound("no VM with ID %q", vmID)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2302,7 +2302,7 @@ func (m *mockAzureClient) Get(_ context.Context, _ string) (*azure.VirtualMachin
return nil, nil
}

func (m *mockAzureClient) GetByVMID(_ context.Context, _, _ string) (*azure.VirtualMachine, error) {
func (m *mockAzureClient) GetByVMID(_ context.Context, _ string) (*azure.VirtualMachine, error) {
return nil, nil
}

Expand Down
Loading