diff --git a/server/storage/wal/file_pipeline.go b/server/storage/wal/file_pipeline.go index bdfa31b5e2a1..f2db83ffa5b0 100644 --- a/server/storage/wal/file_pipeline.go +++ b/server/storage/wal/file_pipeline.go @@ -75,7 +75,7 @@ func (fp *filePipeline) Close() error { func (fp *filePipeline) alloc() (f *fileutil.LockedFile, err error) { // count % 2 so this file isn't the same as the one last published fpath := filepath.Join(fp.dir, fmt.Sprintf("%d.tmp", fp.count%2)) - if f, err = fileutil.LockFile(fpath, os.O_CREATE|os.O_WRONLY, fileutil.PrivateFileMode); err != nil { + if f, err = createNewWALFile[*fileutil.LockedFile](fpath, true, false); err != nil { return nil, err } if err = fileutil.Preallocate(f.File, fp.size, true); err != nil { diff --git a/server/storage/wal/repair.go b/server/storage/wal/repair.go index d1a887835da6..4cc6dd22e461 100644 --- a/server/storage/wal/repair.go +++ b/server/storage/wal/repair.go @@ -67,7 +67,7 @@ func Repair(lg *zap.Logger, dirpath string) bool { case errors.Is(err, io.ErrUnexpectedEOF): brokenName := f.Name() + ".broken" - bf, bferr := os.OpenFile(brokenName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fileutil.PrivateFileMode) + bf, bferr := createNewWALFile[*os.File](brokenName, false, true) if bferr != nil { lg.Warn("failed to create backup file", zap.String("path", brokenName), zap.Error(bferr)) return false diff --git a/server/storage/wal/wal.go b/server/storage/wal/wal.go index 3a313876083f..f5c4cf9d0294 100644 --- a/server/storage/wal/wal.go +++ b/server/storage/wal/wal.go @@ -126,7 +126,7 @@ func Create(lg *zap.Logger, dirpath string, metadata []byte) (*WAL, error) { } p := filepath.Join(tmpdirpath, walName(0, 0)) - f, err := fileutil.LockFile(p, os.O_WRONLY|os.O_CREATE, fileutil.PrivateFileMode) + f, err := createNewWALFile[*fileutil.LockedFile](p, true, false) if err != nil { lg.Warn( "failed to flock an initial WAL file", @@ -233,6 +233,39 @@ func Create(lg *zap.Logger, dirpath string, metadata []byte) (*WAL, error) { return w, nil } +// createNewWALFile creates a WAL file. +// To create a locked file, use *fileutil.LockedFile type parameter and set lock to true. +// To create a standard file, use *os.File type parameter and set lock to false. +// If truncate is true, the file will be truncated if it already exists. +func createNewWALFile[T *os.File | *fileutil.LockedFile](path string, lock, truncate bool) (T, error) { + flag := os.O_WRONLY | os.O_CREATE + if truncate { + flag |= os.O_TRUNC + } + + if lock { + lockedFile, err := fileutil.LockFile(path, flag, fileutil.PrivateFileMode) + if err != nil { + return nil, err + } + lf, ok := any(lockedFile).(T) + if !ok { + return nil, fmt.Errorf("expected file type '%T', but got '%T'", (T)(nil), lockedFile) + } + return lf, nil + } + + file, err := os.OpenFile(path, flag, fileutil.PrivateFileMode) + if err != nil { + return nil, err + } + f, ok := any(file).(T) + if !ok { + return nil, fmt.Errorf("expected file type '%T', but got '%T'", (T)(nil), file) + } + return f, nil +} + func (w *WAL) Reopen(lg *zap.Logger, snap walpb.Snapshot) (*WAL, error) { err := w.Close() if err != nil { diff --git a/server/storage/wal/wal_test.go b/server/storage/wal/wal_test.go index ed3a8893df55..32b7c8bce39b 100644 --- a/server/storage/wal/wal_test.go +++ b/server/storage/wal/wal_test.go @@ -96,6 +96,71 @@ func TestNew(t *testing.T) { } } +func TestCreateNewWALFile(t *testing.T) { + tests := []struct { + name string + fileType interface{} + lock bool + truncate bool + wantErr error + }{ + { + name: "*os.File type constraint with lock set to true should return error", + fileType: &os.File{}, + lock: true, + wantErr: errors.New("expected file type '*os.File', but got '*fileutil.LockedFile'"), + }, + { + name: "*fileutil.LockedFile type constraint with lock set to false should return error", + fileType: &fileutil.LockedFile{}, + lock: false, + wantErr: errors.New("expected file type '*fileutil.LockedFile', but got '*os.File'"), + }, + { + name: "*os.File type constraint with lock set to false should succeed", + fileType: &os.File{}, + lock: false, + wantErr: nil, + }, + { + name: "*fileutil.LockedFile type constraint with lock set to true should succeed", + fileType: &fileutil.LockedFile{}, + lock: true, + wantErr: nil, + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := filepath.Join(t.TempDir(), walName(0, uint64(i))) + + var err error + var f interface{} + + switch tt.fileType.(type) { + case *os.File: + f, err = createNewWALFile[*os.File](p, tt.lock, tt.truncate) + require.IsType(t, &os.File{}, f) + case *fileutil.LockedFile: + f, err = createNewWALFile[*fileutil.LockedFile](p, tt.lock, tt.truncate) + require.IsType(t, &fileutil.LockedFile{}, f) + } + + if tt.wantErr == nil { + require.NoError(t, err) + } else { + require.EqualError(t, err, tt.wantErr.Error()) + } + + fi, err := os.Stat(p) + require.NoError(t, err) + expectedPerms := fmt.Sprintf("%o", os.FileMode(fileutil.PrivateFileMode)) + actualPerms := fmt.Sprintf("%o", fi.Mode().Perm()) + require.Equal(t, expectedPerms, actualPerms, "unexpected file permissions on %q", p) + }) + } +} + func TestCreateFailFromPollutedDir(t *testing.T) { p := t.TempDir() os.WriteFile(filepath.Join(p, "test.wal"), []byte("data"), os.ModeTemporary)