diff --git a/server/storage/wal/file_pipeline.go b/server/storage/wal/file_pipeline.go index bdfa31b5e2a..c8ee4cce429 100644 --- a/server/storage/wal/file_pipeline.go +++ b/server/storage/wal/file_pipeline.go @@ -75,7 +75,7 @@ func (fp *filePipeline) Close() error { func (fp *filePipeline) alloc() (f *fileutil.LockedFile, err error) { // count % 2 so this file isn't the same as the one last published fpath := filepath.Join(fp.dir, fmt.Sprintf("%d.tmp", fp.count%2)) - if f, err = fileutil.LockFile(fpath, os.O_CREATE|os.O_WRONLY, fileutil.PrivateFileMode); err != nil { + if f, err = createNewWALFile[*fileutil.LockedFile](fpath, false); err != nil { return nil, err } if err = fileutil.Preallocate(f.File, fp.size, true); err != nil { diff --git a/server/storage/wal/repair.go b/server/storage/wal/repair.go index d1a887835da..16277540f34 100644 --- a/server/storage/wal/repair.go +++ b/server/storage/wal/repair.go @@ -67,7 +67,7 @@ func Repair(lg *zap.Logger, dirpath string) bool { case errors.Is(err, io.ErrUnexpectedEOF): brokenName := f.Name() + ".broken" - bf, bferr := os.OpenFile(brokenName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fileutil.PrivateFileMode) + bf, bferr := createNewWALFile[*os.File](brokenName, true) if bferr != nil { lg.Warn("failed to create backup file", zap.String("path", brokenName), zap.Error(bferr)) return false diff --git a/server/storage/wal/wal.go b/server/storage/wal/wal.go index 3a313876083..f8acafe5ab6 100644 --- a/server/storage/wal/wal.go +++ b/server/storage/wal/wal.go @@ -126,7 +126,7 @@ func Create(lg *zap.Logger, dirpath string, metadata []byte) (*WAL, error) { } p := filepath.Join(tmpdirpath, walName(0, 0)) - f, err := fileutil.LockFile(p, os.O_WRONLY|os.O_CREATE, fileutil.PrivateFileMode) + f, err := createNewWALFile[*fileutil.LockedFile](p, false) if err != nil { lg.Warn( "failed to flock an initial WAL file", @@ -233,6 +233,31 @@ func Create(lg *zap.Logger, dirpath string, metadata []byte) (*WAL, error) { return w, nil } +// createNewWALFile creates a WAL file. +// To create a locked file, use *fileutil.LockedFile type parameter. +// To create a standard file, use *os.File type parameter. +// If truncate is true, the file will be truncated if it already exists. +func createNewWALFile[T *os.File | *fileutil.LockedFile](path string, truncate bool) (T, error) { + flag := os.O_WRONLY | os.O_CREATE + if truncate { + flag |= os.O_TRUNC + } + + if _, isLockedFile := any(T(nil)).(*fileutil.LockedFile); isLockedFile { + lockedFile, err := fileutil.LockFile(path, flag, fileutil.PrivateFileMode) + if err != nil { + return nil, err + } + return any(lockedFile).(T), nil + } + + file, err := os.OpenFile(path, flag, fileutil.PrivateFileMode) + if err != nil { + return nil, err + } + return any(file).(T), nil +} + func (w *WAL) Reopen(lg *zap.Logger, snap walpb.Snapshot) (*WAL, error) { err := w.Close() if err != nil { diff --git a/server/storage/wal/wal_test.go b/server/storage/wal/wal_test.go index ed3a8893df5..e8125bf7a91 100644 --- a/server/storage/wal/wal_test.go +++ b/server/storage/wal/wal_test.go @@ -96,6 +96,60 @@ func TestNew(t *testing.T) { } } +func TestCreateNewWALFile(t *testing.T) { + tests := []struct { + name string + fileType interface{} + truncate bool + }{ + { + name: "creating standard file should succeed", + fileType: &os.File{}, + }, + { + name: "creating locked file should succeed", + fileType: &fileutil.LockedFile{}, + }, + { + name: "creating standard file with truncate should succeed", + fileType: &os.File{}, + truncate: true, + }, + { + name: "creating locked file with truncate should succeed", + fileType: &fileutil.LockedFile{}, + truncate: true, + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := filepath.Join(t.TempDir(), walName(0, uint64(i))) + + var err error + var f interface{} + + switch tt.fileType.(type) { + case *os.File: + f, err = createNewWALFile[*os.File](p, tt.truncate) + require.IsType(t, &os.File{}, f) + case *fileutil.LockedFile: + f, err = createNewWALFile[*fileutil.LockedFile](p, tt.truncate) + require.IsType(t, &fileutil.LockedFile{}, f) + } + + require.NoError(t, err) + + // Validate the file permissions + fi, err := os.Stat(p) + require.NoError(t, err) + expectedPerms := fmt.Sprintf("%o", os.FileMode(fileutil.PrivateFileMode)) + actualPerms := fmt.Sprintf("%o", fi.Mode().Perm()) + require.Equal(t, expectedPerms, actualPerms, "unexpected file permissions on %q", p) + }) + } +} + func TestCreateFailFromPollutedDir(t *testing.T) { p := t.TempDir() os.WriteFile(filepath.Join(p, "test.wal"), []byte("data"), os.ModeTemporary)