diff --git a/Sources/SSHClient/Internal/Command/SSHCommandSession.swift b/Sources/SSHClient/Internal/Command/SSHCommandSession.swift index 2b59b9e..d73ec31 100644 --- a/Sources/SSHClient/Internal/Command/SSHCommandSession.swift +++ b/Sources/SSHClient/Internal/Command/SSHCommandSession.swift @@ -45,6 +45,8 @@ private class SSHCommandHandler: ChannelDuplexHandler { private let invocation: SSHCommandInvocation private let promise: Promise + private var isSuccess = false + // MARK: - Life Cycle init(invocation: SSHCommandInvocation, @@ -53,10 +55,6 @@ private class SSHCommandHandler: ChannelDuplexHandler { self.promise = promise } - deinit { - promise.fail(SSHConnectionError.unknown) - } - func handlerAdded(context: ChannelHandlerContext) { let execRequest = SSHChannelRequestEvent.ExecRequest( command: invocation.command.command, @@ -64,23 +62,27 @@ private class SSHCommandHandler: ChannelDuplexHandler { ) context .channel - .setOption(ChannelOptions.allowRemoteHalfClosure, value: true) - .flatMap { +// .setOption(ChannelOptions.allowRemoteHalfClosure, value: true) + .eventLoop.flatSubmit { context.triggerUserOutboundEvent(execRequest) } .whenFailure { _ in - context.close(promise: nil) + context.channel.close(promise: nil) } } func errorCaught(context: ChannelHandlerContext, error: Error) { context.channel.close(promise: nil) - promise.fail(SSHConnectionError.unknown) context.fireErrorCaught(error) } - func handlerRemoved(context: ChannelHandlerContext) { - promise.succeed(()) + func channelInactive(context: ChannelHandlerContext) { + if isSuccess { + promise.succeed(()) + } else { + promise.fail(SSHShellError.unknown) + } + context.fireChannelInactive() } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { @@ -93,6 +95,7 @@ private class SSHCommandHandler: ChannelDuplexHandler { case let event as ChannelEvent: switch event { case .inputClosed: + isSuccess = true context.channel.close(promise: nil) case .outputClosed: break @@ -115,7 +118,6 @@ private class SSHCommandHandler: ChannelDuplexHandler { switch channelData.type { case .channel: invocation.onChunk?(.init(channel: .standard, data: data)) - return case .stdErr: invocation.onChunk?(.init(channel: .error, data: data)) default: diff --git a/Sources/SSHClient/Internal/Shell/IOSSHShell.swift b/Sources/SSHClient/Internal/Shell/IOSSHShell.swift index c6c01bc..ef8ba82 100644 --- a/Sources/SSHClient/Internal/Shell/IOSSHShell.swift +++ b/Sources/SSHClient/Internal/Shell/IOSSHShell.swift @@ -117,3 +117,5 @@ class IOSSHShell { } } } + +extension IOSSHShell: SSHSession {} diff --git a/Sources/SSHClient/Internal/Shell/SSHShellHandler.swift b/Sources/SSHClient/Internal/Shell/SSHShellHandler.swift index 2b95474..99b8dab 100644 --- a/Sources/SSHClient/Internal/Shell/SSHShellHandler.swift +++ b/Sources/SSHClient/Internal/Shell/SSHShellHandler.swift @@ -25,9 +25,9 @@ class StartShellHandler: ChannelInboundHandler { _ = context .channel .eventLoop -// TODO (gz): Move option to bootstrapper -// https://forums.swift.org/t/unit-testing-channeloptions/51797 -// .setOption(ChannelOptions.allowRemoteHalfClosure, value: true) + // TODO: (gz): Move option to bootstrapper + // https://forums.swift.org/t/unit-testing-channeloptions/51797 + // .setOption(ChannelOptions.allowRemoteHalfClosure, value: true) .flatSubmit { let promise = context.channel.eventLoop.makePromise(of: Void.self) let request = SSHChannelRequestEvent.ShellRequest(wantReply: true) diff --git a/Sources/SSHClient/SSHCommand.swift b/Sources/SSHClient/SSHCommand.swift index 96ca012..4f596f6 100644 --- a/Sources/SSHClient/SSHCommand.swift +++ b/Sources/SSHClient/SSHCommand.swift @@ -1,17 +1,17 @@ import Foundation -public struct SSHCommandStatus: Sendable { +public struct SSHCommandStatus: Sendable, Hashable { public let exitStatus: Int } -public enum SSHCommandResponseChunk: Sendable { +public enum SSHCommandResponseChunk: Sendable, Hashable { case chunk(SSHCommandChunk) case status(SSHCommandStatus) } -public struct SSHCommandChunk: Sendable { - public enum Channel: Sendable { +public struct SSHCommandChunk: Sendable, Hashable { + public enum Channel: Sendable, Hashable { case standard case error } @@ -20,7 +20,7 @@ public struct SSHCommandChunk: Sendable { public let data: Data } -public struct SSHCommand: Sendable { +public struct SSHCommand: Sendable, Hashable { public let command: String public init(_ command: String) { diff --git a/Tests/SSHClientTests/Unit/IOSSHCommandTests.swift b/Tests/SSHClientTests/Unit/IOSSHCommandTests.swift new file mode 100644 index 0000000..3ebe66b --- /dev/null +++ b/Tests/SSHClientTests/Unit/IOSSHCommandTests.swift @@ -0,0 +1,113 @@ + +import Foundation +import NIOCore +import NIOEmbedded +import NIOSSH +@testable import SSHClient +import XCTest + +class IOSSHCommandTests: XCTestCase { + func testSimpleInvocation() throws { + let context = IOSSHCommandTestsContext(command: SSHCommand("echo")) + let future = try context.assertStart() + try context.serverEnd() + XCTAssertNoThrow(try future.wait()) + } + + func testRegularInvocation() throws { + let context = IOSSHCommandTestsContext(command: SSHCommand("echo main")) + let future = try context.assertStart() + var isCompleted = false + future.whenComplete { _ in isCompleted = true } + let data = context.harness.channel.triggerInboundChannelString("main") + XCTAssertEqual( + [SSHCommandChunk(channel: .standard, data: data)], + context.chunks + ) + let error = context.harness.channel.triggerInboundSTDErrString("error") + XCTAssertEqual( + [ + SSHCommandChunk(channel: .standard, data: data), + SSHCommandChunk(channel: .error, data: error), + ], + context.chunks + ) + let exitStatus = SSHChannelRequestEvent.ExitStatus(exitStatus: 1) + context.harness.channel.triggerInbound(exitStatus) + XCTAssertEqual( + [SSHCommandStatus(exitStatus: exitStatus.exitStatus)], + context.status + ) + XCTAssertFalse(isCompleted) + try context.serverEnd() + XCTAssertNoThrow(try future.wait()) + } + + func testChannelClosingOnInputClosed() throws { + let context = IOSSHCommandTestsContext(command: SSHCommand("echo")) + let future = try context.assertStart() + try context.harness.channel.close().wait() + context.harness.run() + XCTAssertThrowsError(try future.wait()) + } + + func testChannelClosingOnError() throws { + let context = IOSSHCommandTestsContext(command: SSHCommand("echo")) + let future = try context.assertStart() + context.harness.channel.fireErrorCaught() + context.harness.run() + XCTAssertThrowsError(try future.wait()) + } + + func testChannelClosingOnOutboundFailure() throws { + let context = IOSSHCommandTestsContext(command: SSHCommand("echo")) + context.harness.channel.shouldFailOnOutboundEvent = true + try context.harness.channel.connect().wait() + let futur = try context.harness.start(context.session) + context.harness.run() + XCTAssertThrowsError(try futur.wait()) + } +} + +private class IOSSHCommandTestsContext { + private(set) var invocation: SSHCommandInvocation! + private(set) var session: SSHCommandSession! + let harness = SSHSessionHarness() + + private(set) var chunks: [SSHCommandChunk] = [] + private(set) var status: [SSHCommandStatus] = [] + + private let channel = EmbeddedSSHChannel() + + init(command: SSHCommand) { + invocation = SSHCommandInvocation( + command: command, + onChunk: { self.chunks.append($0) }, + onStatus: { self.status.append($0) } + ) + session = SSHCommandSession(invocation: invocation) + } + + func assertStart() throws -> Future { + try channel.connect().wait() + let promise = try harness.start(session) + harness.channel.run() + XCTAssertTrue(harness.channel.isActive) + XCTAssertEqual( + harness.channel.outboundEvents, + [SSHChannelRequestEvent.ExecRequest( + command: invocation.command.command, + wantReply: true + )] + ) + XCTAssertEqual(chunks, []) + XCTAssertEqual(status, []) + return promise + } + + func serverEnd() throws { + let closing = ChannelEvent.inputClosed + harness.channel.triggerInbound(closing) + harness.run() + } +} diff --git a/Tests/SSHClientTests/Unit/IOSSHShellTests.swift b/Tests/SSHClientTests/Unit/IOSSHShellTests.swift index fc86acc..db69c73 100644 --- a/Tests/SSHClientTests/Unit/IOSSHShellTests.swift +++ b/Tests/SSHClientTests/Unit/IOSSHShellTests.swift @@ -1,108 +1,94 @@ import Foundation -import XCTest import NIOCore import NIOEmbedded import NIOSSH @testable import SSHClient +import XCTest class IOSSHShellTests: XCTestCase { - func testSuccessfulSSHChannelStart() throws { - let shell = EmbeddedIOShell() - try shell.assertStart() - } - - func testFailedSSHChannelStart() throws { - let shell = EmbeddedIOShell() - let promise = try shell.start() - shell.run() - try shell.channel.close().wait() - shell.run() - XCTAssertEqual(shell.recordedData, []) - XCTAssertEqual(shell.recordedStates, [.failed(.unknown)]) - XCTAssertThrowsError(try promise.futureResult.wait()) + let context = IOShellContext() + try context.assertStart() } func testFailedOnOutboundSSHChannelStart() throws { - let shell = EmbeddedIOShell() - let promise = try shell.start() - shell.channel.shouldFailOnOutboundEvent = true - XCTAssertTrue(shell.channel.isActive) - shell.run() - XCTAssertFalse(shell.channel.isActive) - XCTAssertEqual(shell.recordedData, []) - XCTAssertEqual(shell.recordedStates, [.failed(.unknown)]) - XCTAssertThrowsError(try promise.futureResult.wait()) + let context = IOShellContext() + context.harness.channel.shouldFailOnOutboundEvent = true + let promise = try context.harness.start(context.shell) + context.harness.channel.run() + XCTAssertThrowsError(try promise.wait()) + XCTAssertEqual(context.recordedData, []) + XCTAssertEqual(context.recordedStates, [.failed(.unknown)]) } func testReadingWhenConnected() throws { - let shell = EmbeddedIOShell() - try shell.assertStart() - let data = shell.channel.triggerInboundChannelString("Data") - shell.run() - XCTAssertEqual(shell.recordedData, [data]) - let error = shell.channel.triggerInboundChannelString("Error") - shell.run() - XCTAssertEqual(shell.recordedData, [data, error]) + let context = IOShellContext() + try context.assertStart() + let data = context.harness.channel.triggerInboundChannelString("Data") + context.harness.channel.run() + XCTAssertEqual(context.recordedData, [data]) + let error = context.harness.channel.triggerInboundChannelString("Error") + context.harness.channel.run() + XCTAssertEqual(context.recordedData, [data, error]) } func testReadingWhenNotConnected() throws { - let shell = EmbeddedIOShell() - let _ = shell.channel.triggerInboundChannelString("Data") - shell.run() - XCTAssertEqual(shell.recordedData, []) + let context = IOShellContext() + let _ = context.harness.channel.triggerInboundChannelString("Data") + context.harness.channel.run() + XCTAssertEqual(context.recordedData, []) } func testWritingWhenConnected() throws { - let shell = EmbeddedIOShell() - try shell.assertStart() + let context = IOShellContext() + try context.assertStart() let data = "Data".data(using: .utf8)! - let future = shell.shell.write(data) - shell.run() + let future = context.shell.write(data) + context.harness.channel.run() try future.wait() XCTAssertEqual( - try shell.channel.readAllOutbound(), + try context.harness.channel.readAllOutbound(), [SSHChannelData(type: .channel, data: .byteBuffer(.init(data: data)))] ) } func testWritingWhenNotConnected() throws { - let shell = EmbeddedIOShell() + let context = IOShellContext() let data = "Data".data(using: .utf8)! - let future = shell.shell.write(data) - shell.run() + let future = context.shell.write(data) + context.harness.channel.run() XCTAssertThrowsError(try future.wait()) } func testServerDisconnection() throws { - let shell = EmbeddedIOShell() - try shell.assertStart() - try shell.channel.close().wait() - shell.run() - XCTAssertEqual(shell.recordedStates, [.ready, .failed(.unknown)]) + let context = IOShellContext() + try context.assertStart() + try context.harness.channel.close().wait() + context.harness.channel.run() + XCTAssertEqual(context.recordedStates, [.ready, .failed(.unknown)]) } func testClientDisconnection() throws { - let shell = EmbeddedIOShell() - try shell.assertStart() - let future = shell.shell.close() - shell.run() + let context = IOShellContext() + try context.assertStart() + let future = context.shell.close() + context.harness.channel.run() try future.wait() - XCTAssertEqual(shell.recordedStates, [.ready, .closed]) + XCTAssertEqual(context.recordedStates, [.ready, .closed]) } } -private class EmbeddedIOShell { - - let channel = EmbeddedSSHChannel() +private class IOShellContext { + let harness: SSHSessionHarness let shell: IOSSHShell private(set) var recordedStates: [SSHShell.State] = [] private(set) var recordedData: [Data] = [] init() { - shell = IOSSHShell(eventLoop: channel.loop) + harness = SSHSessionHarness() + shell = IOSSHShell(eventLoop: harness.channel.loop) shell.stateUpdateHandler = { [weak self] state in self?.recordedStates.append(state) } @@ -111,36 +97,20 @@ private class EmbeddedIOShell { } } - func start() throws -> Promise { - let promise = channel.loop.makePromise(of: Void.self) - try channel.connect().wait() - let context = SSHSessionContext( - channel: channel.channel, - promise: promise - ) - shell.start(in: context) - try channel.startMonitoringOutbound() - return promise - } - func assertStart() throws { - let promise = try start() - run() - XCTAssertTrue(channel.isActive) + let promise = try harness.start(shell) + harness.channel.run() + XCTAssertTrue(harness.channel.isActive) XCTAssertEqual( - channel.outboundEvents, + harness.channel.outboundEvents, [SSHChannelRequestEvent.ShellRequest(wantReply: true)] ) XCTAssertEqual(recordedData, []) XCTAssertEqual(recordedStates, []) - channel.triggerInbound(ChannelSuccessEvent()) - run() - try promise.futureResult.wait() + harness.channel.triggerInbound(ChannelSuccessEvent()) + harness.channel.run() + try promise.wait() XCTAssertEqual(recordedData, []) XCTAssertEqual(recordedStates, [.ready]) } - - func run() { - channel.run() - } } diff --git a/Tests/SSHClientTests/Unit/Utils/EmbeddedSSHChannel.swift b/Tests/SSHClientTests/Unit/Utils/EmbeddedSSHChannel.swift index 1aaf9f6..e459a57 100644 --- a/Tests/SSHClientTests/Unit/Utils/EmbeddedSSHChannel.swift +++ b/Tests/SSHClientTests/Unit/Utils/EmbeddedSSHChannel.swift @@ -1,13 +1,32 @@ import Foundation -import XCTest import NIOCore import NIOEmbedded import NIOSSH @testable import SSHClient +import XCTest -class EmbeddedSSHChannel { +struct SSHSessionHarness { + let channel = EmbeddedSSHChannel() + func start(_ session: S) throws -> Future { + let promise = channel.loop.makePromise(of: Void.self) + try channel.connect().wait() + let context = SSHSessionContext( + channel: channel.channel, + promise: promise + ) + try channel.startMonitoringOutbound() + session.start(in: context) + return promise.futureResult + } + + func run() { + channel.run() + } +} + +class EmbeddedSSHChannel { var channel: Channel { embeddedChannel } @@ -32,6 +51,11 @@ class EmbeddedSSHChannel { get { recorder.shouldFailOnOutboundEvent } } + func fireErrorCaught() { + struct AnError: Error {} + embeddedChannel.pipeline.fireErrorCaught(AnError()) + } + func triggerInbound(_ event: Any) { embeddedChannel.pipeline.fireUserInboundEventTriggered(event) } @@ -57,7 +81,7 @@ class EmbeddedSSHChannel { } func readOutbound() throws -> SSHChannelData? { - return try embeddedChannel.readOutbound(as: SSHChannelData.self) + try embeddedChannel.readOutbound(as: SSHChannelData.self) } func readAllOutbound() throws -> [SSHChannelData] { @@ -69,11 +93,11 @@ class EmbeddedSSHChannel { } func connect() -> Future { - return channel.connect(to: try! .init(unixDomainSocketPath: "/fake")) + channel.connect(to: try! .init(unixDomainSocketPath: "/fake")) } func close() -> Future { - return channel.close() + channel.close() } func run() {