From 547c4b59811091bd613564a2df5b8f3ebc72bd8f Mon Sep 17 00:00:00 2001 From: Shulhan Date: Thu, 21 Mar 2024 15:31:44 +0700 Subject: [PATCH] lib/memfs: trim trailing slash ("/") in the path of Get method The MemFS always store directory without slash. If caller request a directory node with slash, it will always return nil. --- lib/http/server.go | 20 ++++++++++++-------- lib/memfs/memfs.go | 3 +++ lib/memfs/memfs_test.go | 2 ++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/http/server.go b/lib/http/server.go index 37ff71dc..56a5d35b 100644 --- a/lib/http/server.go +++ b/lib/http/server.go @@ -345,17 +345,11 @@ func (srv *Server) Stop(wait time.Duration) (err error) { // [ServerOptions.EnableIndexHTML] is true, server will generate list of // content for index.html. func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) { - var ( - nodeIndexHTML *memfs.Node - pathHTML string - err error - ) - if srv.Options.Memfs == nil { return nil, false } - pathHTML = path.Join(reqPath, `index.html`) + var err error node, err = srv.Options.Memfs.Get(reqPath) if err != nil { @@ -363,6 +357,8 @@ func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) { return nil, false } + var pathHTML = path.Join(reqPath, `index.html`) + node, err = srv.Options.Memfs.Get(pathHTML) if err != nil { pathHTML = reqPath + `.html` @@ -375,9 +371,14 @@ func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) { } if node.IsDir() { + var ( + pathHTML = path.Join(reqPath, `index.html`) + nodeIndexHTML *memfs.Node + ) + nodeIndexHTML, err = srv.Options.Memfs.Get(pathHTML) if err == nil { - return nodeIndexHTML, true + return nodeIndexHTML, false } if !srv.Options.EnableIndexHTML { @@ -385,6 +386,9 @@ func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) { } node.GenerateIndexHTML() + + // Do not return isDir=true, to prevent the caller check and + // redirect the user to path with slash. } return node, false diff --git a/lib/memfs/memfs.go b/lib/memfs/memfs.go index a6daf8da..6ef40c2d 100644 --- a/lib/memfs/memfs.go +++ b/lib/memfs/memfs.go @@ -204,6 +204,9 @@ func (mfs *MemFS) Get(path string) (node *Node, err error) { if len(path) == 0 { return nil, fmt.Errorf(`%s: empty path`, logp) } + if path != `/` { + path = strings.TrimSuffix(path, `/`) + } node = mfs.PathNodes.Get(path) if node != nil { diff --git a/lib/memfs/memfs_test.go b/lib/memfs/memfs_test.go index 3c0f3f5e..12651413 100644 --- a/lib/memfs/memfs_test.go +++ b/lib/memfs/memfs_test.go @@ -290,6 +290,8 @@ func TestMemFS_Get(t *testing.T) { }, }, { path: "/include", + }, { + path: "/include/", }, { path: "/include/dir", expErr: os.ErrNotExist.Error(),