diff --git a/internal/wclayer/converttobaselayer.go b/internal/wclayer/converttobaselayer.go index 65baf6d29e..86e68e0789 100644 --- a/internal/wclayer/converttobaselayer.go +++ b/internal/wclayer/converttobaselayer.go @@ -2,38 +2,63 @@ package wclayer import ( "context" + "fmt" "os" "path/filepath" "syscall" "github.com/Microsoft/hcsshim/internal/hcserror" + "github.com/Microsoft/hcsshim/internal/longpath" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/safefile" "github.com/Microsoft/hcsshim/internal/winapi" "github.com/pkg/errors" "go.opencensus.io/trace" + "golang.org/x/sys/windows" ) var hiveNames = []string{"DEFAULT", "SAM", "SECURITY", "SOFTWARE", "SYSTEM"} -// Ensure the given file exists as an ordinary file, and create a zero-length file if not. -func ensureFile(path string, root *os.File) error { - stat, err := safefile.LstatRelative(path, root) - if err != nil && os.IsNotExist(err) { - newFile, err := safefile.OpenRelative(path, root, 0, syscall.FILE_SHARE_WRITE, winapi.FILE_CREATE, 0) - if err != nil { - return err - } - return newFile.Close() +// Ensure the given file exists as an ordinary file, and create a minimal hive file if not. +func ensureHive(path string, root *os.File) (err error) { + _, err = safefile.LstatRelative(path, root) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("accessing %s: %w", path, err) + } + + version := windows.RtlGetVersion() + if version == nil { + return fmt.Errorf("failed to get OS version") } + var fullPath string + fullPath, err = longpath.LongAbs(filepath.Join(root.Name(), path)) if err != nil { - return err + return fmt.Errorf("getting path: %w", err) } - if !stat.Mode().IsRegular() { - fullPath := filepath.Join(root.Name(), path) - return errors.Errorf("%s has unexpected file mode %s", fullPath, stat.Mode().String()) + var key syscall.Handle + err = winapi.ORCreateHive(&key) + if err != nil { + return fmt.Errorf("creating hive: %w", err) + } + + defer func() { + closeErr := winapi.ORCloseHive(&key) + if closeErr != nil { + err = fmt.Errorf("closing hive key: %w", closeErr) + } + }() + + var hivePath *uint16 + hivePath, err = syscall.UTF16PtrFromString(fullPath) + if err != nil { + return fmt.Errorf("getting path: %w", err) + } + + err = winapi.ORSaveHive(key, hivePath, version.MajorVersion, version.MinorVersion) + if err != nil { + return fmt.Errorf("saving hive: %w", err) } return nil @@ -48,7 +73,7 @@ func ensureBaseLayer(root *os.File) (hasUtilityVM bool, err error) { for _, hiveName := range hiveNames { hivePath := filepath.Join(hiveSourcePath, hiveName) - if err = ensureFile(hivePath, root); err != nil { + if err = ensureHive(hivePath, root); err != nil { return } } diff --git a/internal/winapi/ofreg.go b/internal/winapi/ofreg.go new file mode 100644 index 0000000000..9f7cd3537e --- /dev/null +++ b/internal/winapi/ofreg.go @@ -0,0 +1,5 @@ +package winapi + +//sys ORCreateHive(key *syscall.Handle) (regerrno error) = offreg.ORCreateHive +//sys ORSaveHive(key syscall.Handle, file *uint16, OsMajorVersion uint32, OsMinorVersion uint32) (regerrno error) = offreg.ORSaveHive +//sys ORCloseHive(key *syscall.Handle) (regerrno error) = offreg.ORCloseHive diff --git a/internal/winapi/winapi.go b/internal/winapi/winapi.go index b45fc7de43..2073ca3a39 100644 --- a/internal/winapi/winapi.go +++ b/internal/winapi/winapi.go @@ -1,3 +1,3 @@ package winapi -//go:generate go run ..\..\mksyscall_windows.go -output zsyscall_windows.go bindflt.go user.go console.go system.go net.go path.go thread.go jobobject.go logon.go memory.go process.go processor.go devices.go filesystem.go errors.go +//go:generate go run ..\..\mksyscall_windows.go -output zsyscall_windows.go bindflt.go user.go console.go system.go net.go path.go thread.go jobobject.go logon.go memory.go process.go processor.go devices.go filesystem.go errors.go ofreg.go diff --git a/internal/winapi/zsyscall_windows.go b/internal/winapi/zsyscall_windows.go index 5485115387..7a7369e1bc 100644 --- a/internal/winapi/zsyscall_windows.go +++ b/internal/winapi/zsyscall_windows.go @@ -44,6 +44,7 @@ var ( modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modcfgmgr32 = windows.NewLazySystemDLL("cfgmgr32.dll") + modoffreg = windows.NewLazySystemDLL("offreg.dll") procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") procNetLocalGroupGetInfo = modnetapi32.NewProc("NetLocalGroupGetInfo") @@ -78,6 +79,9 @@ var ( procNtOpenDirectoryObject = modntdll.NewProc("NtOpenDirectoryObject") procNtQueryDirectoryObject = modntdll.NewProc("NtQueryDirectoryObject") procRtlNtStatusToDosError = modntdll.NewProc("RtlNtStatusToDosError") + procORCreateHive = modoffreg.NewProc("ORCreateHive") + procORSaveHive = modoffreg.NewProc("ORSaveHive") + procORCloseHive = modoffreg.NewProc("ORCloseHive") ) func BfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { @@ -405,3 +409,27 @@ func RtlNtStatusToDosError(status uint32) (winerr error) { } return } + +func ORCreateHive(key *syscall.Handle) (regerrno error) { + r0, _, _ := syscall.Syscall(procORCreateHive.Addr(), 1, uintptr(unsafe.Pointer(key)), 0, 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func ORSaveHive(key syscall.Handle, file *uint16, OsMajorVersion uint32, OsMinorVersion uint32) (regerrno error) { + r0, _, _ := syscall.Syscall6(procORSaveHive.Addr(), 4, uintptr(key), uintptr(unsafe.Pointer(file)), uintptr(OsMajorVersion), uintptr(OsMinorVersion), 0, 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func ORCloseHive(key *syscall.Handle) (regerrno error) { + r0, _, _ := syscall.Syscall(procORCloseHive.Addr(), 1, uintptr(unsafe.Pointer(key)), 0, 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +}