Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration tests for LocalSocketShellMain. #2304

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ kt_android_library(
srcs = [
"BlockingPublish.java",
"FileObserverShellMain.kt",
"LocalSocketShellMain.kt",
"ShellCommand.java",
"ShellCommandExecutor.java",
"ShellCommandExecutorServer.java",
"ShellCommandFileObserverExecutorServer.kt",
"ShellCommandLocalSocketExecutorServer.kt",
"ShellExecSharedConstants.java",
"ShellMain.java",
],
Expand All @@ -72,6 +74,8 @@ kt_android_library(
deps = [
":coroutine_file_observer",
":file_observer_protocol",
":local_socket_protocol",
":local_socket_protocol_pb_java_proto_lite",
"//services/speakeasy/java/androidx/test/services/speakeasy:protocol",
"//services/speakeasy/java/androidx/test/services/speakeasy/client",
"//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection",
Expand All @@ -94,6 +98,7 @@ kt_android_library(
"ShellExecutorFactory.java",
"ShellExecutorFileObserverImpl.kt",
"ShellExecutorImpl.java",
"ShellExecutorLocalSocketImpl.kt",
],
idl_srcs = ["Command.aidl"],
visibility = [":export"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.test.services.shellexecutor

import android.util.Log
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.util.concurrent.Executors
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.runInterruptible

/** Variant of ShellMain that uses a LocalSocket to communicate with the client. */
class LocalSocketShellMain {

suspend fun run(args: Array<String>): Int {
val scope = CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
val server = ShellCommandLocalSocketExecutorServer(scope = scope)
server.start()

val processArgs = args.toMutableList()
processArgs.addAll(
processArgs.size - 1,
listOf("-e", ShellExecSharedConstants.BINDER_KEY, server.binderKey()),
)
val pb = ProcessBuilder(processArgs.toList())

val exitCode: Int

try {
val process = pb.start()

val stdinCopier = scope.launch { copyStream("stdin", System.`in`, process.outputStream) }
val stdoutCopier = scope.launch { copyStream("stdout", process.inputStream, System.out) }
val stderrCopier = scope.launch { copyStream("stderr", process.errorStream, System.err) }

runInterruptible { process.waitFor() }
exitCode = process.exitValue()

stdinCopier.cancel() // System.`in`.close() does not force input.read() to return
stdoutCopier.join()
stderrCopier.join()
} finally {
server.stop(100.milliseconds)
}
return exitCode
}

suspend fun copyStream(name: String, input: InputStream, output: OutputStream) {
val buf = ByteArray(1024)
try {
while (true) {
val size = input.read(buf)
if (size == -1) break
output.write(buf, 0, size)
}
output.flush()
} catch (x: IOException) {
Log.e(TAG, "IOException on $name. Terminating.", x)
}
}

companion object {
private const val TAG = "LocalSocketShellMain"

@JvmStatic
public fun main(args: Array<String>) {
System.exit(runBlocking { LocalSocketShellMain().run(args) })
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.test.services.shellexecutor

import android.net.LocalServerSocket
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.Process as AndroidProcess
import android.util.Log
import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey
import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest
import java.io.IOException
import java.io.InterruptedIOException
import java.security.SecureRandom
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout

/** Server that run shell commands for a client talking over a LocalSocket. */
final class ShellCommandLocalSocketExecutorServer
@JvmOverloads
constructor(
private val scope: CoroutineScope =
CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
) {
// Use the same secret generation as SpeakEasy does.
private val secret = java.lang.Long.toHexString(SecureRandom().nextLong())
lateinit var socket: LocalServerSocket
lateinit var address: LocalSocketAddress
// Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job...
lateinit var serverJob: Job
// ...while all the child jobs are under a single SupervisorJob that we can join later.
val shellJobs = SupervisorJob()
val running = AtomicBoolean(true)

/** Returns the binder key to pass to client processes. */
fun binderKey(): String {
// The address can contain spaces, and since it gets passed through a command line, we need to
// encode it. java.net.URLEncoder is conveniently available in all SDK versions.
return address.asBinderKey(secret)
}

/** Runs a simple server. */
private suspend fun server() = coroutineScope {
while (running.get()) {
val connection =
try {
runInterruptible { socket.accept() }
} catch (x: Exception) {
// None of my tests have managed to trigger this one.
Log.e(TAG, "LocalServerSocket.accept() failed", x)
break
}
launch(scope.coroutineContext + shellJobs) { handleConnection(connection) }
}
}

/**
* Relays the output of process to connection with a series of RunCommandResponses.
*
* @param process The process to relay output from.
* @param connection The connection to relay output to.
* @return false if there was a problem, true otherwise.
*/
private suspend fun relay(process: Process, connection: LocalSocket): Boolean {
// Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB
// takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s),
// 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms ->
// 75ms).
val buffer = ByteArray(65536)
var size: Int

// LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use
// that as our loop constraint.
while (true) {
try {
size = runInterruptible { process.inputStream.read(buffer) }
if (size < 0) return true // EOF
if (size == 0) {
delay(1.milliseconds)
continue
}
} catch (x: InterruptedIOException) {
// We start getting these at API 24 when the timeout handling kicks in.
Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}")
return false
} catch (x: IOException) {
Log.i(TAG, "Error reading from ${process}; did it time out?", x)
return false
}

if (!connection.sendResponse(buffer = buffer, size = size)) {
return false
}
}
}

/** Handle one connection. */
private suspend fun handleConnection(connection: LocalSocket) {
// connection.localSocketAddress is always null, so no point in logging it.

// Close the connection when done.
connection.use {
val request = connection.readRequest()

if (request.secret.compareTo(secret) != 0) {
Log.w(TAG, "Ignoring request with wrong secret: $request")
return
}

val pb = request.toProcessBuilder()
pb.redirectErrorStream(true)

val process: Process
try {
process = pb.start()
} catch (x: IOException) {
Log.e(TAG, "Failed to start process", x)
connection.sendResponse(
buffer = x.stackTraceToString().toByteArray(),
exitCode = EXIT_CODE_FAILED_TO_START,
)
return
}

// We will not be writing anything to the process' stdin.
process.outputStream.close()

// Close the process' stdout when we're done reading.
process.inputStream.use {
// Launch a coroutine to relay the process' output to the client. If it times out, kill the
// process and cancel the job. This is more coroutine-friendly than using waitFor() to
// handle timeouts.
val ioJob = scope.async { relay(process, connection) }

try {
withTimeout(request.timeout()) {
if (!ioJob.await()) {
Log.w(TAG, "Relaying ${process} output failed")
}
runInterruptible { process.waitFor() }
}
} catch (x: TimeoutCancellationException) {
Log.e(TAG, "Process ${process} timed out after ${request.timeout()}")
process.destroy()
ioJob.cancel()
connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT)
return
}

connection.sendResponse(exitCode = process.exitValue())
}
}
}

/** Starts the server. */
fun start() {
socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}")
address = socket.localSocketAddress
Log.i(TAG, "Starting server on ${address.name}")

// Launch a coroutine to call socket.accept()
serverJob = scope.launch { server() }
}

/** Stops the server. */
fun stop(timeout: Duration) {
running.set(false)
// Closing the socket does not interrupt accept()...
socket.close()
runBlocking(scope.coroutineContext) {
try {
// ...so we simply cancel that job...
serverJob.cancel()
// ...and play nicely with all the shell jobs underneath.
withTimeout(timeout) {
shellJobs.complete()
shellJobs.join()
}
} catch (x: TimeoutCancellationException) {
Log.w(TAG, "Shell jobs did not stop after $timeout", x)
shellJobs.cancel()
}
}
}

private fun RunCommandRequest.timeout(): Duration =
if (timeoutMs <= 0) {
Duration.INFINITE
} else {
timeoutMs.milliseconds
}

/**
* Sets up a ProcessBuilder with information from the request; other configuration is up to the
* caller.
*/
private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder {
val pb = ProcessBuilder(argvList)
val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret!
Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}")
if (environmentMap.isNotEmpty()) {
pb.environment().putAll(environmentMap)
val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ")
Log.i(TAG, "Environment: $env")
}
return pb
}

private companion object {
const val TAG = "SCLSEServer" // up to 23 characters

const val EXIT_CODE_FAILED_TO_START = -1
const val EXIT_CODE_TIMED_OUT = -2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ public ShellExecutorFactory(Context context, String binderKey) {

public ShellExecutor create() {
// Binder keys for SpeakEasy are a string of hex digits. Binder keys for the FileObserver
// protocol are the absolute path of the directory that the server is watching.
if (binderKey.startsWith("/")) {
// protocol are the absolute path of the directory that the server is watching. Binder keys for
// the LocalSocket protocol start and end with a colon.
if (LocalSocketProtocol.isBinderKey(binderKey)) {
return new ShellExecutorLocalSocketImpl(binderKey);
} else if (binderKey.startsWith("/")) {
return new ShellExecutorFileObserverImpl(binderKey);
} else {
return new ShellExecutorImpl(context, binderKey);
Expand Down
Loading
Loading