diff --git a/src/main/kotlin/us/ihmc/cd/RemoteExtension.kt b/src/main/kotlin/us/ihmc/cd/RemoteExtension.kt index 670ccb87..42a619c7 100644 --- a/src/main/kotlin/us/ihmc/cd/RemoteExtension.kt +++ b/src/main/kotlin/us/ihmc/cd/RemoteExtension.kt @@ -4,13 +4,12 @@ import net.schmizz.sshj.SSHClient import net.schmizz.sshj.common.IOUtils import net.schmizz.sshj.connection.channel.direct.Session import net.schmizz.sshj.sftp.SFTPClient +import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts import net.schmizz.sshj.transport.verification.PromiscuousVerifier -import org.apache.commons.exec.OS import org.gradle.api.Action import us.ihmc.build.LogTools import java.io.IOException import java.nio.file.Files -import java.nio.file.Paths import java.util.concurrent.TimeUnit open class RemoteExtension @@ -30,26 +29,29 @@ open class RemoteExtension */ private fun authWithSSHKey(username: String, sshClient: SSHClient) { - val userSSHConfigFolder = Paths.get(System.getProperty("user.home")).resolve(".ssh") + val sshDir = OpenSSHKnownHosts.detectSSHDir() - val list = Files.list(userSSHConfigFolder) - - val privateKeyFiles = arrayListOf() - for (path in list) + if (sshDir.isDirectory) { - if (Files.isRegularFile(path) - && path.fileName.toString() != "config" - && path.fileName.toString() != "known_hosts" - && !path.fileName.toString().endsWith(".pub")) + val list = Files.list(sshDir.toPath()) + + val privateKeyFiles = arrayListOf() + for (path in list) { - val absoluteNormalizedString = path.toAbsolutePath().normalize().toString() - privateKeyFiles.add(absoluteNormalizedString) + if (Files.isRegularFile(path) + && path.fileName.toString() != "config" + && path.fileName.toString() != "known_hosts" + && !path.fileName.toString().endsWith(".pub")) + { + val absoluteNormalizedString = path.toAbsolutePath().normalize().toString() + privateKeyFiles.add(absoluteNormalizedString) + } } - } - LogTools.info("Passing keys to authPublicKey: $privateKeyFiles") + LogTools.info("Passing keys to authPublicKey: $privateKeyFiles") - sshClient.authPublickey(username, *privateKeyFiles.toTypedArray()) + sshClient.authPublickey(username, *privateKeyFiles.toTypedArray()) + } } class RemoteConnection(val ssh: SSHClient, val sftp: SFTPClient) @@ -97,13 +99,16 @@ open class RemoteExtension { val sshClient = SSHClient() - if (OS.isFamilyUnix()) + val sshDir = OpenSSHKnownHosts.detectSSHDir() + + if (sshDir.resolve("known_hosts").isFile) { sshClient.loadKnownHosts() } - else if (OS.isFamilyWindows()) + else { - // TODO: Intelligently try to find known_hosts location on Windows + LogTools.warn("Could not find known_hosts file. Disabling host key verification.") + sshClient.addHostKeyVerifier(PromiscuousVerifier()) }