Skip to content

Commit

Permalink
Add option to remove prefix from path in FileMiddleware (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Jul 21, 2024
1 parent 7060933 commit a563c7b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
30 changes: 29 additions & 1 deletion Sources/Hummingbird/Middleware/FileMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,28 @@ public protocol FileMiddlewareFileAttributes {
public struct FileMiddleware<Context: RequestContext, Provider: FileProvider>: RouterMiddleware where Provider.FileAttributes: FileMiddlewareFileAttributes {
let cacheControl: CacheControl
let searchForIndexHtml: Bool
let urlBasePath: String?
let fileProvider: Provider

/// Create FileMiddleware
/// - Parameters:
/// - rootFolder: Root folder to look for files
/// - urlBasePath: Prefix to remove from request URL
/// - cacheControl: What cache control headers to include in response
/// - searchForIndexHtml: Should we look for index.html in folders
/// - threadPool: ThreadPool used by file loading
/// - logger: Logger used to output file information
public init(
_ rootFolder: String = "public",
urlBasePath: String? = nil,
cacheControl: CacheControl = .init([]),
searchForIndexHtml: Bool = false,
threadPool: NIOThreadPool = NIOThreadPool.singleton,
logger: Logger = Logger(label: "FileMiddleware")
) where Provider == LocalFileSystem {
self.cacheControl = cacheControl
self.searchForIndexHtml = searchForIndexHtml
self.urlBasePath = urlBasePath.map { String($0.dropSuffix("/")) }
self.fileProvider = LocalFileSystem(
rootFolder: rootFolder,
threadPool: threadPool,
Expand All @@ -72,15 +76,18 @@ public struct FileMiddleware<Context: RequestContext, Provider: FileProvider>: R
/// Create FileMiddleware using custom ``FileProvider``.
/// - Parameters:
/// - fileProvider: File provider
/// - urlBasePath: Prefix to remove from request URL
/// - cacheControl: What cache control headers to include in response
/// - indexHtml: Should we look for index.html in folders
public init(
fileProvider: Provider,
urlBasePath: String? = nil,
cacheControl: CacheControl = .init([]),
searchForIndexHtml: Bool = false
) {
self.cacheControl = cacheControl
self.searchForIndexHtml = searchForIndexHtml
self.urlBasePath = urlBasePath.map { String($0.dropSuffix("/")) }
self.fileProvider = fileProvider
}

Expand All @@ -94,8 +101,12 @@ public struct FileMiddleware<Context: RequestContext, Provider: FileProvider>: R
throw error
}

guard request.method == .get || request.method == .head else {
throw error
}

// Remove percent encoding from URI path
guard let path = request.uri.path.removingPercentEncoding else {
guard var path = request.uri.path.removingPercentEncoding else {
throw HTTPError(.badRequest, message: "Invalid percent encoding in URL")
}

Expand All @@ -104,6 +115,23 @@ public struct FileMiddleware<Context: RequestContext, Provider: FileProvider>: R
throw HTTPError(.badRequest)
}

// Do we have a prefix to remove from the path
if let urlBasePath {
// If path doesnt have prefix then throw error
guard path.hasPrefix(urlBasePath) else {
throw error
}
let subPath = path.dropFirst(urlBasePath.count)
if subPath.first == nil {
path = "/"
} else if subPath.first == "/" {
path = String(subPath)
} else {
// If first character isn't a "/" then the base path isn't a complete folder name
// in this situation, so isn't inside the specified folder
throw error
}
}
// get file attributes and actual file path and ID (It might be an index.html)
let (actualPath, actualID, attributes) = try await self.getFileAttributes(path)
// we have a file so indicate it came from the FileMiddleware
Expand Down
68 changes: 68 additions & 0 deletions Tests/HummingbirdTests/FileMiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,74 @@ class FileMiddlewareTests: XCTestCase {
}
}

func testPathPrefix() async throws {
// echo file provider. Returns file name as contents of file
struct MemoryFileProvider: FileProvider {
let prefix: String
struct FileAttributes: FileMiddlewareFileAttributes {
var isFolder: Bool
var modificationDate: Date { .distantPast }
let size: Int
}

func getFileIdentifier(_ path: String) -> String? {
return path
}

func getAttributes(id path: String) async throws -> FileAttributes? {
return .init(
isFolder: path.last == "/",
size: path.utf8.count
)
}

func loadFile(id path: String, context: some RequestContext) async throws -> ResponseBody {
let buffer = context.allocator.buffer(string: self.prefix + path)
return .init(byteBuffer: buffer)
}

func loadFile(id path: String, range: ClosedRange<Int>, context: some RequestContext) async throws -> ResponseBody {
let buffer = context.allocator.buffer(string: self.prefix + path)
guard let slice = buffer.getSlice(at: range.lowerBound, length: range.count) else { throw HTTPError(.rangeNotSatisfiable) }
return .init(byteBuffer: slice)
}
}
let router = Router()
router.add(middleware: FileMiddleware(fileProvider: MemoryFileProvider(prefix: "memory:/"), urlBasePath: "/test", searchForIndexHtml: true))
router.add(middleware: FileMiddleware(fileProvider: MemoryFileProvider(prefix: "memory2:/"), urlBasePath: "/test2", searchForIndexHtml: true))
let app = Application(responder: router.buildResponder())

try await app.test(.router) { client in
try await client.execute(uri: "/test/hello", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory://hello")
}
try await client.execute(uri: "/test/hello/", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory://hello/index.html")
}
try await client.execute(uri: "/test", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory://index.html")
}
try await client.execute(uri: "/test/", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory://index.html")
}
try await client.execute(uri: "/goodbye", method: .get) { response in
XCTAssertEqual(response.status, .notFound)
}
try await client.execute(uri: "/testHello", method: .get) { response in
XCTAssertEqual(response.status, .notFound)
}
try await client.execute(uri: "/test2/hello", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory2://hello")
}
try await client.execute(uri: "/test2/hello/", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory2://hello/index.html")
}
try await client.execute(uri: "/test2", method: .get) { response in
XCTAssertEqual(String(buffer: response.body), "memory2://index.html")
}
}
}

func testCustomFileProvider() async throws {
// basic file provider
struct MemoryFileProvider: FileProvider {
Expand Down

0 comments on commit a563c7b

Please sign in to comment.