Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ func utf16FromString(str string) []uint16 {

// goBytes copies the given C byte array to a Go byte array (see `C.GoBytes`).
// This function avoids having cgo as dependency.
func goBytes(src uintptr, len uint32) []byte {
if src == uintptr(0) {
func goBytes(src *byte, len uint32) []byte {
if src == nil || len == 0 {
return []byte{}
}
rv := make([]byte, len)
copy(rv, *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: src,
Data: uintptr(unsafe.Pointer(src)),
Len: int(len),
Cap: int(len),
})))
Expand All @@ -59,7 +59,7 @@ func sysToCredential(cred *sysCREDENTIAL) (result *Credential) {
result.CredentialBlob = goBytes(cred.CredentialBlob, cred.CredentialBlobSize)
result.Attributes = make([]CredentialAttribute, cred.AttributeCount)
attrSlice := *(*[]sysCREDENTIAL_ATTRIBUTE)(unsafe.Pointer(&reflect.SliceHeader{
Data: cred.Attributes,
Data: uintptr(unsafe.Pointer(cred.Attributes)),
Len: int(cred.AttributeCount),
Cap: int(cred.AttributeCount),
}))
Expand All @@ -85,17 +85,13 @@ func sysFromCredential(cred *Credential) (result *sysCREDENTIAL) {
result.LastWritten = syscall.NsecToFiletime(cred.LastWritten.UnixNano())
result.CredentialBlobSize = uint32(len(cred.CredentialBlob))
if len(cred.CredentialBlob) > 0 {
result.CredentialBlob = uintptr(unsafe.Pointer(&cred.CredentialBlob[0]))
} else {
result.CredentialBlob = 0
result.CredentialBlob = &cred.CredentialBlob[0]
}
result.Persist = uint32(cred.Persist)
result.AttributeCount = uint32(len(cred.Attributes))
attributes := make([]sysCREDENTIAL_ATTRIBUTE, len(cred.Attributes))
if len(attributes) > 0 {
result.Attributes = uintptr(unsafe.Pointer(&attributes[0]))
} else {
result.Attributes = 0
result.Attributes = &attributes[0]
}
for i := range cred.Attributes {
inAttr := &cred.Attributes[i]
Expand All @@ -104,9 +100,7 @@ func sysFromCredential(cred *Credential) (result *sysCREDENTIAL) {
outAttr.Flags = 0
outAttr.ValueSize = uint32(len(inAttr.Value))
if len(inAttr.Value) > 0 {
outAttr.Value = uintptr(unsafe.Pointer(&inAttr.Value[0]))
} else {
outAttr.Value = 0
outAttr.Value = &inAttr.Value[0]
}
}
result.TargetAlias, _ = syscall.UTF16PtrFromString(cred.TargetAlias)
Expand Down
20 changes: 10 additions & 10 deletions conversion_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//go:build windows
// +build windows

package wincred

import (
"testing"
"time"
"unsafe"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -52,7 +52,7 @@ func BenchmarkUtf16ToByte(b *testing.B) {

func TestGoBytes(t *testing.T) {
input := []byte{1, 2, 3, 4, 5}
output := goBytes(uintptr(unsafe.Pointer(&input[0])), uint32(len(input)))
output := goBytes(&input[0], uint32(len(input)))
assert.Equal(t, len(input), len(output))
assert.Equal(t, input[0], output[0])
assert.Equal(t, input[1], output[1])
Expand All @@ -65,15 +65,15 @@ func TestGoBytes(t *testing.T) {

func TestGoBytes_Null(t *testing.T) {
assert.NotPanics(t, func() {
output := goBytes(0, 123)
output := goBytes(nil, 123)
assert.Equal(t, []byte{}, output)
})
}

func BenchmarkGoBytes(b *testing.B) {
input := []byte{1, 2, 3, 4, 5}
for i := 0; i < b.N; i++ {
goBytes(uintptr(unsafe.Pointer(&input[0])), uint32(len(input)))
_ = goBytes(&input[0], uint32(len(input)))
}
}

Expand Down Expand Up @@ -108,7 +108,7 @@ func TestConversion_CredentialBlob(t *testing.T) {
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.Equal(t, uint32(3), sys.CredentialBlobSize)
assert.NotEqual(t, uintptr(0), sys.CredentialBlob)
assert.NotNil(t, res.CredentialBlob)
assert.Equal(t, cred.CredentialBlob, res.CredentialBlob)
}

Expand All @@ -117,7 +117,7 @@ func TestConversion_CredentialBlob_Empty(t *testing.T) {
cred.CredentialBlob = []byte{} // empty blob
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.Equal(t, uintptr(0), sys.CredentialBlob)
assert.Nil(t, sys.CredentialBlob)
assert.Equal(t, uint32(0), sys.CredentialBlobSize)
assert.Equal(t, []byte{}, res.CredentialBlob)
}
Expand All @@ -127,7 +127,7 @@ func TestConversion_CredentialBlob_Nil(t *testing.T) {
cred.CredentialBlob = nil // nil blob
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.Equal(t, uintptr(0), sys.CredentialBlob)
assert.Nil(t, sys.CredentialBlob)
assert.Equal(t, uint32(0), sys.CredentialBlobSize)
assert.Equal(t, []byte{}, res.CredentialBlob)
}
Expand All @@ -140,7 +140,7 @@ func TestConversion_Attributes(t *testing.T) {
}
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.NotEqual(t, uintptr(0), sys.Attributes)
assert.NotNil(t, sys.Attributes)
assert.Equal(t, uint32(2), sys.AttributeCount)
assert.Equal(t, cred.Attributes, res.Attributes)
}
Expand All @@ -150,7 +150,7 @@ func TestConversion_Attributes_Empty(t *testing.T) {
cred.Attributes = []CredentialAttribute{}
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.Equal(t, uintptr(0), sys.Attributes)
assert.Nil(t, sys.Attributes)
assert.Equal(t, uint32(0), sys.AttributeCount)
assert.Equal(t, []CredentialAttribute{}, res.Attributes)
}
Expand All @@ -160,7 +160,7 @@ func TestConversion_Attributes_Nil(t *testing.T) {
cred.Attributes = nil
sys := sysFromCredential(cred)
res := sysToCredential(sys)
assert.Equal(t, uintptr(0), sys.Attributes)
assert.Nil(t, sys.Attributes)
assert.Equal(t, uint32(0), sys.AttributeCount)
assert.Equal(t, []CredentialAttribute{}, res.Attributes)
}
Expand Down
9 changes: 6 additions & 3 deletions sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package wincred

import (
"reflect"
"runtime"
"syscall"
"unsafe"

Expand Down Expand Up @@ -33,10 +34,10 @@ type sysCREDENTIAL struct {
Comment *uint16
LastWritten windows.Filetime
CredentialBlobSize uint32
CredentialBlob uintptr
CredentialBlob *byte
Persist uint32
AttributeCount uint32
Attributes uintptr
Attributes *sysCREDENTIAL_ATTRIBUTE
TargetAlias *uint16
UserName *uint16
}
Expand All @@ -46,7 +47,7 @@ type sysCREDENTIAL_ATTRIBUTE struct {
Keyword *uint16
Flags uint32
ValueSize uint32
Value uintptr
Value *byte
}

// https://docs.microsoft.com/en-us/windows/desktop/api/wincred/ns-wincred-_credentialw
Expand Down Expand Up @@ -93,6 +94,8 @@ func sysCredWrite(cred *Credential, typ sysCRED_TYPE) error {
uintptr(unsafe.Pointer(ncred)),
0,
)
// Make sure everything reachable from ncred stays alive through the call.
runtime.KeepAlive(ncred)
if ret == 0 {
return err
}
Expand Down
39 changes: 39 additions & 0 deletions sys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ package wincred

import (
"errors"
"runtime"
"testing"
"time"
"unsafe"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -89,3 +92,39 @@ func TestSysCredDelete_Mock(t *testing.T) {
assert.Nil(t, err)
mockCredDelete.AssertNumberOfCalls(t, "Call", 1)
}

func TestCredWrite_GCSafety_WithAttributes(t *testing.T) {
// Minimal repro for the Go 1.25 regression: we create a credential that has at least
// one Attribute with a non-empty Value (so sysFromCredential allocates the attributes
// slice internally). Then we force a GC after building the native struct and *before*
// calling CredWriteW. With the old uintptr-based fields, the GC can reclaim the slice,
// leaving dangling addresses and causing ERROR_INVALID_PARAMETER. With the fix, it’s fine.
cred := &Credential{
TargetName: "Foo",
Comment: "Bar",
LastWritten: time.Now(),
TargetAlias: "MyAlias",
UserName: "Nobody",
Persist: PersistLocalMachine,
CredentialBlob: []byte("secret"),
Attributes: []CredentialAttribute{
{Keyword: "label", Value: []byte("hello-world")},
},
}

ncred := sysFromCredential(cred)
ncred.Type = uint32(sysCRED_TYPE_GENERIC)

// run GC a few times to gc the attributes slice.
for i := 0; i < 5; i++ {
runtime.GC()
}

// call CredWriteW - same as sysCredWrite
ret, _, err := procCredWrite.Call(uintptr(unsafe.Pointer(ncred)), 0)
if ret == 0 {
t.Fatalf("CredWriteW failed: %v", err)
}

_ = sysCredDelete(cred, sysCRED_TYPE_GENERIC)
}