diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d3e40d854f57..0df9a9a58779 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -173,6 +173,8 @@ CHANGELOG* /x-pack/filebeat/modules.d/zoom.yml.disabled @elastic/security-service-integrations /x-pack/filebeat/processors/decode_cef/ @elastic/sec-deployment-and-devices /x-pack/heartbeat/ @elastic/obs-ds-hosted-services +/x-pack/libbeat/reader/parquet/ @elastic/security-service-integrations +/x-pack/libbeat/reader/etw/ @elastic/sec-windows-platform /x-pack/metricbeat/ @elastic/elastic-agent-data-plane /x-pack/metricbeat/docs/ # Listed without an owner to avoid maintaining doc ownership for each input and module. /x-pack/metricbeat/module/activemq @elastic/obs-infraobs-integrations @@ -219,4 +221,3 @@ CHANGELOG* /x-pack/osquerybeat/ @elastic/sec-deployment-and-devices /x-pack/packetbeat/ @elastic/sec-linux-platform /x-pack/winlogbeat/ @elastic/sec-windows-platform -/x-pack/libbeat/reader/parquet/ @elastic/security-service-integrations diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 645409067f14..8281f7b79ecb 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -181,6 +181,8 @@ Setting environmental variable ELASTIC_NETINFO:false in Elastic Agent pod will d *Libbeat* - Add watcher that can be used to monitor Linux kernel events. {pull}37833[37833] +- Added support for ETW reader. {pull}36914[36914] + *Heartbeat* - Added status to monitor run log report. - Upgrade github.com/elastic/go-elasticsearch/v8 to v8.12.0. {pull}37673[37673] diff --git a/x-pack/libbeat/Jenkinsfile.yml b/x-pack/libbeat/Jenkinsfile.yml index 9d4ecfa7bd08..9947fd0096c6 100644 --- a/x-pack/libbeat/Jenkinsfile.yml +++ b/x-pack/libbeat/Jenkinsfile.yml @@ -27,6 +27,43 @@ stages: branches: true ## for all the branches tags: true ## for all the tags stage: extended + ## For now Windows CI tests for Libbeat are only enabled for ETW + ## It only contains Go tests + windows-2022: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-2022" + stage: mandatory + windows-2019: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-2019" + stage: extended_win + windows-2016: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-2016" + stage: mandatory + windows-2012: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-2012-r2" + stage: extended_win + windows-11: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-11" + stage: extended_win + windows-10: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-10" + stage: extended_win + windows-8: + mage: "mage -w reader/etw build goUnitTest" + platforms: ## override default labels in this specific stage. + - "windows-8" + stage: extended_win unitTest: mage: "mage build unitTest" stage: mandatory diff --git a/x-pack/libbeat/reader/etw/config.go b/x-pack/libbeat/reader/etw/config.go new file mode 100644 index 000000000000..44f9e68ff2d0 --- /dev/null +++ b/x-pack/libbeat/reader/etw/config.go @@ -0,0 +1,16 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package etw + +type Config struct { + Logfile string // Path to the logfile + ProviderGUID string // GUID of the ETW provider + ProviderName string // Name of the ETW provider + SessionName string // Name for new ETW session + TraceLevel string // Level of tracing (e.g., "verbose") + MatchAnyKeyword uint64 // Filter for any matching keywords (bitmask) + MatchAllKeyword uint64 // Filter for all matching keywords (bitmask) + Session string // Existing session to attach +} diff --git a/x-pack/libbeat/reader/etw/controller.go b/x-pack/libbeat/reader/etw/controller.go new file mode 100644 index 000000000000..f17866440cfc --- /dev/null +++ b/x-pack/libbeat/reader/etw/controller.go @@ -0,0 +1,121 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "errors" + "fmt" + "syscall" +) + +// AttachToExistingSession queries the status of an existing ETW session. +// On success, it updates the Session's handler with the queried information. +func (s *Session) AttachToExistingSession() error { + // Convert the session name to UTF16 for Windows API compatibility. + sessionNamePtr, err := syscall.UTF16PtrFromString(s.Name) + if err != nil { + return fmt.Errorf("failed to convert session name: %w", err) + } + + // Query the current state of the ETW session. + err = s.controlTrace(0, sessionNamePtr, s.properties, EVENT_TRACE_CONTROL_QUERY) + switch { + case err == nil: + // Get the session handler from the properties struct. + s.handler = uintptr(s.properties.Wnode.Union1) + + return nil + + // Handle specific errors related to the query operation. + case errors.Is(err, ERROR_BAD_LENGTH): + return fmt.Errorf("bad length when querying handler: %w", err) + case errors.Is(err, ERROR_INVALID_PARAMETER): + return fmt.Errorf("invalid parameters when querying handler: %w", err) + case errors.Is(err, ERROR_WMI_INSTANCE_NOT_FOUND): + return fmt.Errorf("session is not running: %w", err) + default: + return fmt.Errorf("failed to get handler: %w", err) + } +} + +// CreateRealtimeSession initializes and starts a new real-time ETW session. +func (s *Session) CreateRealtimeSession() error { + // Convert the session name to UTF16 format for Windows API compatibility. + sessionPtr, err := syscall.UTF16PtrFromString(s.Name) + if err != nil { + return fmt.Errorf("failed to convert session name: %w", err) + } + + // Start the ETW trace session. + err = s.startTrace(&s.handler, sessionPtr, s.properties) + switch { + case err == nil: + + // Handle specific errors related to starting the trace session. + case errors.Is(err, ERROR_ALREADY_EXISTS): + return fmt.Errorf("session already exists: %w", err) + case errors.Is(err, ERROR_INVALID_PARAMETER): + return fmt.Errorf("invalid parameters when starting session trace: %w", err) + default: + return fmt.Errorf("failed to start trace: %w", err) + } + + // Set additional parameters for trace enabling. + // See https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-enable_trace_parameters#members + params := EnableTraceParameters{ + Version: 2, // ENABLE_TRACE_PARAMETERS_VERSION_2 + } + + // Zero timeout means asynchronous enablement + const timeout = 0 + + // Enable the trace session with extended options. + err = s.enableTrace(s.handler, &s.GUID, EVENT_CONTROL_CODE_ENABLE_PROVIDER, s.traceLevel, s.matchAnyKeyword, s.matchAllKeyword, timeout, ¶ms) + switch { + case err == nil: + return nil + // Handle specific errors related to enabling the trace session. + case errors.Is(err, ERROR_INVALID_PARAMETER): + return fmt.Errorf("invalid parameters when enabling session trace: %w", err) + case errors.Is(err, ERROR_TIMEOUT): + return fmt.Errorf("timeout value expired before the enable callback completed: %w", err) + case errors.Is(err, ERROR_NO_SYSTEM_RESOURCES): + return fmt.Errorf("exceeded the number of trace sessions that can enable the provider: %w", err) + default: + return fmt.Errorf("failed to enable trace: %w", err) + } +} + +// StopSession closes the ETW session and associated handles if they were created. +func (s *Session) StopSession() error { + if !s.Realtime { + return nil + } + + if isValidHandler(s.traceHandler) { + // Attempt to close the trace and handle potential errors. + if err := s.closeTrace(s.traceHandler); err != nil && !errors.Is(err, ERROR_CTX_CLOSE_PENDING) { + return fmt.Errorf("failed to close trace: %w", err) + } + } + + if s.NewSession { + // If we created the session, send a control command to stop it. + return s.controlTrace( + s.handler, + nil, + s.properties, + EVENT_TRACE_CONTROL_STOP, + ) + } + + return nil +} + +func isValidHandler(handler uint64) bool { + return handler != 0 && handler != INVALID_PROCESSTRACE_HANDLE +} diff --git a/x-pack/libbeat/reader/etw/controller_test.go b/x-pack/libbeat/reader/etw/controller_test.go new file mode 100644 index 000000000000..0c663433ad1f --- /dev/null +++ b/x-pack/libbeat/reader/etw/controller_test.go @@ -0,0 +1,190 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/windows" +) + +func TestAttachToExistingSession_Error(t *testing.T) { + // Mock implementation of controlTrace + controlTrace := func(traceHandle uintptr, + instanceName *uint16, + properties *EventTraceProperties, + controlCode uint32) error { + return ERROR_WMI_INSTANCE_NOT_FOUND + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + properties: &EventTraceProperties{}, + controlTrace: controlTrace, + } + + err := session.AttachToExistingSession() + assert.EqualError(t, err, "session is not running: The instance name passed was not recognized as valid by a WMI data provider.") +} + +func TestAttachToExistingSession_Success(t *testing.T) { + // Mock implementation of controlTrace + controlTrace := func(traceHandle uintptr, + instanceName *uint16, + properties *EventTraceProperties, + controlCode uint32) error { + // Set a mock handler value + properties.Wnode.Union1 = 12345 + return nil + } + + // Create a Session instance with initialized Properties + session := &Session{ + Name: "TestSession", + properties: &EventTraceProperties{}, + controlTrace: controlTrace, + } + + err := session.AttachToExistingSession() + + assert.NoError(t, err) + assert.Equal(t, uintptr(12345), session.handler, "Handler should be set to the mock value") +} + +func TestCreateRealtimeSession_StartTraceError(t *testing.T) { + // Mock implementation of startTrace + startTrace := func(traceHandle *uintptr, + instanceName *uint16, + properties *EventTraceProperties) error { + return ERROR_ALREADY_EXISTS + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + properties: &EventTraceProperties{}, + startTrace: startTrace, + } + + err := session.CreateRealtimeSession() + assert.EqualError(t, err, "session already exists: Cannot create a file when that file already exists.") +} + +func TestCreateRealtimeSession_EnableTraceError(t *testing.T) { + // Mock implementations + startTrace := func(traceHandle *uintptr, + instanceName *uint16, + properties *EventTraceProperties) error { + *traceHandle = 12345 // Mock handler value + return nil + } + + enableTrace := func(traceHandle uintptr, + providerId *windows.GUID, + isEnabled uint32, + level uint8, + matchAnyKeyword uint64, + matchAllKeyword uint64, + enableProperty uint32, + enableParameters *EnableTraceParameters) error { + return ERROR_INVALID_PARAMETER + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + properties: &EventTraceProperties{}, + startTrace: startTrace, + enableTrace: enableTrace, + } + + err := session.CreateRealtimeSession() + assert.EqualError(t, err, "invalid parameters when enabling session trace: The parameter is incorrect.") +} + +func TestCreateRealtimeSession_Success(t *testing.T) { + // Mock implementations + startTrace := func(traceHandle *uintptr, + instanceName *uint16, + properties *EventTraceProperties) error { + *traceHandle = 12345 // Mock handler value + return nil + } + + enableTrace := func(traceHandle uintptr, + providerId *windows.GUID, + isEnabled uint32, + level uint8, + matchAnyKeyword uint64, + matchAllKeyword uint64, + enableProperty uint32, + enableParameters *EnableTraceParameters) error { + return nil + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + properties: &EventTraceProperties{}, + startTrace: startTrace, + enableTrace: enableTrace, + } + + err := session.CreateRealtimeSession() + + assert.NoError(t, err) + assert.Equal(t, uintptr(12345), session.handler, "Handler should be set to the mock value") +} + +func TestStopSession_Error(t *testing.T) { + // Mock implementation of closeTrace + closeTrace := func(traceHandle uint64) error { + return ERROR_INVALID_PARAMETER + } + + // Create a Session instance + session := &Session{ + Realtime: true, + NewSession: true, + traceHandler: 12345, // Example handler value + properties: &EventTraceProperties{}, + closeTrace: closeTrace, + } + + err := session.StopSession() + assert.EqualError(t, err, "failed to close trace: The parameter is incorrect.") +} + +func TestStopSession_Success(t *testing.T) { + // Mock implementations + closeTrace := func(traceHandle uint64) error { + return nil + } + + controlTrace := func(traceHandle uintptr, + instanceName *uint16, + properties *EventTraceProperties, + controlCode uint32) error { + // Set a mock handler value + return nil + } + + // Create a Session instance + session := &Session{ + Realtime: true, + NewSession: true, + traceHandler: 12345, // Example handler value + properties: &EventTraceProperties{}, + closeTrace: closeTrace, + controlTrace: controlTrace, + } + + err := session.StopSession() + assert.NoError(t, err) +} diff --git a/x-pack/libbeat/reader/etw/event.go b/x-pack/libbeat/reader/etw/event.go new file mode 100644 index 000000000000..34faa8d21cb7 --- /dev/null +++ b/x-pack/libbeat/reader/etw/event.go @@ -0,0 +1,340 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "errors" + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +// propertyParser is used for parsing properties from raw EVENT_RECORD structures. +type propertyParser struct { + r *EventRecord + info *TraceEventInfo + data []byte + ptrSize uint32 +} + +// GetEventProperties extracts and returns properties from an ETW event record. +func GetEventProperties(r *EventRecord) (map[string]interface{}, error) { + // Handle the case where the event only contains a string. + if r.EventHeader.Flags == EVENT_HEADER_FLAG_STRING_ONLY { + userDataPtr := (*uint16)(unsafe.Pointer(r.UserData)) + return map[string]interface{}{ + "_": utf16AtOffsetToString(uintptr(unsafe.Pointer(userDataPtr)), 0), // Convert the user data from UTF16 to string. + }, nil + } + + // Initialize a new property parser for the event record. + p, err := newPropertyParser(r) + if err != nil { + return nil, fmt.Errorf("failed to parse event properties: %w", err) + } + + // Iterate through each property of the event and format it + properties := make(map[string]interface{}, int(p.info.TopLevelPropertyCount)) + for i := 0; i < int(p.info.TopLevelPropertyCount); i++ { + name := p.getPropertyName(i) + value, err := p.getPropertyValue(i) + if err != nil { + return nil, fmt.Errorf("failed to parse %q value: %w", name, err) + } + properties[name] = value + } + + return properties, nil +} + +// newPropertyParser initializes a new property parser for a given event record. +func newPropertyParser(r *EventRecord) (*propertyParser, error) { + info, err := getEventInformation(r) + if err != nil { + return nil, fmt.Errorf("failed to get event information: %w", err) + } + ptrSize := r.pointerSize() + // Return a new propertyParser instance initialized with event record data and metadata. + return &propertyParser{ + r: r, + info: info, + ptrSize: ptrSize, + data: unsafe.Slice((*uint8)(unsafe.Pointer(r.UserData)), r.UserDataLength), + }, nil +} + +// getEventPropertyInfoAtIndex looks for the EventPropertyInfo object at a specified index. +func (info *TraceEventInfo) getEventPropertyInfoAtIndex(i uint32) *EventPropertyInfo { + if i < info.PropertyCount { + // Calculate the address of the first element in EventPropertyInfoArray. + eventPropertyInfoPtr := uintptr(unsafe.Pointer(&info.EventPropertyInfoArray[0])) + // Adjust the pointer to point to the i-th EventPropertyInfo element. + eventPropertyInfoPtr += uintptr(i) * unsafe.Sizeof(EventPropertyInfo{}) + + return ((*EventPropertyInfo)(unsafe.Pointer(eventPropertyInfoPtr))) + } + return nil +} + +// getEventInformation retrieves detailed metadata about an event record. +func getEventInformation(r *EventRecord) (info *TraceEventInfo, err error) { + // Initially call TdhGetEventInformation to get the required buffer size. + var bufSize uint32 + if err = _TdhGetEventInformation(r, 0, nil, nil, &bufSize); errors.Is(err, ERROR_INSUFFICIENT_BUFFER) { + // Allocate enough memory for TRACE_EVENT_INFO based on the required size. + buff := make([]byte, bufSize) + info = ((*TraceEventInfo)(unsafe.Pointer(&buff[0]))) + // Retrieve the event information into the allocated buffer. + err = _TdhGetEventInformation(r, 0, nil, info, &bufSize) + } + + // Check for errors in retrieving the event information. + if err != nil { + return nil, fmt.Errorf("TdhGetEventInformation failed: %w", err) + } + + return info, nil +} + +// getPropertyName retrieves the name of the i-th event property in the event record. +func (p *propertyParser) getPropertyName(i int) string { + // Convert the UTF16 property name to a Go string. + namePtr := readPropertyName(p, i) + return windows.UTF16PtrToString((*uint16)(namePtr)) +} + +// readPropertyName gets the pointer to the property name in the event information structure. +func readPropertyName(p *propertyParser, i int) unsafe.Pointer { + // Calculate the pointer to the property name using its offset in the event property array. + return unsafe.Add(unsafe.Pointer(p.info), p.info.getEventPropertyInfoAtIndex(uint32(i)).NameOffset) +} + +// getPropertyValue retrieves the value of a specified event property. +func (p *propertyParser) getPropertyValue(i int) (interface{}, error) { + propertyInfo := p.info.getEventPropertyInfoAtIndex(uint32(i)) + + // Determine the size of the property array. + arraySize, err := p.getArraySize(*propertyInfo) + if err != nil { + return nil, fmt.Errorf("failed to get array size: %w", err) + } + + // Initialize a slice to hold the property values. + result := make([]interface{}, arraySize) + for j := 0; j < int(arraySize); j++ { + var ( + value interface{} + err error + ) + // Parse the property value based on its type (simple or structured). + if (propertyInfo.Flags & PropertyStruct) == PropertyStruct { + value, err = p.parseStruct(*propertyInfo) + } else { + value, err = p.parseSimpleType(*propertyInfo) + } + if err != nil { + return nil, err + } + result[j] = value + } + + // Return the entire result set or the single value, based on the property count. + if ((propertyInfo.Flags & PropertyParamCount) == PropertyParamCount) || + (propertyInfo.count() > 1) { + return result, nil + } + return result[0], nil +} + +// getArraySize calculates the size of an array property within an event. +func (p *propertyParser) getArraySize(propertyInfo EventPropertyInfo) (uint32, error) { + // Check if the property's count is specified by another property. + if (propertyInfo.Flags & PropertyParamCount) == PropertyParamCount { + var dataDescriptor PropertyDataDescriptor + // Locate the property containing the array size using the countPropertyIndex. + dataDescriptor.PropertyName = readPropertyName(p, int(propertyInfo.count())) + dataDescriptor.ArrayIndex = 0xFFFFFFFF + // Retrieve the length of the array from the specified property. + return getLengthFromProperty(p.r, &dataDescriptor) + } else { + // If the array size is directly specified, return it. + return uint32(propertyInfo.count()), nil + } +} + +// getLengthFromProperty retrieves the length of a property from an event record. +func getLengthFromProperty(r *EventRecord, dataDescriptor *PropertyDataDescriptor) (uint32, error) { + var length uint32 + // Call TdhGetProperty to get the length of the property specified by the dataDescriptor. + err := _TdhGetProperty( + r, + 0, + nil, + 1, + dataDescriptor, + uint32(unsafe.Sizeof(length)), + (*byte)(unsafe.Pointer(&length)), + ) + if err != nil { + return 0, err + } + return length, nil +} + +// parseStruct extracts and returns the fields from an embedded structure within a property. +func (p *propertyParser) parseStruct(propertyInfo EventPropertyInfo) (map[string]interface{}, error) { + // Determine the start and end indexes of the structure members within the property info. + startIndex := propertyInfo.structStartIndex() + lastIndex := startIndex + propertyInfo.numOfStructMembers() + + // Initialize a map to hold the structure's fields. + structure := make(map[string]interface{}, (lastIndex - startIndex)) + // Iterate through each member of the structure. + for j := startIndex; j < lastIndex; j++ { + name := p.getPropertyName(int(j)) + value, err := p.getPropertyValue(int(j)) + if err != nil { + return nil, fmt.Errorf("failed parse field '%s' of complex property type: %w", name, err) + } + structure[name] = value // Add the field to the structure map. + } + + return structure, nil +} + +// parseSimpleType parses a simple property type using TdhFormatProperty. +func (p *propertyParser) parseSimpleType(propertyInfo EventPropertyInfo) (string, error) { + var mapInfo *EventMapInfo + if propertyInfo.mapNameOffset() > 0 { + // If failed retrieving the map information, returns on error + var err error + mapInfo, err = p.getMapInfo(propertyInfo) + if err != nil { + return "", fmt.Errorf("failed to get map information due to: %w", err) + } + } + + // Get the length of the property. + propertyLength, err := p.getPropertyLength(propertyInfo) + if err != nil { + return "", fmt.Errorf("failed to get property length due to: %w", err) + } + + var userDataConsumed uint16 + + // Set a default buffer size for formatted data. + formattedDataSize := uint32(DEFAULT_PROPERTY_BUFFER_SIZE) + formattedData := make([]byte, int(formattedDataSize)) + + // Retry loop to handle buffer size adjustments. +retryLoop: + for { + var dataPtr *uint8 + if len(p.data) > 0 { + dataPtr = &p.data[0] + } + err := _TdhFormatProperty( + p.info, + mapInfo, + p.ptrSize, + propertyInfo.inType(), + propertyInfo.outType(), + uint16(propertyLength), + uint16(len(p.data)), + dataPtr, + &formattedDataSize, + &formattedData[0], + &userDataConsumed, + ) + + switch { + case err == nil: + // If formatting is successful, break out of the loop. + break retryLoop + case errors.Is(err, ERROR_INSUFFICIENT_BUFFER): + // Increase the buffer size if it's insufficient. + formattedData = make([]byte, formattedDataSize) + continue + case errors.Is(err, ERROR_EVT_INVALID_EVENT_DATA): + // Handle invalid event data error. + // Discarding MapInfo allows us to access + // at least the non-interpreted data. + if mapInfo != nil { + mapInfo = nil + continue + } + return "", fmt.Errorf("TdhFormatProperty failed: %w", err) // Handle unknown error + default: + return "", fmt.Errorf("TdhFormatProperty failed: %w", err) + } + } + // Update the data slice to account for consumed data. + p.data = p.data[userDataConsumed:] + + // Convert the formatted data to string and return. + return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(&formattedData[0]))), nil +} + +// getMapInfo retrieves mapping information for a given property. +func (p *propertyParser) getMapInfo(propertyInfo EventPropertyInfo) (*EventMapInfo, error) { + var mapSize uint32 + // Get the name of the map from the property info. + mapName := (*uint16)(unsafe.Add(unsafe.Pointer(p.info), propertyInfo.mapNameOffset())) + + // First call to get the required size of the map info. + err := _TdhGetEventMapInformation(p.r, mapName, nil, &mapSize) + switch { + case errors.Is(err, ERROR_NOT_FOUND): + // No mapping information available. This is not an error. + return nil, nil + case errors.Is(err, ERROR_INSUFFICIENT_BUFFER): + // Resize the buffer and try again. + default: + return nil, fmt.Errorf("TdhGetEventMapInformation failed to get size: %w", err) + } + + // Allocate buffer and retrieve the actual map information. + buff := make([]byte, int(mapSize)) + mapInfo := ((*EventMapInfo)(unsafe.Pointer(&buff[0]))) + err = _TdhGetEventMapInformation(p.r, mapName, mapInfo, &mapSize) + if err != nil { + return nil, fmt.Errorf("TdhGetEventMapInformation failed: %w", err) + } + + if mapInfo.EntryCount == 0 { + return nil, nil // No entries in the map. + } + + return mapInfo, nil +} + +// getPropertyLength returns the length of a specific property within TraceEventInfo. +func (p *propertyParser) getPropertyLength(propertyInfo EventPropertyInfo) (uint32, error) { + // Check if the length of the property is defined by another property. + if (propertyInfo.Flags & PropertyParamLength) == PropertyParamLength { + var dataDescriptor PropertyDataDescriptor + // Read the property name that contains the length information. + dataDescriptor.PropertyName = readPropertyName(p, int(propertyInfo.length())) + dataDescriptor.ArrayIndex = 0xFFFFFFFF + // Retrieve the length from the specified property. + return getLengthFromProperty(p.r, &dataDescriptor) + } + + inType := propertyInfo.inType() + outType := propertyInfo.outType() + // Special handling for properties representing IPv6 addresses. + // https://docs.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhformatproperty#remarks + if TdhIntypeBinary == inType && TdhOuttypeIpv6 == outType { + // Return the fixed size of an IPv6 address. + return 16, nil + } + + // Default case: return the length as defined in the property info. + // Note: A length of 0 can indicate a variable-length field (e.g., structure, string). + return uint32(propertyInfo.length()), nil +} diff --git a/x-pack/libbeat/reader/etw/provider.go b/x-pack/libbeat/reader/etw/provider.go new file mode 100644 index 000000000000..e0a20c3facd1 --- /dev/null +++ b/x-pack/libbeat/reader/etw/provider.go @@ -0,0 +1,81 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// utf16AtOffsetToString converts a UTF-16 encoded string +// at a specific offset in a struct to a Go string. +func utf16AtOffsetToString(pstruct uintptr, offset uintptr) string { + // Initialize a slice to store UTF-16 characters. + out := make([]uint16, 0, 64) + + // Start reading at the given offset. + wc := (*uint16)(unsafe.Pointer(pstruct + offset)) + + // Iterate over the UTF-16 characters until a null terminator is encountered. + for i := uintptr(2); *wc != 0; i += 2 { + out = append(out, *wc) + wc = (*uint16)(unsafe.Pointer(pstruct + offset + i)) + } + + // Convert the UTF-16 slice to a Go string and return. + return syscall.UTF16ToString(out) +} + +// guidFromProviderName searches for a provider by name and returns its GUID. +func guidFromProviderName(providerName string) (windows.GUID, error) { + // Returns if the provider name is empty. + if providerName == "" { + return windows.GUID{}, fmt.Errorf("empty provider name") + } + + var buf *ProviderEnumerationInfo + size := uint32(1) + + // Attempt to retrieve provider information with a buffer that increases in size until it's sufficient. + for { + tmp := make([]byte, size) + buf = (*ProviderEnumerationInfo)(unsafe.Pointer(&tmp[0])) + if err := enumerateProvidersFunc(buf, &size); !errors.Is(err, ERROR_INSUFFICIENT_BUFFER) { + break + } + } + + if buf.NumberOfProviders == 0 { + return windows.GUID{}, fmt.Errorf("no providers found") + } + + // Iterate through the list of providers to find a match by name. + startProvEnumInfo := uintptr(unsafe.Pointer(buf)) + it := uintptr(unsafe.Pointer(&buf.TraceProviderInfoArray[0])) + for i := uintptr(0); i < uintptr(buf.NumberOfProviders); i++ { + pInfo := (*TraceProviderInfo)(unsafe.Pointer(it + i*unsafe.Sizeof(buf.TraceProviderInfoArray[0]))) + name := utf16AtOffsetToString(startProvEnumInfo, uintptr(pInfo.ProviderNameOffset)) + + // If a match is found, return the corresponding GUID. + if name == providerName { + return pInfo.ProviderGuid, nil + } + } + + // No matching provider is found. + return windows.GUID{}, fmt.Errorf("unable to find GUID from provider name") +} + +// IsGUIDValid checks if GUID contains valid data +// (any of the fields in the GUID are non-zero) +func IsGUIDValid(guid windows.GUID) bool { + return guid.Data1 != 0 || guid.Data2 != 0 || guid.Data3 != 0 || guid.Data4 != [8]byte{} +} diff --git a/x-pack/libbeat/reader/etw/provider_test.go b/x-pack/libbeat/reader/etw/provider_test.go new file mode 100644 index 000000000000..d8c561ef3e4f --- /dev/null +++ b/x-pack/libbeat/reader/etw/provider_test.go @@ -0,0 +1,199 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "encoding/binary" + "syscall" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/windows" +) + +func TestUTF16AtOffsetToString(t *testing.T) { + // Create a UTF-16 string + sampleText := "This is a string test!" + utf16Str, _ := syscall.UTF16FromString(sampleText) + + // Convert it to uintptr (simulate as if it's part of a larger struct) + ptr := uintptr(unsafe.Pointer(&utf16Str[0])) + + // Test the function + result := utf16AtOffsetToString(ptr, 0) + assert.Equal(t, sampleText, result, "The converted string should match the original") + + // Test with offset (skip the first character) + offset := unsafe.Sizeof(utf16Str[0]) // Size of one UTF-16 character + resultWithOffset := utf16AtOffsetToString(ptr, offset) + assert.Equal(t, sampleText[1:], resultWithOffset, "The converted string with offset should skip the first character") +} + +func TestGUIDFromProviderName_EmptyName(t *testing.T) { + guid, err := guidFromProviderName("") + assert.EqualError(t, err, "empty provider name") + assert.Equal(t, windows.GUID{}, guid, "GUID should be empty for an empty provider name") +} + +func TestGUIDFromProviderName_EmptyProviderList(t *testing.T) { + // Defer restoration of the original function + t.Cleanup(func() { + enumerateProvidersFunc = _TdhEnumerateProviders + }) + + // Define a mock provider name and GUID for testing. + mockProviderName := "NonExistentProvider" + + enumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error { + // Check if the buffer size is sufficient + requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{}))*0 // As there are no providers + if *pBufferSize < requiredSize { + // Set the size required and return the error + *pBufferSize = requiredSize + return ERROR_INSUFFICIENT_BUFFER + } + + // Empty list of providers + *pBuffer = ProviderEnumerationInfo{ + NumberOfProviders: 0, + TraceProviderInfoArray: [anysizeArray]TraceProviderInfo{}, + } + return nil + } + + guid, err := guidFromProviderName(mockProviderName) + assert.EqualError(t, err, "no providers found") + assert.Equal(t, windows.GUID{}, guid, "GUID should be empty when the provider is not found") +} + +func TestGUIDFromProviderName_GUIDNotFound(t *testing.T) { + // Defer restoration of the original function + t.Cleanup(func() { + enumerateProvidersFunc = _TdhEnumerateProviders + }) + + // Define a mock provider name and GUID for testing. + mockProviderName := "NonExistentProvider" + realProviderName := "ExistentProvider" + mockGUID := windows.GUID{Data1: 1234, Data2: 5678} + + enumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error { + // Convert provider name to UTF-16 + utf16ProviderName, _ := syscall.UTF16FromString(realProviderName) + + // Calculate size needed for the provider name string + nameSize := (len(utf16ProviderName) + 1) * 2 // +1 for null-terminator + + requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{})) + uint32(nameSize) + if *pBufferSize < requiredSize { + *pBufferSize = requiredSize + return ERROR_INSUFFICIENT_BUFFER + } + + // Calculate the offset for the provider name + // It's placed after ProviderEnumerationInfo and TraceProviderInfo + nameOffset := unsafe.Sizeof(ProviderEnumerationInfo{}) + unsafe.Sizeof(TraceProviderInfo{}) + + // Convert pBuffer to a byte slice starting at the calculated offset for the name + byteBuffer := (*[1 << 30]byte)(unsafe.Pointer(pBuffer))[:] + // Copy the UTF-16 encoded name into the buffer + for i, char := range utf16ProviderName { + binary.LittleEndian.PutUint16(byteBuffer[nameOffset+(uintptr(i)*2):], char) + } + + // Create and populate the ProviderEnumerationInfo struct + *pBuffer = ProviderEnumerationInfo{ + NumberOfProviders: 1, + TraceProviderInfoArray: [anysizeArray]TraceProviderInfo{ + { + ProviderGuid: mockGUID, + ProviderNameOffset: uint32(nameOffset), + }, + }, + } + return nil + } + + guid, err := guidFromProviderName(mockProviderName) + assert.EqualError(t, err, "unable to find GUID from provider name") + assert.Equal(t, windows.GUID{}, guid, "GUID should be empty when the provider is not found") +} + +func TestGUIDFromProviderName_Success(t *testing.T) { + // Defer restoration of the original function + t.Cleanup(func() { + enumerateProvidersFunc = _TdhEnumerateProviders + }) + + // Define a mock provider name and GUID for testing. + mockProviderName := "MockProvider" + mockGUID := windows.GUID{Data1: 1234, Data2: 5678} + + enumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error { + // Convert provider name to UTF-16 + utf16ProviderName, _ := syscall.UTF16FromString(mockProviderName) + + // Calculate size needed for the provider name string + nameSize := (len(utf16ProviderName) + 1) * 2 // +1 for null-terminator + + requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{})) + uint32(nameSize) + if *pBufferSize < requiredSize { + *pBufferSize = requiredSize + return ERROR_INSUFFICIENT_BUFFER + } + + // Calculate the offset for the provider name + // It's placed after ProviderEnumerationInfo and TraceProviderInfo + nameOffset := unsafe.Sizeof(ProviderEnumerationInfo{}) + unsafe.Sizeof(TraceProviderInfo{}) + + // Convert pBuffer to a byte slice starting at the calculated offset for the name + byteBuffer := (*[1 << 30]byte)(unsafe.Pointer(pBuffer))[:] + // Copy the UTF-16 encoded name into the buffer + for i, char := range utf16ProviderName { + binary.LittleEndian.PutUint16(byteBuffer[nameOffset+(uintptr(i)*2):], char) + } + + // Create and populate the ProviderEnumerationInfo struct + *pBuffer = ProviderEnumerationInfo{ + NumberOfProviders: 1, + TraceProviderInfoArray: [anysizeArray]TraceProviderInfo{ + { + ProviderGuid: mockGUID, + ProviderNameOffset: uint32(nameOffset), + }, + }, + } + return nil + } + + // Run the test + guid, err := guidFromProviderName(mockProviderName) + assert.NoError(t, err) + assert.Equal(t, mockGUID, guid, "GUID should match the mock GUID") +} + +func TestIsGUIDValid_True(t *testing.T) { + // Valid GUID + validGUID := windows.GUID{ + Data1: 0xeb79061a, + Data2: 0xa566, + Data3: 0x4698, + Data4: [8]byte{0x12, 0x34, 0x3e, 0xd2, 0x80, 0x70, 0x33, 0xa0}, + } + + valid := IsGUIDValid(validGUID) + assert.True(t, valid, "IsGUIDValid should return true for a valid GUID") +} + +func TestIsGUIDValid_False(t *testing.T) { + // Invalid GUID (all zeros) + invalidGUID := windows.GUID{} + + valid := IsGUIDValid(invalidGUID) + assert.False(t, valid, "IsGUIDValid should return false for an invalid GUID") +} diff --git a/x-pack/libbeat/reader/etw/session.go b/x-pack/libbeat/reader/etw/session.go new file mode 100644 index 000000000000..3a8e7be51d7c --- /dev/null +++ b/x-pack/libbeat/reader/etw/session.go @@ -0,0 +1,250 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// For testing purposes we create a variable to store the function to call +// When running tests, these variables point to a mock function +var ( + guidFromProviderNameFunc = guidFromProviderName + setSessionGUIDFunc = setSessionGUID +) + +type Session struct { + // Name is the identifier for the session. + // It is used to identify the session in logs and also for Windows processes. + Name string + // GUID is the provider GUID to configure the session. + GUID windows.GUID + // properties of the session that are initialized in newSessionProperties() + // See https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_properties for more information + properties *EventTraceProperties + // handler of the event tracing session for which the provider is being configured. + // It is obtained from StartTrace when a new trace is started. + // This handler is needed to enable, query or stop the trace. + handler uintptr + // Realtime is a flag to know if the consumer reads from a logfile or real-time session. + Realtime bool // Real-time flag + // NewSession is a flag to indicate whether a new session has been created or attached to an existing one. + NewSession bool + // TraceLevel sets the maximum level of events that we want the provider to write. + traceLevel uint8 + // matchAnyKeyword is a 64-bit bitmask of keywords that determine the categories of events that we want the provider to write. + // The provider writes an event if the event's keyword bits match any of the bits set in this value + // or if the event has no keyword bits set, in addition to meeting the level and matchAllKeyword criteria. + matchAnyKeyword uint64 + // matchAllKeyword is a 64-bit bitmask of keywords that restricts the events that we want the provider to write. + // The provider typically writes an event if the event's keyword bits match all of the bits set in this value + // or if the event has no keyword bits set, in addition to meeting the level and matchAnyKeyword criteria. + matchAllKeyword uint64 + // traceHandler is the trace processing handle. + // It is used to control the trace that receives and processes events. + traceHandler uint64 + // Callback is the pointer to EventRecordCallback which receives and processes event trace events. + Callback func(*EventRecord) uintptr + // BufferCallback is the pointer to BufferCallback which processes retrieved metadata about the ETW buffers (optional). + BufferCallback func(*EventTraceLogfile) uintptr + + // Pointers to functions that make calls to the Windows API. + // In tests, these pointers can be replaced with mock functions to simulate API behavior without making actual calls to the Windows API. + startTrace func(*uintptr, *uint16, *EventTraceProperties) error + controlTrace func(traceHandle uintptr, instanceName *uint16, properties *EventTraceProperties, controlCode uint32) error + enableTrace func(traceHandle uintptr, providerId *windows.GUID, isEnabled uint32, level uint8, matchAnyKeyword uint64, matchAllKeyword uint64, enableProperty uint32, enableParameters *EnableTraceParameters) error + closeTrace func(traceHandle uint64) error + openTrace func(elf *EventTraceLogfile) (uint64, error) + processTrace func(handleArray *uint64, handleCount uint32, startTime *FileTime, endTime *FileTime) error +} + +// setSessionName determines the session name based on the provided configuration. +func setSessionName(conf Config) string { + // Iterate through potential session name values, returning the first non-empty one. + for _, value := range []string{conf.Logfile, conf.Session, conf.SessionName} { + if value != "" { + return value + } + } + + if conf.ProviderName != "" { + return fmt.Sprintf("Elastic-%s", conf.ProviderName) + } + + return fmt.Sprintf("Elastic-%s", conf.ProviderGUID) +} + +// setSessionGUID determines the session GUID based on the provided configuration. +func setSessionGUID(conf Config) (windows.GUID, error) { + var guid windows.GUID + var err error + + // If ProviderGUID is not set in the configuration, attempt to resolve it using the provider name. + if conf.ProviderGUID == "" { + guid, err = guidFromProviderNameFunc(conf.ProviderName) + if err != nil { + return windows.GUID{}, fmt.Errorf("error resolving GUID: %w", err) + } + } else { + // If ProviderGUID is set, parse it into a GUID structure. + guid, err = windows.GUIDFromString(conf.ProviderGUID) + if err != nil { + return windows.GUID{}, fmt.Errorf("error parsing Windows GUID: %w", err) + } + } + + return guid, nil +} + +// getTraceLevel converts a string representation of a trace level +// to its corresponding uint8 constant value +func getTraceLevel(level string) uint8 { + switch level { + case "critical": + return TRACE_LEVEL_CRITICAL + case "error": + return TRACE_LEVEL_ERROR + case "warning": + return TRACE_LEVEL_WARNING + case "information": + return TRACE_LEVEL_INFORMATION + case "verbose": + return TRACE_LEVEL_VERBOSE + default: + return TRACE_LEVEL_INFORMATION + } +} + +// newSessionProperties initializes and returns a pointer to EventTraceProperties +// with the necessary settings for starting an ETW session. +// See https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_properties +func newSessionProperties(sessionName string) *EventTraceProperties { + // Calculate buffer size for session properties. + sessionNameSize := (len(sessionName) + 1) * 2 + bufSize := sessionNameSize + int(unsafe.Sizeof(EventTraceProperties{})) + + // Allocate buffer and cast to EventTraceProperties. + propertiesBuf := make([]byte, bufSize) + sessionProperties := (*EventTraceProperties)(unsafe.Pointer(&propertiesBuf[0])) + + // Initialize mandatory fields of the EventTraceProperties struct. + // Filled based on https://learn.microsoft.com/en-us/windows/win32/etw/wnode-header + sessionProperties.Wnode.BufferSize = uint32(bufSize) + sessionProperties.Wnode.Guid = windows.GUID{} // GUID not required for non-private/kernel sessions + // ClientContext is used for timestamp resolution + // Not used unless adding PROCESS_TRACE_MODE_RAW_TIMESTAMP flag to EVENT_TRACE_LOGFILE struct + // See https://learn.microsoft.com/en-us/windows/win32/etw/wnode-header + sessionProperties.Wnode.ClientContext = 1 + sessionProperties.Wnode.Flags = WNODE_FLAG_TRACED_GUID + // Set logging mode to real-time + // See https://learn.microsoft.com/en-us/windows/win32/etw/logging-mode-constants + sessionProperties.LogFileMode = EVENT_TRACE_REAL_TIME_MODE + sessionProperties.LogFileNameOffset = 0 // Can be specified to log to a file as well as to a real-time session + sessionProperties.BufferSize = 64 // Default buffer size, can be configurable + sessionProperties.LoggerNameOffset = uint32(unsafe.Sizeof(EventTraceProperties{})) // Offset to the logger name + + return sessionProperties +} + +// NewSession initializes and returns a new ETW Session based on the provided configuration. +func NewSession(conf Config) (Session, error) { + var session Session + var err error + + // Assign ETW Windows API functions + session.startTrace = _StartTrace + session.controlTrace = _ControlTrace + session.enableTrace = _EnableTraceEx2 + session.openTrace = _OpenTrace + session.processTrace = _ProcessTrace + session.closeTrace = _CloseTrace + + session.Name = setSessionName(conf) + session.Realtime = true + + // If a current session is configured, set up the session properties and return. + if conf.Session != "" { + session.properties = newSessionProperties(session.Name) + return session, nil + } else if conf.Logfile != "" { + // If a logfile is specified, set up for non-realtime session. + session.Realtime = false + return session, nil + } + + session.NewSession = true // Indicate this is a new session + + session.GUID, err = setSessionGUIDFunc(conf) + if err != nil { + return Session{}, err + } + + // Initialize additional session properties. + session.properties = newSessionProperties(session.Name) + session.traceLevel = getTraceLevel(conf.TraceLevel) + session.matchAnyKeyword = conf.MatchAnyKeyword + session.matchAllKeyword = conf.MatchAllKeyword + + return session, nil +} + +// StartConsumer initializes and starts the ETW event tracing session. +func (s *Session) StartConsumer() error { + var elf EventTraceLogfile + var err error + + // Configure EventTraceLogfile based on the session type (realtime or not). + if !s.Realtime { + elf.LogFileMode = PROCESS_TRACE_MODE_EVENT_RECORD + logfilePtr, err := syscall.UTF16PtrFromString(s.Name) + if err != nil { + return fmt.Errorf("failed to convert logfile name: %w", err) + } + elf.LogFileName = logfilePtr + } else { + elf.LogFileMode = PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_REAL_TIME + sessionPtr, err := syscall.UTF16PtrFromString(s.Name) + if err != nil { + return fmt.Errorf("failed to convert session name: %w", err) + } + elf.LoggerName = sessionPtr + } + + // Set callback and context for the session. + if s.Callback == nil { + return fmt.Errorf("error loading callback") + } + elf.Callback = syscall.NewCallback(s.Callback) + elf.Context = 0 + + // Open an ETW trace processing handle for consuming events + // from an ETW real-time trace session or an ETW log file. + s.traceHandler, err = s.openTrace(&elf) + + switch { + case err == nil: + + // Handle specific errors for trace opening. + case errors.Is(err, ERROR_BAD_PATHNAME): + return fmt.Errorf("invalid log source when opening trace: %w", err) + case errors.Is(err, ERROR_ACCESS_DENIED): + return fmt.Errorf("access denied when opening trace: %w", err) + default: + return fmt.Errorf("failed to open trace: %w", err) + } + // Process the trace. This function blocks until processing ends. + if err := s.processTrace(&s.traceHandler, 1, nil, nil); err != nil { + return fmt.Errorf("failed to process trace: %w", err) + } + + return nil +} diff --git a/x-pack/libbeat/reader/etw/session_test.go b/x-pack/libbeat/reader/etw/session_test.go new file mode 100644 index 000000000000..005b9839d5c6 --- /dev/null +++ b/x-pack/libbeat/reader/etw/session_test.go @@ -0,0 +1,338 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "fmt" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/windows" +) + +// TestSetSessionName tests the setSessionName function with various configurations. +func TestSetSessionName(t *testing.T) { + testCases := []struct { + name string + config Config + expectedName string + }{ + { + name: "ProviderNameSet", + config: Config{ + ProviderName: "Provider1", + }, + expectedName: "Elastic-Provider1", + }, + { + name: "SessionNameSet", + config: Config{ + SessionName: "Session1", + }, + expectedName: "Session1", + }, + { + name: "LogFileSet", + config: Config{ + Logfile: "LogFile1.etl", + }, + expectedName: "LogFile1.etl", + }, + { + name: "FallbackToProviderGUID", + config: Config{ + ProviderGUID: "12345", + }, + expectedName: "Elastic-12345", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sessionName := setSessionName(tc.config) + assert.Equal(t, tc.expectedName, sessionName, "The session name should be correctly determined") + }) + } +} + +func mockGUIDFromProviderName(providerName string) (windows.GUID, error) { + // Return a mock GUID regardless of the input + return windows.GUID{Data1: 0x12345678, Data2: 0x1234, Data3: 0x5678, Data4: [8]byte{0x9A, 0xBC, 0xDE, 0xF0, 0x12, 0x34, 0x56, 0x78}}, nil +} + +func TestSetSessionGUID_ProviderName(t *testing.T) { + // Defer restoration of original function + t.Cleanup(func() { + guidFromProviderNameFunc = guidFromProviderName + }) + + // Replace with mock function + guidFromProviderNameFunc = mockGUIDFromProviderName + + conf := Config{ProviderName: "Provider1"} + expectedGUID := windows.GUID{Data1: 0x12345678, Data2: 0x1234, Data3: 0x5678, Data4: [8]byte{0x9A, 0xBC, 0xDE, 0xF0, 0x12, 0x34, 0x56, 0x78}} + + guid, err := setSessionGUID(conf) + assert.NoError(t, err) + assert.Equal(t, expectedGUID, guid, "The GUID should match the mock GUID") +} + +func TestSetSessionGUID_ProviderGUID(t *testing.T) { + // Example GUID string + guidString := "{12345678-1234-5678-1234-567812345678}" + + // Configuration with a set ProviderGUID + conf := Config{ProviderGUID: guidString} + + // Expected GUID based on the GUID string + expectedGUID := windows.GUID{Data1: 0x12345678, Data2: 0x1234, Data3: 0x5678, Data4: [8]byte{0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}} + + guid, err := setSessionGUID(conf) + + assert.NoError(t, err) + assert.Equal(t, expectedGUID, guid, "The GUID should match the expected value") +} + +func TestGetTraceLevel(t *testing.T) { + testCases := []struct { + name string + level string + expectedCode uint8 + }{ + {"CriticalLevel", "critical", TRACE_LEVEL_CRITICAL}, + {"ErrorLevel", "error", TRACE_LEVEL_ERROR}, + {"WarningLevel", "warning", TRACE_LEVEL_WARNING}, + {"InformationLevel", "information", TRACE_LEVEL_INFORMATION}, + {"VerboseLevel", "verbose", TRACE_LEVEL_VERBOSE}, + {"DefaultLevel", "unknown", TRACE_LEVEL_INFORMATION}, // Default case + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := getTraceLevel(tc.level) + assert.Equal(t, tc.expectedCode, result, "Trace level code should match the expected value") + }) + } +} + +func TestNewSessionProperties(t *testing.T) { + testCases := []struct { + name string + sessionName string + expectedSize uint32 + }{ + {"EmptyName", "", 2 + uint32(unsafe.Sizeof(EventTraceProperties{}))}, + {"NormalName", "Session1", 18 + uint32(unsafe.Sizeof(EventTraceProperties{}))}, + // Additional test cases can be added here + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + props := newSessionProperties(tc.sessionName) + + assert.Equal(t, tc.expectedSize, props.Wnode.BufferSize, "BufferSize should match expected value") + assert.Equal(t, windows.GUID{}, props.Wnode.Guid, "GUID should be empty") + assert.Equal(t, uint32(1), props.Wnode.ClientContext, "ClientContext should be 1") + assert.Equal(t, uint32(WNODE_FLAG_TRACED_GUID), props.Wnode.Flags, "Flags should match WNODE_FLAG_TRACED_GUID") + assert.Equal(t, uint32(EVENT_TRACE_REAL_TIME_MODE), props.LogFileMode, "LogFileMode should be set to real-time") + assert.Equal(t, uint32(0), props.LogFileNameOffset, "LogFileNameOffset should be 0") + assert.Equal(t, uint32(64), props.BufferSize, "BufferSize should be 64") + assert.Equal(t, uint32(unsafe.Sizeof(EventTraceProperties{})), props.LoggerNameOffset, "LoggerNameOffset should be the size of EventTraceProperties") + }) + } +} + +func TestNewSession_ProviderName(t *testing.T) { + // Defer restoration of original function + t.Cleanup(func() { + setSessionGUIDFunc = setSessionGUID + }) + + // Override setSessionGUIDFunc with mock + setSessionGUIDFunc = func(conf Config) (windows.GUID, error) { + return windows.GUID{ + Data1: 0x12345678, + Data2: 0x1234, + Data3: 0x5678, + Data4: [8]byte{0x9A, 0xBC, 0xDE, 0xF0, 0x12, 0x34, 0x56, 0x78}, + }, nil + } + + expectedGUID := windows.GUID{ + Data1: 0x12345678, + Data2: 0x1234, + Data3: 0x5678, + Data4: [8]byte{0x9A, 0xBC, 0xDE, 0xF0, 0x12, 0x34, 0x56, 0x78}, + } + + conf := Config{ + ProviderName: "Provider1", + SessionName: "Session1", + TraceLevel: "warning", + MatchAnyKeyword: 0xffffffffffffffff, + MatchAllKeyword: 0, + } + session, err := NewSession(conf) + + assert.NoError(t, err) + assert.Equal(t, "Session1", session.Name, "SessionName should match expected value") + assert.Equal(t, expectedGUID, session.GUID, "The GUID in the session should match the expected GUID") + assert.Equal(t, uint8(3), session.traceLevel, "TraceLevel should be 3 (warning)") + assert.Equal(t, true, session.NewSession) + assert.Equal(t, true, session.Realtime) + assert.NotNil(t, session.properties) +} + +func TestNewSession_GUIDError(t *testing.T) { + // Defer restoration of original function + t.Cleanup(func() { + setSessionGUIDFunc = setSessionGUID + }) + + // Override setSessionGUIDFunc with mock + setSessionGUIDFunc = func(conf Config) (windows.GUID, error) { + // Return an empty GUID and an error + return windows.GUID{}, fmt.Errorf("mock error") + } + + conf := Config{ + ProviderName: "Provider1", + SessionName: "Session1", + TraceLevel: "warning", + MatchAnyKeyword: 0xffffffffffffffff, + MatchAllKeyword: 0, + } + session, err := NewSession(conf) + + assert.EqualError(t, err, "mock error") + expectedSession := Session{} + assert.Equal(t, expectedSession, session, "Session should be its zero value when an error occurs") + +} + +func TestNewSession_AttachSession(t *testing.T) { + // Test case + conf := Config{ + Session: "Session1", + SessionName: "TestSession", + TraceLevel: "verbose", + MatchAnyKeyword: 0xffffffffffffffff, + MatchAllKeyword: 0, + } + session, err := NewSession(conf) + + assert.NoError(t, err) + assert.Equal(t, "Session1", session.Name, "SessionName should match expected value") + assert.Equal(t, false, session.NewSession) + assert.Equal(t, true, session.Realtime) + assert.NotNil(t, session.properties) +} + +func TestNewSession_Logfile(t *testing.T) { + // Test case + conf := Config{ + Logfile: "LogFile1.etl", + TraceLevel: "verbose", + MatchAnyKeyword: 0xffffffffffffffff, + MatchAllKeyword: 0, + } + session, err := NewSession(conf) + + assert.NoError(t, err) + assert.Equal(t, "LogFile1.etl", session.Name, "SessionName should match expected value") + assert.Equal(t, false, session.NewSession) + assert.Equal(t, false, session.Realtime) + assert.Nil(t, session.properties) +} + +func TestStartConsumer_CallbackNull(t *testing.T) { + // Create a Session instance + session := &Session{ + Name: "TestSession", + Realtime: false, + BufferCallback: nil, + Callback: nil, + } + + err := session.StartConsumer() + assert.EqualError(t, err, "error loading callback") +} + +func TestStartConsumer_OpenTraceError(t *testing.T) { + // Mock implementation of openTrace + openTrace := func(elf *EventTraceLogfile) (uint64, error) { + return 0, ERROR_ACCESS_DENIED // Mock a valid session handler + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + Realtime: false, + BufferCallback: nil, + Callback: func(*EventRecord) uintptr { + return 1 + }, + openTrace: openTrace, + } + + err := session.StartConsumer() + assert.EqualError(t, err, "access denied when opening trace: Access is denied.") +} + +func TestStartConsumer_ProcessTraceError(t *testing.T) { + // Mock implementations + openTrace := func(elf *EventTraceLogfile) (uint64, error) { + return 12345, nil // Mock a valid session handler + } + + processTrace := func(handleArray *uint64, handleCount uint32, startTime *FileTime, endTime *FileTime) error { + return ERROR_INVALID_PARAMETER + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + Realtime: true, + BufferCallback: nil, + Callback: func(*EventRecord) uintptr { + return 1 + }, + openTrace: openTrace, + processTrace: processTrace, + } + + err := session.StartConsumer() + assert.EqualError(t, err, "failed to process trace: The parameter is incorrect.") +} + +func TestStartConsumer_Success(t *testing.T) { + // Mock implementations + openTrace := func(elf *EventTraceLogfile) (uint64, error) { + return 12345, nil // Mock a valid session handler + } + + processTrace := func(handleArray *uint64, handleCount uint32, startTime *FileTime, endTime *FileTime) error { + return nil + } + + // Create a Session instance + session := &Session{ + Name: "TestSession", + Realtime: true, + BufferCallback: nil, + Callback: func(*EventRecord) uintptr { + return 1 + }, + openTrace: openTrace, + processTrace: processTrace, + } + + err := session.StartConsumer() + assert.NoError(t, err) + assert.Equal(t, uint64(12345), session.traceHandler, "traceHandler should be set to the mock value") +} diff --git a/x-pack/libbeat/reader/etw/syscall_advapi32.go b/x-pack/libbeat/reader/etw/syscall_advapi32.go new file mode 100644 index 000000000000..fe44b0022a46 --- /dev/null +++ b/x-pack/libbeat/reader/etw/syscall_advapi32.go @@ -0,0 +1,318 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "errors" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + advapi32 = windows.NewLazySystemDLL("advapi32.dll") + // Controller + startTraceW = advapi32.NewProc("StartTraceW") + enableTraceEx2 = advapi32.NewProc("EnableTraceEx2") // Manifest-based providers and filtering + controlTraceW = advapi32.NewProc("ControlTraceW") + // Consumer + openTraceW = advapi32.NewProc("OpenTraceW") + processTrace = advapi32.NewProc("ProcessTrace") + closeTrace = advapi32.NewProc("CloseTrace") +) + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace +type EventTrace struct { + Header EventTraceHeader + InstanceId uint32 + ParentInstanceId uint32 + ParentGuid windows.GUID + MofData uintptr + MofLength uint32 + UnionCtx uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_header +type EventTraceHeader struct { + Size uint16 + Union1 uint16 + Union2 uint32 + ThreadId uint32 + ProcessId uint32 + TimeStamp int64 + Union3 [16]byte + Union4 uint64 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_properties +type EventTraceProperties struct { + Wnode WnodeHeader + BufferSize uint32 + MinimumBuffers uint32 + MaximumBuffers uint32 + MaximumFileSize uint32 + LogFileMode uint32 + FlushTimer uint32 + EnableFlags uint32 + AgeLimit int32 + NumberOfBuffers uint32 + FreeBuffers uint32 + EventsLost uint32 + BuffersWritten uint32 + LogBuffersLost uint32 + RealTimeBuffersLost uint32 + LoggerThreadId syscall.Handle + LogFileNameOffset uint32 + LoggerNameOffset uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/etw/wnode-header +type WnodeHeader struct { + BufferSize uint32 + ProviderId uint32 + Union1 uint64 + Union2 int64 + Guid windows.GUID + ClientContext uint32 + Flags uint32 +} + +// Used to enable a provider via EnableTraceEx2 +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-enable_trace_parameters +type EnableTraceParameters struct { + Version uint32 + EnableProperty uint32 + ControlFlags uint32 + SourceId windows.GUID + EnableFilterDesc *EventFilterDescriptor + FilterDescrCount uint32 +} + +// Defines the filter data that a session passes +// to the provider's enable callback function +// https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_filter_descriptor +type EventFilterDescriptor struct { + Ptr uint64 + Size uint32 + Type uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_logfilew +type EventTraceLogfile struct { + LogFileName *uint16 // Logfile + LoggerName *uint16 // Real-time session + CurrentTime int64 + BuffersRead uint32 + LogFileMode uint32 + CurrentEvent EventTrace + LogfileHeader TraceLogfileHeader + BufferCallback uintptr + BufferSize uint32 + Filled uint32 + EventsLost uint32 + // Receive events (EventRecordCallback (TDH) or EventCallback) + // Tip: New code should use EventRecordCallback instead of EventCallback. + // The EventRecordCallback receives an EVENT_RECORD which contains + // more complete event information + Callback uintptr + IsKernelTrace uint32 + Context uintptr +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-trace_logfile_header +type TraceLogfileHeader struct { + BufferSize uint32 + VersionUnion uint32 + ProviderVersion uint32 + NumberOfProcessors uint32 + EndTime int64 + TimerResolution uint32 + MaximumFileSize uint32 + LogFileMode uint32 + BuffersWritten uint32 + Union1 [16]byte + LoggerName *uint16 + LogFileName *uint16 + TimeZone windows.Timezoneinformation + BootTime int64 + PerfFreq int64 + StartTime int64 + ReservedFlags uint32 + BuffersLost uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime +type FileTime struct { + dwLowDateTime uint32 + dwHighDateTime uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-systemtime +type SystemTime struct { + Year uint16 + Month uint16 + DayOfWeek uint16 + Day uint16 + Hour uint16 + Minute uint16 + Second uint16 + Milliseconds uint16 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-enabletrace +const ( + TRACE_LEVEL_NONE = 0 + TRACE_LEVEL_CRITICAL = 1 + TRACE_LEVEL_FATAL = 1 + TRACE_LEVEL_ERROR = 2 + TRACE_LEVEL_WARNING = 3 + TRACE_LEVEL_INFORMATION = 4 + TRACE_LEVEL_VERBOSE = 5 +) + +// https://learn.microsoft.com/en-us/windows/win32/api/evntprov/nc-evntprov-penablecallback +const ( + EVENT_CONTROL_CODE_DISABLE_PROVIDER = 0 + EVENT_CONTROL_CODE_ENABLE_PROVIDER = 1 + EVENT_CONTROL_CODE_CAPTURE_STATE = 2 +) + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-controltracea +const ( + EVENT_TRACE_CONTROL_QUERY = 0 + EVENT_TRACE_CONTROL_STOP = 1 + EVENT_TRACE_CONTROL_UPDATE = 2 + EVENT_TRACE_CONTROL_FLUSH = 3 +) + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/ns-evntrace-event_trace_logfilea +const ( + PROCESS_TRACE_MODE_REAL_TIME = 0x00000100 + PROCESS_TRACE_MODE_RAW_TIMESTAMP = 0x00001000 + PROCESS_TRACE_MODE_EVENT_RECORD = 0x10000000 +) + +const INVALID_PROCESSTRACE_HANDLE = 0xFFFFFFFFFFFFFFFF + +// https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes +const ( + ERROR_ACCESS_DENIED syscall.Errno = 5 + ERROR_INVALID_HANDLE syscall.Errno = 6 + ERROR_BAD_LENGTH syscall.Errno = 24 + ERROR_INVALID_PARAMETER syscall.Errno = 87 + ERROR_INSUFFICIENT_BUFFER syscall.Errno = 122 + ERROR_BAD_PATHNAME syscall.Errno = 161 + ERROR_ALREADY_EXISTS syscall.Errno = 183 + ERROR_NOT_FOUND syscall.Errno = 1168 + ERROR_NO_SYSTEM_RESOURCES syscall.Errno = 1450 + ERROR_TIMEOUT syscall.Errno = 1460 + ERROR_WMI_INSTANCE_NOT_FOUND syscall.Errno = 4201 + ERROR_CTX_CLOSE_PENDING syscall.Errno = 7007 + ERROR_EVT_INVALID_EVENT_DATA syscall.Errno = 15005 +) + +// https://learn.microsoft.com/en-us/windows/win32/etw/logging-mode-constants (to extend modes) +// https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wmistr/ns-wmistr-_wnode_header (to extend flags) +const ( + WNODE_FLAG_ALL_DATA = 0x00000001 + WNODE_FLAG_TRACED_GUID = 0x00020000 + EVENT_TRACE_REAL_TIME_MODE = 0x00000100 +) + +// Wrappers + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-starttracew +func _StartTrace(traceHandle *uintptr, + instanceName *uint16, + properties *EventTraceProperties) error { + r0, _, _ := startTraceW.Call( + uintptr(unsafe.Pointer(traceHandle)), + uintptr(unsafe.Pointer(instanceName)), + uintptr(unsafe.Pointer(properties))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-enabletraceex2 +func _EnableTraceEx2(traceHandle uintptr, + providerId *windows.GUID, + isEnabled uint32, + level uint8, + matchAnyKeyword uint64, + matchAllKeyword uint64, + enableProperty uint32, + enableParameters *EnableTraceParameters) error { + r0, _, _ := enableTraceEx2.Call( + traceHandle, + uintptr(unsafe.Pointer(providerId)), + uintptr(isEnabled), + uintptr(level), + uintptr(matchAnyKeyword), + uintptr(matchAllKeyword), + uintptr(enableProperty), + uintptr(unsafe.Pointer(enableParameters))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-controltracew +func _ControlTrace(traceHandle uintptr, + instanceName *uint16, + properties *EventTraceProperties, + controlCode uint32) error { + r0, _, _ := controlTraceW.Call( + traceHandle, + uintptr(unsafe.Pointer(instanceName)), + uintptr(unsafe.Pointer(properties)), + uintptr(controlCode)) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-opentracew +func _OpenTrace(logfile *EventTraceLogfile) (uint64, error) { + r0, _, err := openTraceW.Call( + uintptr(unsafe.Pointer(logfile))) + var errno syscall.Errno + if errors.As(err, &errno) && errno == 0 { + return uint64(r0), nil + } + return uint64(r0), err +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-processtrace +func _ProcessTrace(handleArray *uint64, + handleCount uint32, + startTime *FileTime, + endTime *FileTime) error { + r0, _, _ := processTrace.Call( + uintptr(unsafe.Pointer(handleArray)), + uintptr(handleCount), + uintptr(unsafe.Pointer(startTime)), + uintptr(unsafe.Pointer(endTime))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntrace/nf-evntrace-closetrace +func _CloseTrace(traceHandle uint64) error { + r0, _, _ := closeTrace.Call( + uintptr(traceHandle)) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} diff --git a/x-pack/libbeat/reader/etw/syscall_tdh.go b/x-pack/libbeat/reader/etw/syscall_tdh.go new file mode 100644 index 000000000000..73551ee123e2 --- /dev/null +++ b/x-pack/libbeat/reader/etw/syscall_tdh.go @@ -0,0 +1,323 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build windows + +package etw + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + tdh = windows.NewLazySystemDLL("tdh.dll") + tdhEnumerateProviders = tdh.NewProc("TdhEnumerateProviders") + tdhGetEventInformation = tdh.NewProc("TdhGetEventInformation") + tdhGetEventMapInformation = tdh.NewProc("TdhGetEventMapInformation") + tdhFormatProperty = tdh.NewProc("TdhFormatProperty") + tdhGetProperty = tdh.NewProc("TdhGetProperty") +) + +const anysizeArray = 1 +const DEFAULT_PROPERTY_BUFFER_SIZE = 256 + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-provider_enumeration_info +type ProviderEnumerationInfo struct { + NumberOfProviders uint32 + Reserved uint32 + TraceProviderInfoArray [anysizeArray]TraceProviderInfo +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-trace_provider_info +type TraceProviderInfo struct { + ProviderGuid windows.GUID + SchemaSource uint32 + ProviderNameOffset uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntcons/ns-evntcons-event_record +type EventRecord struct { + EventHeader EventHeader + BufferContext EtwBufferContext + ExtendedDataCount uint16 + UserDataLength uint16 + ExtendedData *EventHeaderExtendedDataItem + UserData uintptr // Event data + UserContext uintptr +} + +// https://learn.microsoft.com/en-us/windows/win32/api/relogger/ns-relogger-event_header +const ( + EVENT_HEADER_FLAG_STRING_ONLY = 0x0004 + EVENT_HEADER_FLAG_32_BIT_HEADER = 0x0020 + EVENT_HEADER_FLAG_64_BIT_HEADER = 0x0040 +) + +// https://learn.microsoft.com/en-us/windows/win32/api/relogger/ns-relogger-event_header +type EventHeader struct { + Size uint16 + HeaderType uint16 + Flags uint16 + EventProperty uint16 + ThreadId uint32 + ProcessId uint32 + TimeStamp int64 + ProviderId windows.GUID + EventDescriptor EventDescriptor + Time int64 + ActivityId windows.GUID +} + +func (e *EventRecord) pointerSize() uint32 { + if e.EventHeader.Flags&EVENT_HEADER_FLAG_32_BIT_HEADER == EVENT_HEADER_FLAG_32_BIT_HEADER { + return 4 + } + return 8 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor +type EventDescriptor struct { + Id uint16 + Version uint8 + Channel uint8 + Level uint8 + Opcode uint8 + Task uint16 + Keyword uint64 +} + +// https://learn.microsoft.com/en-us/windows/desktop/api/relogger/ns-relogger-etw_buffer_context +type EtwBufferContext struct { + Union uint16 + LoggerId uint16 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/evntcons/ns-evntcons-event_header_extended_data_item +type EventHeaderExtendedDataItem struct { + Reserved1 uint16 + ExtType uint16 + InternalStruct uint16 + DataSize uint16 + DataPtr uint64 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-tdh_context +type TdhContext struct { + ParameterValue uint32 + ParameterType int32 + ParameterSize uint32 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-trace_event_info +type TraceEventInfo struct { + ProviderGUID windows.GUID + EventGUID windows.GUID + EventDescriptor EventDescriptor + DecodingSource DecodingSource + ProviderNameOffset uint32 + LevelNameOffset uint32 + ChannelNameOffset uint32 + KeywordsNameOffset uint32 + TaskNameOffset uint32 + OpcodeNameOffset uint32 + EventMessageOffset uint32 + ProviderMessageOffset uint32 + BinaryXMLOffset uint32 + BinaryXMLSize uint32 + ActivityIDNameOffset uint32 + RelatedActivityIDNameOffset uint32 + PropertyCount uint32 + TopLevelPropertyCount uint32 + Flags TemplateFlags + EventPropertyInfoArray [anysizeArray]EventPropertyInfo +} + +// https://learn.microsoft.com/en-us/windows/desktop/api/tdh/ns-tdh-event_property_info +type EventPropertyInfo struct { + Flags PropertyFlags + NameOffset uint32 + TypeUnion struct { + u1 uint16 + u2 uint16 + u3 uint32 + } + CountUnion uint16 + LengthUnion uint16 + ResTagUnion uint32 +} + +func (i *EventPropertyInfo) count() uint16 { + return i.CountUnion +} + +func (i *EventPropertyInfo) length() uint16 { + return i.LengthUnion +} + +func (i *EventPropertyInfo) inType() uint16 { + return i.TypeUnion.u1 +} + +func (i *EventPropertyInfo) outType() uint16 { + return i.TypeUnion.u2 +} + +func (i *EventPropertyInfo) structStartIndex() uint16 { + return i.inType() +} + +func (i *EventPropertyInfo) numOfStructMembers() uint16 { + return i.outType() +} + +func (i *EventPropertyInfo) mapNameOffset() uint32 { + return i.TypeUnion.u3 +} + +const ( + TdhIntypeBinary = 14 + TdhOuttypeIpv6 = 24 +) + +type DecodingSource int32 +type TemplateFlags int32 + +type PropertyFlags int32 + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ne-tdh-property_flags +const ( + PropertyStruct = PropertyFlags(0x1) + PropertyParamLength = PropertyFlags(0x2) + PropertyParamCount = PropertyFlags(0x4) +) + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-event_map_info +type EventMapInfo struct { + NameOffset uint32 + Flag MapFlags + EntryCount uint32 + Union uint32 + MapEntryArray [anysizeArray]EventMapEntry +} + +type MapFlags int32 + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/ns-tdh-event_map_entry +type EventMapEntry struct { + OutputOffset uint32 + Union uint32 +} + +// https://learn.microsoft.com/en-us/windows/desktop/api/tdh/ns-tdh-property_data_descriptor +type PropertyDataDescriptor struct { + PropertyName unsafe.Pointer + ArrayIndex uint32 + Reserved uint32 +} + +// enumerateProvidersFunc is used to replace the pointer to the function in unit tests +var enumerateProvidersFunc = _TdhEnumerateProviders + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhenumerateproviders +func _TdhEnumerateProviders( + pBuffer *ProviderEnumerationInfo, + pBufferSize *uint32) error { + r0, _, _ := tdhEnumerateProviders.Call( + uintptr(unsafe.Pointer(pBuffer)), + uintptr(unsafe.Pointer(pBufferSize))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhgeteventinformation +func _TdhGetEventInformation(pEvent *EventRecord, + tdhContextCount uint32, + pTdhContext *TdhContext, + pBuffer *TraceEventInfo, + pBufferSize *uint32) error { + r0, _, _ := tdhGetEventInformation.Call( + uintptr(unsafe.Pointer(pEvent)), + uintptr(tdhContextCount), + uintptr(unsafe.Pointer(pTdhContext)), + uintptr(unsafe.Pointer(pBuffer)), + uintptr(unsafe.Pointer(pBufferSize))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhformatproperty +func _TdhFormatProperty( + eventInfo *TraceEventInfo, + mapInfo *EventMapInfo, + pointerSize uint32, + propertyInType uint16, + propertyOutType uint16, + propertyLength uint16, + userDataLength uint16, + userData *byte, + bufferSize *uint32, + buffer *uint8, + userDataConsumed *uint16) error { + r0, _, _ := tdhFormatProperty.Call( + uintptr(unsafe.Pointer(eventInfo)), + uintptr(unsafe.Pointer(mapInfo)), + uintptr(pointerSize), + uintptr(propertyInType), + uintptr(propertyOutType), + uintptr(propertyLength), + uintptr(userDataLength), + uintptr(unsafe.Pointer(userData)), + uintptr(unsafe.Pointer(bufferSize)), + uintptr(unsafe.Pointer(buffer)), + uintptr(unsafe.Pointer(userDataConsumed))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhgetproperty +func _TdhGetProperty(pEvent *EventRecord, + tdhContextCount uint32, + pTdhContext *TdhContext, + propertyDataCount uint32, + pPropertyData *PropertyDataDescriptor, + bufferSize uint32, + pBuffer *byte) error { + r0, _, _ := tdhGetProperty.Call( + uintptr(unsafe.Pointer(pEvent)), + uintptr(tdhContextCount), + uintptr(unsafe.Pointer(pTdhContext)), + uintptr(propertyDataCount), + uintptr(unsafe.Pointer(pPropertyData)), + uintptr(bufferSize), + uintptr(unsafe.Pointer(pBuffer))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhgeteventmapinformation +func _TdhGetEventMapInformation(pEvent *EventRecord, + pMapName *uint16, + pBuffer *EventMapInfo, + pBufferSize *uint32) error { + r0, _, _ := tdhGetEventMapInformation.Call( + uintptr(unsafe.Pointer(pEvent)), + uintptr(unsafe.Pointer(pMapName)), + uintptr(unsafe.Pointer(pBuffer)), + uintptr(unsafe.Pointer(pBufferSize))) + if r0 == 0 { + return nil + } + return syscall.Errno(r0) +}