From 4010b16eef6632331ba564dc76174a3db7ff4903 Mon Sep 17 00:00:00 2001 From: Tiago Peczenyj Date: Sun, 5 Nov 2023 14:46:42 +0100 Subject: [PATCH] Add support to fs.fs on serve static files (#1640) * substitute *os.File by fs.File * refactor error handling by using the new recommended form * finish implementation * substitute seek(offset,0) by seek(offset, io.SeekStart) * add unit test * use io.SeekStart on Seek method --- compress.go | 10 +- fs.go | 238 +++++++++++++++---- fs_fs_test.go | 640 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 841 insertions(+), 47 deletions(-) create mode 100644 fs_fs_test.go diff --git a/compress.go b/compress.go index 50d381b804..8494d569be 100644 --- a/compress.go +++ b/compress.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" "io" - "os" + "io/fs" "sync" "github.com/klauspost/compress/flate" @@ -421,7 +421,7 @@ func newCompressWriterPoolMap() []*sync.Pool { return m } -func isFileCompressible(f *os.File, minCompressRatio float64) bool { +func isFileCompressible(f fs.File, minCompressRatio float64) bool { // Try compressing the first 4kb of the file // and see if it can be compressed by more than // the given minCompressRatio. @@ -433,7 +433,11 @@ func isFileCompressible(f *os.File, minCompressRatio float64) bool { } _, err := copyZeroAlloc(zw, lr) releaseStacklessGzipWriter(zw, CompressDefaultCompression) - f.Seek(0, 0) //nolint:errcheck + seeker, ok := f.(io.Seeker) + if !ok { + return false + } + seeker.Seek(0, io.SeekStart) //nolint:errcheck if err != nil { return false } diff --git a/fs.go b/fs.go index fc679de347..4d8c1c1ebf 100644 --- a/fs.go +++ b/fs.go @@ -6,6 +6,7 @@ import ( "fmt" "html" "io" + "io/fs" "mime" "net/http" "os" @@ -136,6 +137,32 @@ var ( rootFSHandler RequestHandler ) +// ServeFS returns HTTP response containing compressed file contents from the given fs.FS's path. +// +// HTTP response may contain uncompressed file contents in the following cases: +// +// - Missing 'Accept-Encoding: gzip' request header. +// - No write access to directory containing the file. +// +// Directory contents is returned if path points to directory. +// +// See also ServeFile. +func ServeFS(ctx *RequestCtx, filesystem fs.FS, path string) { + f := &FS{ + FS: filesystem, + Root: "", + AllowEmptyRoot: true, + GenerateIndexPages: true, + Compress: true, + CompressBrotli: true, + AcceptByteRange: true, + } + handler := f.NewRequestHandler() + + ctx.Request.SetRequestURI(path) + handler(ctx) +} + // PathRewriteFunc must return new request path based on arbitrary ctx // info such as ctx.Path(). // @@ -225,6 +252,9 @@ func NewPathPrefixStripper(prefixSize int) PathRewriteFunc { type FS struct { noCopy noCopy + // FS is filesystem to serve files from. eg: embed.FS os.DirFS + FS fs.FS + // Path to the root directory to serve files from. Root string @@ -391,16 +421,20 @@ func (fs *FS) NewRequestHandler() RequestHandler { } func (fs *FS) normalizeRoot(root string) string { - // Serve files from the current working directory if Root is empty or if Root is a relative path. - if (!fs.AllowEmptyRoot && len(root) == 0) || (len(root) > 0 && !filepath.IsAbs(root)) { - path, err := os.Getwd() - if err != nil { - path = "." + // fs.FS uses relative paths, that paths are slash-separated on all systems, even Windows. + if fs.FS == nil { + // Serve files from the current working directory if Root is empty or if Root is a relative path. + if (!fs.AllowEmptyRoot && len(root) == 0) || (len(root) > 0 && !filepath.IsAbs(root)) { + path, err := os.Getwd() + if err != nil { + path = "." + } + root = path + "/" + root } - root = path + "/" + root + + // convert the root directory slashes to the native format + root = filepath.FromSlash(root) } - // convert the root directory slashes to the native format - root = filepath.FromSlash(root) // strip trailing slashes from the root path for len(root) > 0 && root[len(root)-1] == os.PathSeparator { @@ -440,6 +474,7 @@ func (fs *FS) initRequestHandler() { } h := &fsHandler{ + fs: fs.FS, root: root, indexNames: fs.IndexNames, pathRewrite: fs.PathRewrite, @@ -456,6 +491,10 @@ func (fs *FS) initRequestHandler() { cacheGzip: make(map[string]*fsFile), } + if h.fs == nil { + h.fs = &osFS{} // It provides os.Open and os.Stat + } + go func() { var pendingFiles []*fsFile @@ -488,6 +527,7 @@ func (fs *FS) initRequestHandler() { } type fsHandler struct { + fs fs.FS root string indexNames []string pathRewrite PathRewriteFunc @@ -510,7 +550,8 @@ type fsHandler struct { type fsFile struct { h *fsHandler - f *os.File + f fs.File + filename string // fs.FileInfo.Name() return filename, isn't filepath. dirIndex []byte contentType string contentLength int @@ -555,6 +596,9 @@ func (ff *fsFile) smallFileReader() (io.Reader, error) { const maxSmallFileSize = 2 * 4096 func (ff *fsFile) isBig() bool { + if _, ok := ff.h.fs.(*osFS); !ok { // fs.FS only uses bigFileReader, memory cache uses fsSmallFileReader + return ff.f != nil + } return ff.contentLength > maxSmallFileSize && len(ff.dirIndex) == 0 } @@ -577,7 +621,7 @@ func (ff *fsFile) bigFileReader() (io.Reader, error) { return r, nil } - f, err := os.Open(ff.f.Name()) + f, err := ff.h.fs.Open(ff.filename) if err != nil { return nil, fmt.Errorf("cannot open already opened file: %w", err) } @@ -614,14 +658,18 @@ func (ff *fsFile) decReadersCount() { // bigFileReader attempts to trigger sendfile // for sending big files over the wire. type bigFileReader struct { - f *os.File + f fs.File ff *fsFile r io.Reader lr io.LimitedReader } func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error { - if _, err := r.f.Seek(int64(startPos), 0); err != nil { + seeker, ok := r.f.(io.Seeker) + if !ok { + return errors.New("must implement io.Seeker") + } + if _, err := seeker.Seek(int64(startPos), io.SeekStart); err != nil { return err } r.r = &r.lr @@ -646,7 +694,12 @@ func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) { func (r *bigFileReader) Close() error { r.r = r.f - n, err := r.f.Seek(0, 0) + seeker, ok := r.f.(io.Seeker) + if !ok { + _ = r.f.Close() + return errors.New("must implement io.Seeker") + } + n, err := seeker.Seek(0, io.SeekStart) if err == nil { if n == 0 { ff := r.ff @@ -655,7 +708,7 @@ func (r *bigFileReader) Close() error { ff.bigFilesLock.Unlock() } else { _ = r.f.Close() - err = errors.New("bug: File.Seek(0,0) returned (non-zero, nil)") + err = errors.New("bug: File.Seek(0, io.SeekStart) returned (non-zero, nil)") } } else { _ = r.f.Close() @@ -697,7 +750,11 @@ func (r *fsSmallFileReader) Read(p []byte) (int, error) { ff := r.ff if ff.f != nil { - n, err := ff.f.ReadAt(p, int64(r.startPos)) + ra, ok := ff.f.(io.ReaderAt) + if !ok { + return 0, errors.New("must implement io.ReaderAt") + } + n, err := ra.ReadAt(p, int64(r.startPos)) r.startPos += n return n, err } @@ -732,7 +789,11 @@ func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) { if len(buf) > tailLen { buf = buf[:tailLen] } - n, err = ff.f.ReadAt(buf, int64(curPos)) + ra, ok := ff.f.(io.ReaderAt) + if !ok { + return 0, errors.New("must implement io.ReaderAt") + } + n, err = ra.ReadAt(buf, int64(curPos)) nw, errw := w.Write(buf[:n]) curPos += nw if errw == nil && nw != n { @@ -799,6 +860,12 @@ func cleanCacheNolock(cache map[string]*fsFile, pendingFiles, filesToRelease []* } func (h *fsHandler) pathToFilePath(path string) string { + if _, ok := h.fs.(*osFS); !ok { + if len(path) < 1 { + return path + } + return path[1:] + } return filepath.FromSlash(h.root + path) } @@ -1051,7 +1118,7 @@ func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress if err == nil { return ff, nil } - if !os.IsNotExist(err) { + if !errors.Is(err, fs.ErrNotExist) { return nil, fmt.Errorf("cannot open file %q: %w", indexFilePath, err) } } @@ -1060,7 +1127,7 @@ func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath) } - return h.createDirIndex(ctx.URI(), dirPath, mustCompress, fileEncoding) + return h.createDirIndex(ctx, dirPath, mustCompress, fileEncoding) } var ( @@ -1068,9 +1135,11 @@ var ( errNoCreatePermission = errors.New("no 'create file' permissions") ) -func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) { +func (h *fsHandler) createDirIndex(ctx *RequestCtx, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) { w := &bytebufferpool.ByteBuffer{} + base := ctx.URI() + basePathEscaped := html.EscapeString(string(base.Path())) _, _ = fmt.Fprintf(w, "%s", basePathEscaped) _, _ = fmt.Fprintf(w, "

%s

", basePathEscaped) @@ -1084,28 +1153,29 @@ func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool, _, _ = fmt.Fprintf(w, `
  • ..
  • `, parentPathEscaped) } - f, err := os.Open(dirPath) + dirEntries, err := fs.ReadDir(h.fs, dirPath) if err != nil { return nil, err } - fileinfos, err := f.Readdir(0) - _ = f.Close() - if err != nil { - return nil, err - } - - fm := make(map[string]os.FileInfo, len(fileinfos)) - filenames := make([]string, 0, len(fileinfos)) + fm := make(map[string]fs.FileInfo, len(dirEntries)) + filenames := make([]string, 0, len(dirEntries)) nestedContinue: - for _, fi := range fileinfos { - name := fi.Name() + for _, de := range dirEntries { + name := de.Name() for _, cfs := range h.compressedFileSuffixes { if strings.HasSuffix(name, cfs) { // Do not show compressed files on index page. continue nestedContinue } } + fi, err := de.Info() + if err != nil { + ctx.Logger().Printf("cannot fetch information from dir entry %q: %v, skip", name, err) + + continue nestedContinue + } + fm[name] = fi filenames = append(filenames, name) } @@ -1163,7 +1233,7 @@ const ( ) func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) (*fsFile, error) { - f, err := os.Open(filePath) + f, err := h.fs.Open(filePath) if err != nil { return nil, err } @@ -1182,10 +1252,15 @@ func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) if strings.HasSuffix(filePath, h.compressedFileSuffixes[fileEncoding]) || fileInfo.Size() > fsMaxCompressibleFileSize || !isFileCompressible(f, fsMinCompressRatio) { - return h.newFSFile(f, fileInfo, false, "") + return h.newFSFile(f, fileInfo, false, filePath, "") } compressedFilePath := h.filePathToCompressed(filePath) + + if _, ok := h.fs.(*osFS); !ok { + return h.newCompressedFSFileCache(f, fileInfo, compressedFilePath, fileEncoding) + } + if compressedFilePath != filePath { if err := os.MkdirAll(filepath.Dir(compressedFilePath), os.ModePerm); err != nil { return nil, err @@ -1207,7 +1282,7 @@ func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) return ff, err } -func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string, fileEncoding string) (*fsFile, error) { +func (h *fsHandler) compressFileNolock(f fs.File, fileInfo fs.FileInfo, filePath, compressedFilePath string, fileEncoding string) (*fsFile, error) { // Attempt to open compressed file created by another concurrent // goroutine. // It is safe opening such a file, since the file creation @@ -1223,7 +1298,7 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat zf, err := os.Create(tmpFilePath) if err != nil { _ = f.Close() - if !os.IsPermission(err) { + if !errors.Is(err, fs.ErrPermission) { return nil, fmt.Errorf("cannot create temporary file %q: %w", tmpFilePath, err) } return nil, errNoCreatePermission @@ -1258,8 +1333,71 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat return h.newCompressedFSFile(compressedFilePath, fileEncoding) } +// newCompressedFSFileCache use memory cache compressed files +func (h *fsHandler) newCompressedFSFileCache(f fs.File, fileInfo fs.FileInfo, filePath, fileEncoding string) (*fsFile, error) { + var ( + w = &bytebufferpool.ByteBuffer{} + err error + ) + + if fileEncoding == "br" { + zw := acquireStacklessBrotliWriter(w, CompressDefaultCompression) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessBrotliWriter(zw, CompressDefaultCompression) + } else if fileEncoding == "gzip" { + zw := acquireStacklessGzipWriter(w, CompressDefaultCompression) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessGzipWriter(zw, CompressDefaultCompression) + } + defer func() { _ = f.Close() }() + + if err != nil { + return nil, fmt.Errorf("error when compressing file %q: %w", filePath, err) + } + + seeker, ok := f.(io.Seeker) + if !ok { + return nil, errors.New("not implemented io.Seeker") + } + if _, err = seeker.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + ext := fileExtension(fileInfo.Name(), false, h.compressedFileSuffixes[fileEncoding]) + contentType := mime.TypeByExtension(ext) + if len(contentType) == 0 { + data, err := readFileHeader(f, false, fileEncoding) + if err != nil { + return nil, fmt.Errorf("cannot read header of the file %q: %w", fileInfo.Name(), err) + } + contentType = http.DetectContentType(data) + } + + dirIndex := w.B + lastModified := fileInfo.ModTime() + ff := &fsFile{ + h: h, + dirIndex: dirIndex, + contentType: contentType, + contentLength: len(dirIndex), + compressed: true, + lastModified: lastModified, + lastModifiedStr: AppendHTTPDate(nil, lastModified), + + t: time.Now(), + } + + return ff, nil +} + func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (*fsFile, error) { - f, err := os.Open(filePath) + f, err := h.fs.Open(filePath) if err != nil { return nil, fmt.Errorf("cannot open compressed file %q: %w", filePath, err) } @@ -1268,7 +1406,7 @@ func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (* _ = f.Close() return nil, fmt.Errorf("cannot obtain info for compressed file %q: %w", filePath, err) } - return h.newFSFile(f, fileInfo, true, fileEncoding) + return h.newFSFile(f, fileInfo, true, filePath, fileEncoding) } func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding string) (*fsFile, error) { @@ -1277,9 +1415,9 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding filePath += h.compressedFileSuffixes[fileEncoding] } - f, err := os.Open(filePath) + f, err := h.fs.Open(filePath) if err != nil { - if mustCompress && os.IsNotExist(err) { + if mustCompress && errors.Is(err, fs.ErrNotExist) { return h.compressAndOpenFSFile(filePathOriginal, fileEncoding) } return nil, err @@ -1301,7 +1439,7 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding } if mustCompress { - fileInfoOriginal, err := os.Stat(filePathOriginal) + fileInfoOriginal, err := fs.Stat(h.fs, filePathOriginal) if err != nil { _ = f.Close() return nil, fmt.Errorf("cannot obtain info for original file %q: %w", filePathOriginal, err) @@ -1318,10 +1456,10 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding } } - return h.newFSFile(f, fileInfo, mustCompress, fileEncoding) + return h.newFSFile(f, fileInfo, mustCompress, filePath, fileEncoding) } -func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, fileEncoding string) (*fsFile, error) { +func (h *fsHandler) newFSFile(f fs.File, fileInfo fs.FileInfo, compressed bool, filePath, fileEncoding string) (*fsFile, error) { n := fileInfo.Size() contentLength := int(n) if n != int64(contentLength) { @@ -1335,7 +1473,7 @@ func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, if len(contentType) == 0 { data, err := readFileHeader(f, compressed, fileEncoding) if err != nil { - return nil, fmt.Errorf("cannot read header of the file %q: %w", f.Name(), err) + return nil, fmt.Errorf("cannot read header of the file %q: %w", fileInfo.Name(), err) } contentType = http.DetectContentType(data) } @@ -1344,6 +1482,7 @@ func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, ff := &fsFile{ h: h, f: f, + filename: filePath, contentType: contentType, contentLength: contentLength, compressed: compressed, @@ -1355,7 +1494,7 @@ func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, return ff, nil } -func readFileHeader(f *os.File, compressed bool, fileEncoding string) ([]byte, error) { +func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, error) { r := io.Reader(f) var ( br *brotli.Reader @@ -1381,7 +1520,11 @@ func readFileHeader(f *os.File, compressed bool, fileEncoding string) ([]byte, e N: 512, } data, err := io.ReadAll(lr) - if _, err := f.Seek(0, 0); err != nil { + seeker, ok := f.(io.Seeker) + if !ok { + return nil, errors.New("must implement io.Seeker") + } + if _, err := seeker.Seek(0, io.SeekStart); err != nil { return nil, err } @@ -1456,3 +1599,10 @@ func getFileLock(absPath string) *sync.Mutex { filelock := v.(*sync.Mutex) return filelock } + +var _ fs.FS = (*osFS)(nil) + +type osFS struct{} + +func (o *osFS) Open(name string) (fs.File, error) { return os.Open(name) } +func (o *osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(name) } diff --git a/fs_fs_test.go b/fs_fs_test.go new file mode 100644 index 0000000000..11f143facf --- /dev/null +++ b/fs_fs_test.go @@ -0,0 +1,640 @@ +package fasthttp + +import ( + "bufio" + "bytes" + "embed" + "os" + "runtime" + "strings" + "testing" + "time" +) + +//go:embed fasthttputil fs.go README.md testdata examples +var fsTestFilesystem embed.FS + +func TestFSServeFileHead(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.Header.SetMethod(MethodHead) + req.SetRequestURI("http://foobar.com/baz") + ctx.Init(&req, nil, nil) + + ServeFS(&ctx, fsTestFilesystem, "fs.go") + + var resp Response + resp.SkipBody = true + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce := resp.Header.ContentEncoding() + if len(ce) > 0 { + t.Fatalf("Unexpected 'Content-Encoding' %q", ce) + } + + body := resp.Body() + if len(body) > 0 { + t.Fatalf("unexpected response body %q. Expecting empty body", body) + } + + expectedBody, err := getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + contentLength := resp.Header.ContentLength() + if contentLength != len(expectedBody) { + t.Fatalf("unexpected Content-Length: %d. expecting %d", contentLength, len(expectedBody)) + } +} + +func TestFSServeFileCompressed(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + var resp Response + + // request compressed gzip file + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip") + ServeFS(&ctx, fsTestFilesystem, "fs.go") + + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce := resp.Header.ContentEncoding() + if string(ce) != "gzip" { + t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") + } + + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expectedBody, err := getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(body, expectedBody) { + t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) + } + + // request compressed brotli file + ctx.Request.Reset() + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "br") + ServeFS(&ctx, fsTestFilesystem, "fs.go") + + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err = resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce = resp.Header.ContentEncoding() + if string(ce) != "br" { + t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "br") + } + + body, err = resp.BodyUnbrotli() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expectedBody, err = getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(body, expectedBody) { + t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) + } +} + +func TestFSFSByteRangeConcurrent(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: fsTestFilesystem, + Root: "", + AcceptByteRange: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + concurrency := 10 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + for j := 0; j < 5; j++ { + testFSByteRange(t, h, "/fs.go") + testFSByteRange(t, h, "/README.md") + } + ch <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-time.After(time.Second): + t.Fatalf("timeout") + case <-ch: + } + } +} + +func TestFSFSByteRangeSingleThread(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: fsTestFilesystem, + Root: ".", + AcceptByteRange: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + testFSByteRange(t, h, "/fs.go") + testFSByteRange(t, h, "/README.md") +} + +func TestFSFSCompressConcurrent(t *testing.T) { + t.Parallel() + // go 1.16 timeout may occur + if strings.HasPrefix(runtime.Version(), "go1.16") { + t.SkipNow() + } + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: fsTestFilesystem, + Root: ".", + GenerateIndexPages: true, + Compress: true, + CompressBrotli: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + concurrency := 4 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + for j := 0; j < 5; j++ { + testFSFSCompress(t, h, "/fs.go") + testFSFSCompress(t, h, "/examples/") + testFSFSCompress(t, h, "/README.md") + } + ch <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-ch: + case <-time.After(time.Second * 2): + t.Fatalf("timeout") + } + } +} + +func TestFSFSCompressSingleThread(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: fsTestFilesystem, + Root: ".", + GenerateIndexPages: true, + Compress: true, + CompressBrotli: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + testFSFSCompress(t, h, "/fs.go") + testFSFSCompress(t, h, "/examples/") + testFSFSCompress(t, h, "/README.md") +} + +func testFSFSCompress(t *testing.T, h RequestHandler, filePath string) { + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + var resp Response + + // request uncompressed file + ctx.Request.Reset() + ctx.Request.SetRequestURI(filePath) + h(&ctx) + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Errorf("unexpected error: %v. filePath=%q", err, filePath) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + } + ce := resp.Header.ContentEncoding() + if string(ce) != "" { + t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) + } + body := string(resp.Body()) + + // request compressed gzip file + ctx.Request.Reset() + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip") + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Errorf("unexpected error: %v. filePath=%q", err, filePath) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "gzip" { + t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) + } + zbody, err := resp.BodyGunzip() + if err != nil { + t.Errorf("unexpected error when gunzipping response body: %v. filePath=%q", err, filePath) + } + if string(zbody) != body { + t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath) + } + + // request compressed brotli file + ctx.Request.Reset() + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.Set(HeaderAcceptEncoding, "br") + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %v. filePath=%q", err, filePath) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "br" { + t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "br", filePath) + } + zbody, err = resp.BodyUnbrotli() + if err != nil { + t.Errorf("unexpected error when unbrotling response body: %v. filePath=%q", err, filePath) + } + if string(zbody) != body { + t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath) + } +} + +func TestFSServeFileContentType(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://foobar.com/baz") + ctx.Init(&req, nil, nil) + + ServeFS(&ctx, fsTestFilesystem, "testdata/test.png") + + var resp Response + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []byte("image/png") + if !bytes.Equal(resp.Header.ContentType(), expected) { + t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, resp.Header.ContentType()) + } +} + +func TestFSServeFileDirectoryRedirect(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.SetRequestURI("http://foobar.com") + ctx.Init(&req, nil, nil) + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, fsTestFilesystem, "fasthttputil") + if ctx.Response.StatusCode() != StatusFound { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil' without trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusFound) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, fsTestFilesystem, "fasthttputil/") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil/' with trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, fsTestFilesystem, "fs.go") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for file '/fs.go'. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } +} + +// //* +// *// +var dirTestFilesystem = os.DirFS(".") + +func TestDirFSServeFileHead(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.Header.SetMethod(MethodHead) + req.SetRequestURI("http://foobar.com/baz") + ctx.Init(&req, nil, nil) + + ServeFS(&ctx, dirTestFilesystem, "fs.go") + + var resp Response + resp.SkipBody = true + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce := resp.Header.ContentEncoding() + if len(ce) > 0 { + t.Fatalf("Unexpected 'Content-Encoding' %q", ce) + } + + body := resp.Body() + if len(body) > 0 { + t.Fatalf("unexpected response body %q. Expecting empty body", body) + } + + expectedBody, err := getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + contentLength := resp.Header.ContentLength() + if contentLength != len(expectedBody) { + t.Fatalf("unexpected Content-Length: %d. expecting %d", contentLength, len(expectedBody)) + } +} + +func TestDirFSServeFileCompressed(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + var resp Response + + // request compressed gzip file + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip") + ServeFS(&ctx, dirTestFilesystem, "fs.go") + + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce := resp.Header.ContentEncoding() + if string(ce) != "gzip" { + t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") + } + + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expectedBody, err := getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(body, expectedBody) { + t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) + } + + // request compressed brotli file + ctx.Request.Reset() + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "br") + ServeFS(&ctx, fsTestFilesystem, "fs.go") + + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err = resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ce = resp.Header.ContentEncoding() + if string(ce) != "br" { + t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "br") + } + + body, err = resp.BodyUnbrotli() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expectedBody, err = getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(body, expectedBody) { + t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) + } +} + +func TestDirFSFSByteRangeConcurrent(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: dirTestFilesystem, + Root: "", + AcceptByteRange: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + concurrency := 10 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + for j := 0; j < 5; j++ { + testFSByteRange(t, h, "/fs.go") + testFSByteRange(t, h, "/README.md") + } + ch <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-time.After(time.Second): + t.Fatalf("timeout") + case <-ch: + } + } +} + +func TestDirFSFSByteRangeSingleThread(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: dirTestFilesystem, + Root: ".", + AcceptByteRange: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + testFSByteRange(t, h, "/fs.go") + testFSByteRange(t, h, "/README.md") +} + +func TestDirFSFSCompressConcurrent(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: dirTestFilesystem, + Root: ".", + GenerateIndexPages: true, + Compress: true, + CompressBrotli: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + concurrency := 4 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + for j := 0; j < 5; j++ { + testFSFSCompress(t, h, "/fs.go") + testFSFSCompress(t, h, "/examples/") + testFSFSCompress(t, h, "/README.md") + } + ch <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-ch: + case <-time.After(time.Second * 2): + t.Fatalf("timeout") + } + } +} + +func TestDirFSFSCompressSingleThread(t *testing.T) { + t.Parallel() + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + FS: dirTestFilesystem, + Root: ".", + GenerateIndexPages: true, + Compress: true, + CompressBrotli: true, + CleanStop: stop, + } + h := fs.NewRequestHandler() + + testFSFSCompress(t, h, "/fs.go") + testFSFSCompress(t, h, "/examples/") + testFSFSCompress(t, h, "/README.md") +} + +func TestDirFSServeFileContentType(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://foobar.com/baz") + ctx.Init(&req, nil, nil) + + ServeFS(&ctx, dirTestFilesystem, "testdata/test.png") + + var resp Response + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []byte("image/png") + if !bytes.Equal(resp.Header.ContentType(), expected) { + t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, resp.Header.ContentType()) + } +} + +func TestDirFSServeFileDirectoryRedirect(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + req.SetRequestURI("http://foobar.com") + ctx.Init(&req, nil, nil) + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, dirTestFilesystem, "fasthttputil") + if ctx.Response.StatusCode() != StatusFound { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil' without trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusFound) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, dirTestFilesystem, "fasthttputil/") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil/' with trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFS(&ctx, dirTestFilesystem, "fs.go") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for file '/fs.go'. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } +}