From 36ee09d8f2a013db01e0412913a3d21b10225df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Sat, 2 Nov 2024 22:37:49 +0100 Subject: [PATCH] wip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jan-Otto Kröpke --- internal/mi/mi_test.go | 34 ++++++++ internal/mi/operation.go | 12 +++ internal/mi/session.go | 170 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 203 insertions(+), 13 deletions(-) diff --git a/internal/mi/mi_test.go b/internal/mi/mi_test.go index b3a834321..ca0a0c199 100644 --- a/internal/mi/mi_test.go +++ b/internal/mi/mi_test.go @@ -118,6 +118,40 @@ func Test_MI_Query(t *testing.T) { require.NoError(t, err) } +func Test_MI_QueryUnmarshal(t *testing.T) { + application, err := mi.Application_Initialize() + require.NoError(t, err) + require.NotEmpty(t, application) + + destinationOptions, err := application.NewDestinationOptions() + require.NoError(t, err) + require.NotEmpty(t, destinationOptions) + + err = destinationOptions.SetTimeout(1 * time.Second) + require.NoError(t, err) + + err = destinationOptions.SetLocale(mi.LocaleEnglish) + require.NoError(t, err) + + session, err := application.NewSession(destinationOptions) + require.NoError(t, err) + require.NotEmpty(t, session) + + var processes []win32Process + + err = session.QueryUnmarshal(mi.OperationFlagsStandardRTTI, nil, mi.NamespaceRootCIMv2, mi.QueryDialectWQL, + "select Name from win32_process where handle = 0", &processes) + + require.NoError(t, err) + require.Equal(t, []win32Process{{Name: "System Idle Process"}}, processes) + + err = session.Close() + require.NoError(t, err) + + err = application.Close() + require.NoError(t, err) +} + func Test_MI_EmptyQuery(t *testing.T) { application, err := mi.Application_Initialize() require.NoError(t, err) diff --git a/internal/mi/operation.go b/internal/mi/operation.go index e94df5221..1521949ba 100644 --- a/internal/mi/operation.go +++ b/internal/mi/operation.go @@ -70,6 +70,18 @@ type OperationOptionsFT struct { GetInterval uintptr } +type OperationCallbacks struct { + CallbackContext uintptr + PromptUser uintptr + WriteError uintptr + WriteMessage uintptr + WriteProgress uintptr + InstanceResult uintptr + IndicationResult uintptr + ClassResult uintptr + StreamedParameterResult uintptr +} + // Close closes an operation handle. // // https://learn.microsoft.com/en-us/windows/win32/api/mi/nf-mi-mi_operation_close diff --git a/internal/mi/session.go b/internal/mi/session.go index ce79216ab..74dd3bdfb 100644 --- a/internal/mi/session.go +++ b/internal/mi/session.go @@ -3,7 +3,9 @@ package mi import ( "errors" "fmt" + "reflect" "syscall" + "time" "unsafe" "golang.org/x/sys/windows" @@ -167,29 +169,171 @@ func (s *Session) QueryInstances(flags OperationFlags, operationOptions *Operati return operation, nil } -// Query queries for a set of instances based on a query expression. -func (s *Session) Query(dst any, namespaceName Namespace, queryExpression string) error { - operation, err := s.QueryInstances(OperationFlagsStandardRTTI, nil, namespaceName, QueryDialectWQL, queryExpression) +// QueryUnmarshal queries for a set of instances based on a query expression. +// +// https://learn.microsoft.com/en-us/windows/win32/api/mi/nf-mi-mi_session_queryinstances +func (s *Session) QueryUnmarshal(flags OperationFlags, operationOptions *OperationOptions, namespaceName Namespace, + queryDialect QueryDialect, queryExpression string, dst any, +) error { + if s == nil || s.ft == nil { + return ErrNotInitialized + } + + queryExpressionUTF16, err := windows.UTF16PtrFromString(queryExpression) if err != nil { - return fmt.Errorf("WMI query failed: %w", err) + return err } - if err = operation.Unmarshal(dst); err != nil { - return fmt.Errorf("failed to unmarshal WMI query results: %w", err) + operation := &Operation{} + + if operationOptions == nil { + operationOptions = s.defaultOperationOptions } - for { - instance, moreResults, err := operation.GetInstance() + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || dv.IsNil() { + return ErrInvalidEntityType + } + + dv = dv.Elem() - _, _ = instance, err + elemType := dv.Type().Elem() + elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem() - if !moreResults { - break + if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct { + return ErrInvalidEntityType + } + + dv.Set(reflect.MakeSlice(dv.Type(), 0, 0)) + + errCh := make(chan error, 1) + + operationCallbacks := &OperationCallbacks{ + InstanceResult: windows.NewCallback(func( + operation *Operation, + _ uintptr, + instance *Instance, + moreResults Boolean, + instanceResult ResultError, + errorMessageUTF16 *uint16, + errorDetails *Instance, + _ uintptr, + ) uintptr { + defer func() { + if moreResults == False { + close(errCh) + } + }() + + if !errors.Is(instanceResult, MI_RESULT_OK) { + errCh <- fmt.Errorf("%w: %s", instanceResult, windows.UTF16PtrToString(errorMessageUTF16)) + + return 0 + } + + if instance == nil { + return 0 + } + + counter, err := instance.GetElementCount() + if err != nil { + errCh <- fmt.Errorf("failed to get element count: %w", err) + + return 0 + } + + if counter == 0 { + return 0 + } + + for i := range elemType.NumField() { + field := elemValue.Field(i) + + // Check if the field has an `mi` tag + miTag := elemType.Field(i).Tag.Get("mi") + if miTag == "" { + continue + } + + element, err := instance.GetElement(miTag) + if err != nil { + errCh <- fmt.Errorf("failed to get element: %w", err) + + return 0 + } + + switch element.valueType { + case ValueTypeBOOLEAN: + field.SetBool(element.value == 1) + case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64: + field.SetUint(uint64(element.value)) + case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64: + field.SetInt(int64(element.value)) + case ValueTypeSTRING: + if element.value == 0 { + errCh <- fmt.Errorf("%s: invalid pointer: value is nil", miTag) + + return 0 + } + + // Convert the UTF-16 string to a Go string + stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value))) + + field.SetString(stringValue) + case ValueTypeREAL32, ValueTypeREAL64: + field.SetFloat(float64(element.value)) + default: + errCh <- fmt.Errorf("unsupported value type: %d", element.valueType) + + return 0 + } + } + + dv.Set(reflect.Append(dv, elemValue)) + + return 0 + }), + } + + r0, _, _ := syscall.SyscallN( + s.ft.QueryInstances, + uintptr(unsafe.Pointer(s)), + uintptr(flags), + uintptr(unsafe.Pointer(operationOptions)), + uintptr(unsafe.Pointer(namespaceName)), + uintptr(unsafe.Pointer(queryDialect)), + uintptr(unsafe.Pointer(queryExpressionUTF16)), + uintptr(unsafe.Pointer(operationCallbacks)), + uintptr(unsafe.Pointer(operation)), + ) + + if result := ResultError(r0); !errors.Is(result, MI_RESULT_OK) { + return result + } + + defer operation.Close() + + var errs []error + + for { + select { + case err := <-errCh: + if err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) + case <-time.After(10 * time.Second): + return errors.New("timeout") } } +} - if err = operation.Close(); err != nil { - return fmt.Errorf("failed to close WMI query operation: %w", err) +// Query queries for a set of instances based on a query expression. +func (s *Session) Query(dst any, namespaceName Namespace, queryExpression string) error { + err := s.QueryUnmarshal(OperationFlagsStandardRTTI, nil, namespaceName, QueryDialectWQL, queryExpression, dst) + if err != nil { + return fmt.Errorf("WMI query failed: %w", err) } return nil