From d872b65dd2113c9f4f01a9bf3493c62f4df74145 Mon Sep 17 00:00:00 2001 From: gloriacai01 Date: Tue, 17 Sep 2024 09:53:13 -0400 Subject: [PATCH] input controller + tests --- core/sdk/build.gradle | 3 + .../sdk/core/resource/ResourceManager.java | 9 + .../viam/sdk/core/component/input/Input.kt | 254 ++++++++++++++++++ .../input/InputControllerRPCClient.kt | 161 +++++++++++ .../input/InputControllerRPCService.kt | 132 +++++++++ .../component/input/InputRPCClientTest.kt | 154 +++++++++++ .../component/input/InputRPCServiceTest.kt | 184 +++++++++++++ .../sdk/core/component/input/InputTest.kt | 59 ++++ 8 files changed, 956 insertions(+) create mode 100644 core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/Input.kt create mode 100644 core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCClient.kt create mode 100644 core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCService.kt create mode 100644 core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCClientTest.kt create mode 100644 core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCServiceTest.kt create mode 100644 core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputTest.kt diff --git a/core/sdk/build.gradle b/core/sdk/build.gradle index 148cfda7b..9902bd8b2 100644 --- a/core/sdk/build.gradle +++ b/core/sdk/build.gradle @@ -9,6 +9,8 @@ buildscript { } plugins { id "com.github.hierynomus.license-report" version "0.16.1" + id "org.jetbrains.kotlin.jvm" version "2.0.0" + } apply plugin: 'kotlin' @@ -16,6 +18,7 @@ apply plugin: 'kotlin' ext.pomDisplayName = "Viam Core SDK" dependencies { + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.9.0-RC.2") api 'io.grpc:grpc-protobuf-lite:1.63.0' api 'io.grpc:grpc-stub:1.63.0' implementation 'org.json:json:20240205' diff --git a/core/sdk/src/main/java/com/viam/sdk/core/resource/ResourceManager.java b/core/sdk/src/main/java/com/viam/sdk/core/resource/ResourceManager.java index a8815d59f..a5cbe62eb 100644 --- a/core/sdk/src/main/java/com/viam/sdk/core/resource/ResourceManager.java +++ b/core/sdk/src/main/java/com/viam/sdk/core/resource/ResourceManager.java @@ -6,6 +6,7 @@ import com.viam.component.camera.v1.CameraServiceGrpc; import com.viam.component.generic.v1.GenericServiceGrpc; import com.viam.component.gripper.v1.GripperServiceGrpc; +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc; import com.viam.component.motor.v1.MotorServiceGrpc; import com.viam.component.movementsensor.v1.MovementSensorServiceGrpc; import com.viam.component.sensor.v1.SensorServiceGrpc; @@ -21,6 +22,7 @@ import com.viam.sdk.core.component.gripper.Gripper; import com.viam.sdk.core.component.gripper.GripperRPCClient; import com.viam.sdk.core.component.gripper.GripperRPCService; +import com.viam.sdk.core.component.input.*; import com.viam.sdk.core.component.motor.Motor; import com.viam.sdk.core.component.motor.MotorRPCClient; import com.viam.sdk.core.component.motor.MotorRPCService; @@ -95,6 +97,13 @@ public class ResourceManager implements Closeable { SensorRPCService::new, SensorRPCClient::new )); + Registry.registerSubtype(new ResourceRegistration<>( + Controller.SUBTYPE, + InputControllerServiceGrpc.SERVICE_NAME, + InputControllerRPCService::new, + InputControllerRPCClient::new + )); + // SERVICES Registry.registerSubtype(new ResourceRegistration<>( diff --git a/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/Input.kt b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/Input.kt new file mode 100644 index 000000000..a56fb0d6d --- /dev/null +++ b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/Input.kt @@ -0,0 +1,254 @@ +package com.viam.sdk.core.component.input + +import com.google.protobuf.Struct +import com.google.protobuf.Timestamp +import com.viam.common.v1.Common.ResourceName +import com.viam.component.inputcontroller.v1.InputController +import com.viam.sdk.core.component.Component +import com.viam.sdk.core.resource.Resource +import com.viam.sdk.core.resource.Subtype +import com.viam.sdk.core.robot.RobotClient + + +/** + * EventType represents the type of input event. + */ +enum class EventType(val value: String) { + + /** Callbacks registered for this event will be called in ADDITION to other registered event callbacks.*/ + ALL_EVENTS("AllEvents"), + + /** Sent at controller initialization, and on reconnects. */ + CONNECT("Connect"), + + /** If unplugged, or wireless/network times out.*/ + DISCONNECT("Disconnect"), + + /** Typical key press.*/ + BUTTON_PRESS("ButtonPress"), + + /** Key release */ + BUTTON_RELEASE("ButtonRelease"), + + /** Key is held down. This wil likely be a repeated event.*/ + BUTTON_HOLD("ButtonHold"), + + /** Both up and down for convenience during registration, not typically emitted.*/ + BUTTON_CHANGE("ButtonChange"), + + /** Absolute position is reported via Value, a la joysticks. */ + POSITION_CHANGE_ABSOLUTE("PositionChangeAbs"), + + /** Relative position is reported via Value, a la mice, or simulating axes with up/down buttons. */ + POSITION_CHANGE_RELATIVE("PositionChangeRel"); + + companion object { + fun fromValue(value: String): EventType = when (value) { + "AllEvents" -> EventType.ALL_EVENTS + "Connect" -> EventType.CONNECT + "Disconnect" -> EventType.DISCONNECT + "ButtonPress" -> EventType.BUTTON_PRESS + "ButtonRelease" -> EventType.BUTTON_RELEASE + "ButtonHold" -> EventType.BUTTON_HOLD + "ButtonChange" -> EventType.BUTTON_CHANGE + "PositionChangeAbs" -> EventType.POSITION_CHANGE_ABSOLUTE + "PositionChangeRel" -> EventType.POSITION_CHANGE_RELATIVE + else -> throw IllegalArgumentException("Unknown event type $value") + } + } +} + + +/** + * Control identifies the input (specific Axis or Button) of a controller. + */ +enum class Control(val value: String) { + + // Axes + ABSOLUTE_X("AbsoluteX"), + ABSOLUTE_Y("AbsoluteY"), + ABSOLUTE_Z("AbsoluteZ"), + ABSOLUTE_RX("AbsoluteRX"), + ABSOLUTE_RY("AbsoluteRY"), + ABSOLUTE_RZ("AbsoluteRZ"), + ABSOLUTE_HAT0_X("AbsoluteHat0X"), + ABSOLUTE_HAT0_Y("AbsoluteHat0Y"), + + // Buttons + BUTTON_SOUTH("ButtonSouth"), + BUTTON_EAST("ButtonEast"), + BUTTON_WEST("ButtonWest"), + BUTTON_NORTH("ButtonNorth"), + BUTTON_LT("ButtonLT"), + BUTTON_RT("ButtonRT"), + BUTTON_LT2("ButtonLT2"), + BUTTON_RT2("ButtonRT2"), + BUTTON_L_THUMB("ButtonLThumb"), + BUTTON_R_THUMB("ButtonRThumb"), + BUTTON_SELECT("ButtonSelect"), + BUTTON_START("ButtonStart"), + BUTTON_MENU("ButtonMenu"), + BUTTON_RECORD("ButtonRecord"), + BUTTON_E_STOP("ButtonEStop"), + + // Pedals + ABSOLUTE_PEDAL_ACCELERATOR("AbsolutePedalAccelerator"), + ABSOLUTE_PEDAL_BRAKE("AbsolutePedalBrake"), + ABSOLUTE_PEDAL_CLUTCH("AbsolutePedalClutch"); + + companion object { + fun fromValue(value: String): Control = when (value) { + "AbsoluteX" -> ABSOLUTE_X + "AbsoluteY" -> ABSOLUTE_Y + "AbsoluteZ" -> ABSOLUTE_Z + "AbsoluteRX" -> ABSOLUTE_RX + "AbsoluteRY" -> ABSOLUTE_RY + "AbsoluteRZ" -> ABSOLUTE_RZ + "AbsoluteHat0X" -> ABSOLUTE_HAT0_X + "AbsoluteHat0Y" -> ABSOLUTE_HAT0_Y + "ButtonSouth" -> BUTTON_SOUTH + "ButtonEast" -> BUTTON_EAST + "ButtonWest" -> BUTTON_WEST + "ButtonNorth" -> BUTTON_NORTH + "ButtonLT" -> BUTTON_LT + "ButtonRT" -> BUTTON_RT + "ButtonLT2" -> BUTTON_LT2 + "ButtonRT2" -> BUTTON_RT2 + "ButtonLThumb" -> BUTTON_L_THUMB + "ButtonRThumb" -> BUTTON_R_THUMB + "ButtonSelect" -> BUTTON_SELECT + "ButtonStart" -> BUTTON_START + "ButtonMenu" -> BUTTON_MENU + "ButtonRecord" -> BUTTON_RECORD + "ButtonEStop" -> BUTTON_E_STOP + "AbsolutePedalAccelerator" -> ABSOLUTE_PEDAL_ACCELERATOR + "AbsolutePedalBrake" -> ABSOLUTE_PEDAL_BRAKE + "AbsolutePedalClutch" -> ABSOLUTE_PEDAL_CLUTCH + else -> { + throw IllegalArgumentException("Unknown control $value") + } + } + } + +} + +class Event(val time: Long, val event: EventType, val control: Control, val value: Double) { + + fun proto(): InputController.Event { + return InputController.Event.newBuilder().setEvent(event.value).setControl(control.value).setValue(value) + .setTime(Timestamp.newBuilder().setSeconds(time).build()).build() + } + + companion object { //not sure if im converting time correctly + fun fromProto(proto: InputController.Event): Event { + return Event( + proto.time.seconds, + EventType.fromValue(proto.event), + Control.fromValue(proto.control), + proto.value + ) + } + } +} +typealias ControlFunction = (Event) -> Unit + +/** + * Controller is a logical "container" more than an actual device + * Could be a single gamepad, or a collection of digitalInterrupts and analogReaders, a keyboard, etc. + */ +abstract class Controller(name: String) : Component(SUBTYPE, named(name)) { + companion object { + @JvmField + val SUBTYPE = Subtype(Subtype.NAMESPACE_RDK, Subtype.RESOURCE_TYPE_COMPONENT, "input_controller") + + /** + * Get the ResourceName of the component + * @param name the name of the component + * @return the component's ResourceName + */ + @JvmStatic + fun named(name: String): ResourceName { + return Resource.named(SUBTYPE, name) + } + + /** + * Get the component with the provided name from the provided robot. + * @param robot the RobotClient + * @param name the name of the component + * @return the component + */ + @JvmStatic + fun fromRobot(robot: RobotClient, name: String): Controller { + return robot.getResource(Controller::class.java, named(name)) + } + } + + /** + * Returns a list of Controls provided by the Controller + * @return List of controls provided by the Controller + */ + abstract fun getControls(extra: Struct): List + + /** + * Returns a list of Controls provided by the Controller + * @return List of controls provided by the Controller + */ + fun getControls(): List { + return getControls(Struct.getDefaultInstance()) + } + + + /** + * Returns the most recent Event for each input (which should be the current state) + * @return the most recent event for each input + */ + abstract fun getEvents(extra: Struct): Map + + /** + * Returns the most recent Event for each input (which should be the current state) + * @return the most recent event for each input + */ + fun getEvents(): Map { + return getEvents(Struct.getDefaultInstance()); + } + + /** + * Register a function that will fire on given EventTypes for a given Control + * @param control the control to register the function for + * @param triggers the events that will trigger the function + * @param function the function to run on specific triggers + */ + abstract fun registerControlCallback( + control: Control, + triggers: List, + function: ControlFunction?, + extra: Struct + ) + + /** + * Register a function that will fire on given EventTypes for a given Control + * @param control the control to register the function for + * @param triggers the events that will trigger the function + * @param function the function to run on specific triggers + */ + fun registerControlCallback(control: Control, triggers: List, function: ControlFunction?) { + registerControlCallback(control, triggers, function, Struct.getDefaultInstance()) + } + + /** + * Directly send an Event (such as a button press) from external code + * @param event the event to trigger + */ + abstract fun triggerEvent(event: Event, extra: Struct) + + /** + * Directly send an Event (such as a button press) from external code + * @param event the event to trigger + */ + fun triggerEvent(event: Event) { + triggerEvent(event, Struct.getDefaultInstance()) + } + + +} + diff --git a/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCClient.kt b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCClient.kt new file mode 100644 index 000000000..9c6d08fe7 --- /dev/null +++ b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCClient.kt @@ -0,0 +1,161 @@ +package com.viam.sdk.core.component.input + +import com.google.protobuf.Struct +import com.google.protobuf.Value +import com.viam.common.v1.Common +import com.viam.common.v1.Common.GetGeometriesRequest +import com.viam.component.inputcontroller.v1.InputController +import com.viam.component.inputcontroller.v1.InputController.TriggerEventRequest +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc.InputControllerServiceBlockingStub +import com.viam.sdk.core.rpc.Channel +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import java.time.Instant +import java.util.* +import java.util.concurrent.locks.Lock +import java.util.concurrent.locks.ReadWriteLock +import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.locks.ReentrantReadWriteLock +import kotlin.concurrent.withLock +import kotlin.jvm.optionals.getOrDefault + +class InputControllerRPCClient(name: String, channel: Channel) : Controller(name) { + private val client: InputControllerServiceBlockingStub + private var callbacks: MutableMap> + private var isStreaming: Boolean + private var lock: ReadWriteLock + private var streamLock: Lock + + init { + val client = InputControllerServiceGrpc.newBlockingStub(channel) + if (channel.callCredentials.isPresent) { + this.client = client.withCallCredentials(channel.callCredentials.get()) + } else { + this.client = client + } + this.callbacks = mutableMapOf() + this.lock = ReentrantReadWriteLock() //ensures only one thread edits properties + this.streamLock = ReentrantLock() //ensures only one stream active + this.isStreaming = false + } + + override fun getControls(extra: Struct): List { + val request = + InputController.GetControlsRequest.newBuilder().setController(this.name.name).setExtra(extra).build() + val response = client.getControls(request) + return response.controlsList.map { control -> Control.fromValue(control) } + } + + override fun getEvents(extra: Struct): Map { + val request = + InputController.GetEventsRequest.newBuilder().setController(this.name.name).setExtra(extra).build() + val response = client.getEvents(request) + return response.eventsList.associate { event -> + Control.fromValue(event.control) to Event.fromProto(event) + } + } + + override fun registerControlCallback( + control: Control, + triggers: List, + function: ControlFunction?, + extra: Struct + ): Unit = runBlocking { + lock.writeLock().lock() + for (trigger in triggers) { + if (trigger == EventType.BUTTON_CHANGE) + callbacks[control] = + mutableMapOf(EventType.BUTTON_PRESS to function, EventType.BUTTON_RELEASE to function) + else + callbacks[control] = mutableMapOf(trigger to function) + } + lock.writeLock().unlock() + + launch { + streamEvents(extra) + } + } + + private suspend fun streamEvents(extra: Struct) { + streamLock.withLock { + if (this.isStreaming) return + this.isStreaming = true + } + if (this.callbacks.isEmpty()) return + val request = InputController.StreamEventsRequest.newBuilder().setController(this.name.name) + .addAllEvents(listOf()).setExtra(extra) + + lock.writeLock().lock() + for ((c, cb) in this.callbacks.entries) { + val evs: List = cb.entries.filter { (_, func) -> func != null }.map { (et, _) -> et.value } + val cancelled = cb.entries.filter { (_, func) -> func == null }.map { (et, _) -> et.value } + val event = InputController.StreamEventsRequest.Events.newBuilder().setControl(c.value).addAllEvents(evs) + .addAllCancelledEvents(cancelled).build() + request.addEvents(event) + } + lock.writeLock().unlock() + + try { + val response = this.client.streamEvents(request.build()) + sendConnectionStatus(true) + for (r in response) { + executeCallback(Event.fromProto(r.event)) + } + + } catch (e: Exception) { + System.err.println(e) + } finally { + sendConnectionStatus(false) + streamLock.withLock { + this.isStreaming = false + } + } + } + + private fun sendConnectionStatus(connected: Boolean) { + for (control in this.callbacks.keys) { + val eventType = if (connected) EventType.CONNECT else EventType.DISCONNECT + val event = Event(Instant.now().epochSecond, eventType, control, value = 0.0) + executeCallback(event) + } + } + + private fun executeCallback(event: Event) { + try { + val cbMap = callbacks[event.control] + val cb = cbMap?.getOrDefault(event.event, null) + val allCB = cbMap?.getOrDefault(EventType.ALL_EVENTS, null) + if (cb != null) { + cb(event) + } + if (allCB != null) { + allCB(event) + } + } catch (e: Exception) { + return + } + } + + override fun triggerEvent(event: Event, extra: Struct) { + val request = + TriggerEventRequest.newBuilder().setController(this.name.name).setEvent(event.proto()).setExtra(extra) + .build() + this.client.triggerEvent(request) + } + + override fun doCommand(command: Map?): Struct { + val request = Common.DoCommandRequest.newBuilder().setName(this.name.name) + .setCommand(Struct.newBuilder().putAllFields(command).build()).build() + val response = this.client.doCommand(request) + return response.result + } + + override fun getGeometries(extra: Optional): List { + val request = GetGeometriesRequest.newBuilder().setName(this.name.name) + .setExtra(extra.getOrDefault(Struct.getDefaultInstance())).build() + val response = this.client.getGeometries(request) + return response.geometriesList + } + +} diff --git a/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCService.kt b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCService.kt new file mode 100644 index 000000000..a9cb90cc5 --- /dev/null +++ b/core/sdk/src/main/kotlin/com/viam/sdk/core/component/input/InputControllerRPCService.kt @@ -0,0 +1,132 @@ +package com.viam.sdk.core.component.input + +import com.viam.common.v1.Common.* +import com.viam.component.inputcontroller.v1.InputController +import com.viam.component.inputcontroller.v1.InputController.* +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc +import com.viam.sdk.core.resource.ResourceManager +import com.viam.sdk.core.resource.ResourceRPCService +import io.grpc.stub.StreamObserver +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.cancelChildren +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import java.util.* + +internal class InputControllerRPCService(private val manager: ResourceManager) : + InputControllerServiceGrpc.InputControllerServiceImplBase(), ResourceRPCService { + + @OptIn(ExperimentalCoroutinesApi::class) + override fun streamEvents( + request: StreamEventsRequest, responseObserver: StreamObserver + ) = runBlocking { + val controller = getResource(Controller.named(request.controller)) + val channel = Channel(1024) + + fun cleanup() { + channel.close() + coroutineContext.cancelChildren() + //unregister events + for (event in request.eventsList) { + val triggers = event.eventsList.map { EventType.fromValue(it) } + if (triggers.isNotEmpty()) controller.registerControlCallback( + Control.fromValue(event.control), triggers, null, request.extra + ) + } + } + + val ctrlFunc = fun(ev: Event) { + launch { + try { + val response = ev.proto() + channel.send(response) + } catch (e: Exception) { + cleanup() + } + } + } + + for (event in request.eventsList) { + val triggers = event.eventsList.map { EventType.fromValue(it) } //handle eerror + if (triggers.isNotEmpty()) controller.registerControlCallback( + Control.fromValue(event.control), + triggers, + ctrlFunc, + request.extra + ) + val cancelledTriggers = event.cancelledEventsList.map { EventType.fromValue(it) } + if (cancelledTriggers.isNotEmpty()) controller.registerControlCallback( + Control.fromValue(event.control), cancelledTriggers, null, request.extra + ) + } + + while (true) { + val receive = channel.receive() + responseObserver.onNext(StreamEventsResponse.newBuilder().setEvent(receive).build()) + if (channel.isEmpty) { + break + } + } + + cleanup() + responseObserver.onCompleted() + } + + override fun getControls( + request: GetControlsRequest, responseObserver: StreamObserver + ) { + val controller = getResource(Controller.named(request.controller)) + val controls = controller.getControls(request.extra) + responseObserver.onNext(GetControlsResponse.newBuilder().addAllControls(controls.map { it.value }).build()) + responseObserver.onCompleted() + } + + override fun getEvents( + request: GetEventsRequest, responseObserver: StreamObserver + ) { + val controller = getResource(Controller.named(request.controller)) + val events = controller.getEvents(request.extra) + responseObserver.onNext(GetEventsResponse.newBuilder().addAllEvents(events.values.map { it.proto() }).build()) + responseObserver.onCompleted() + } + + override fun triggerEvent( + request: TriggerEventRequest, responseObserver: StreamObserver + ) { + val controller = getResource(Controller.named(request.controller)) + try { + controller.triggerEvent(Event.fromProto(request.event), request.extra) + } catch (e: Exception) { + throw e + } + responseObserver.onNext(TriggerEventResponse.newBuilder().build()) + responseObserver.onCompleted() + } + + override fun doCommand( + request: DoCommandRequest, responseObserver: StreamObserver + ) { + val controller = getResource(Controller.named(request.name)) + val result = controller.doCommand(request.command.fieldsMap) + responseObserver.onNext(DoCommandResponse.newBuilder().setResult(result).build()) + responseObserver.onCompleted() + } + + override fun getGeometries( + request: GetGeometriesRequest, responseObserver: StreamObserver + ) { + val controller = getResource(Controller.named(request.name)) + val result = controller.getGeometries(Optional.of(request.extra)) + responseObserver.onNext(GetGeometriesResponse.newBuilder().addAllGeometries(result).build()) + responseObserver.onCompleted() + } + + override fun getResourceClass(): Class { + return Controller::class.java + } + + override fun getManager(): ResourceManager { + return this.manager + } +} \ No newline at end of file diff --git a/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCClientTest.kt b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCClientTest.kt new file mode 100644 index 000000000..e2e83753a --- /dev/null +++ b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCClientTest.kt @@ -0,0 +1,154 @@ +package com.viam.sdk.core.component.input + +import com.google.protobuf.Struct +import com.google.protobuf.Value +import com.viam.common.v1.Common.Geometry +import com.viam.sdk.core.component.input.Controller +import com.viam.sdk.core.resource.ResourceManager +import com.viam.sdk.core.rpc.BasicManagedChannel +import io.grpc.inprocess.InProcessChannelBuilder +import io.grpc.inprocess.InProcessServerBuilder +import io.grpc.testing.GrpcCleanupRule +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.Rule +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Mockito.* +import java.time.Instant +import java.util.* + +class InputRPCClientTest { + + private lateinit var inputController: Controller + private lateinit var client: InputControllerRPCClient + + @JvmField + @Rule + val grpcCleanupRule: GrpcCleanupRule = GrpcCleanupRule() + + @BeforeEach + fun setup() { + inputController = mock( + Controller::class.java, withSettings().useConstructor("mock-inputController").defaultAnswer( + CALLS_REAL_METHODS + ) + ) + val resourceManager = ResourceManager(listOf(inputController)) + val service = InputControllerRPCService(resourceManager) + val serviceName = InProcessServerBuilder.generateName() + grpcCleanupRule.register( + InProcessServerBuilder.forName(serviceName).directExecutor().addService(service).build().start() + ) + val channel = grpcCleanupRule.register(InProcessChannelBuilder.forName(serviceName).directExecutor().build()) + client = InputControllerRPCClient("mock-inputController", BasicManagedChannel(channel)) + } + + @Test + fun getControls() { + val controls = listOf(Control.ABSOLUTE_X) + `when`(inputController.getControls(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(controls) + val response = client.getControls() + verify(inputController).getControls(Struct.getDefaultInstance()) + assertEquals(controls, response) + } + + @Test + fun getEvents() { + val events = + mapOf(Control.ABSOLUTE_X to Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0)) + `when`(inputController.getEvents(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(events) + val response = client.getEvents() + verify(inputController).getEvents(Struct.getDefaultInstance()) + assertEquals(events.keys, response.keys) + + } + + @Test + fun triggerEvent() { + val event = Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0) + client.triggerEvent(event) + verify(inputController).triggerEvent( + (any(Event::class.java) ?: event), + eq(Struct.getDefaultInstance()) ?: Struct.getDefaultInstance() + ) + + } + + @Test + fun registerControlCallback() = runBlocking { + val callbacks = mutableMapOf>() + var callbackCount = 1 + val testEventFun = fun(ev: Event) { + callbackCount += 1 + } + + doAnswer { + val event = it.arguments[0] as Event + val func = callbacks[event.control]?.get(event.event) + if (func != null) { + func(event) + } + null + }.`when`(inputController).triggerEvent( + any(Event::class.java) ?: Event( + Instant.now().epochSecond, + EventType.BUTTON_RELEASE, + Control.BUTTON_START, + 1.0 + ) + ) + + doAnswer { + val control = it.arguments[0] as Control + val triggers = it.arguments[1] as List + val func = it.arguments[2] as ControlFunction? + callbacks[control] = triggers.associateWith { func } + null + }.`when`(inputController).registerControlCallback( + eq(Control.BUTTON_START) ?: Control.BUTTON_START, + eq(listOf(EventType.BUTTON_RELEASE)) ?: listOf(EventType.BUTTON_RELEASE), + any(), + any(Struct::class.java) ?: Struct.getDefaultInstance() + ) + + + launch { + delay(2000L) + for (i in 1..5) { + var ev = Event(Instant.now().epochSecond, EventType.BUTTON_RELEASE, Control.BUTTON_START, i.toDouble()) + inputController.triggerEvent(ev) + + } + } + + client.registerControlCallback(Control.BUTTON_START, listOf(EventType.BUTTON_RELEASE), testEventFun) + verify(inputController, times(2)).registerControlCallback( + eq(Control.BUTTON_START) ?: Control.BUTTON_START, + eq(listOf(EventType.BUTTON_RELEASE)) ?: listOf(EventType.BUTTON_RELEASE), + any(), + eq(Struct.getDefaultInstance()) ?: Struct.getDefaultInstance() + ) // occurs twice: once for registering, once for unregistering during cleanup since this calls streamEvent + assertEquals(callbackCount, 6) + + } + + + @Test + fun doCommand() { + val command = mapOf("foo" to Value.newBuilder().setStringValue("bar").build()) + doReturn(Struct.newBuilder().putAllFields(command).build()).`when`(inputController).doCommand(anyMap()) + val response = client.doCommand(command) + verify(inputController).doCommand(command) + assertEquals(command, response.fieldsMap) + } + + @Test + fun getGeometries() { + doReturn(listOf()).`when`(inputController).getGeometries(any()) + client.getGeometries(Optional.empty()) + verify(inputController).getGeometries(any()) + } +} \ No newline at end of file diff --git a/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCServiceTest.kt b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCServiceTest.kt new file mode 100644 index 000000000..5819b46d2 --- /dev/null +++ b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputRPCServiceTest.kt @@ -0,0 +1,184 @@ +package com.viam.sdk.core.component.input + +import com.google.protobuf.Struct +import com.google.protobuf.Value +import com.viam.common.v1.Common +import com.viam.common.v1.Common.Geometry +import com.viam.component.inputcontroller.v1.InputController +import com.viam.component.inputcontroller.v1.InputController.* +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc +import com.viam.component.inputcontroller.v1.InputControllerServiceGrpc.InputControllerServiceBlockingStub +import com.viam.sdk.core.component.input.Controller +import com.viam.sdk.core.resource.ResourceManager +import io.grpc.inprocess.InProcessChannelBuilder +import io.grpc.inprocess.InProcessServerBuilder +import io.grpc.testing.GrpcCleanupRule +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.Rule +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Mockito.* +import java.time.Instant +import java.util.* +import kotlin.random.Random + +class InputRPCServiceTest { + + private lateinit var inputController: Controller + private lateinit var client: InputControllerServiceBlockingStub + private lateinit var callbacks: MutableMap> + + @JvmField + @Rule + val grpcCleanup: GrpcCleanupRule = GrpcCleanupRule() + + @BeforeEach + fun setup() { + inputController = mock( + Controller::class.java, + withSettings().useConstructor("mock-inputController").defaultAnswer(CALLS_REAL_METHODS) + ) + callbacks = mutableMapOf() + + val resourceManager = ResourceManager(listOf(inputController)) + val service = InputControllerRPCService(resourceManager) + val serviceName = InProcessServerBuilder.generateName() + grpcCleanup.register( + InProcessServerBuilder.forName(serviceName).directExecutor().addService(service).build().start() + ) + client = InputControllerServiceGrpc.newBlockingStub( + grpcCleanup.register( + InProcessChannelBuilder.forName(serviceName).directExecutor().build() + ) + ) + } + + + @Test + fun getControls() { + val controls = listOf(Control.ABSOLUTE_X) + `when`(inputController.getControls(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(controls) + val request = GetControlsRequest.newBuilder().setController(inputController.name.name) + .setExtra(Struct.getDefaultInstance()).build() + val response = client.getControls(request) + verify(inputController).getControls(Struct.getDefaultInstance()) + assertEquals(controls[0], Control.fromValue(response.controlsList[0])) + + } + + @Test + fun getEvents() { + val events = + mapOf(Control.ABSOLUTE_X to Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0)) + `when`(inputController.getEvents(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(events) + val request = + GetEventsRequest.newBuilder().setController(inputController.name.name).setExtra(Struct.getDefaultInstance()) + .build() + val response = client.getEvents(request) + verify(inputController).getEvents(Struct.getDefaultInstance()) + assertEquals(events.values.toList()[0].time, Event.fromProto(response.eventsList[0]).time) + assertEquals(events.values.toList()[0].control, Event.fromProto(response.eventsList[0]).control) + assertEquals(events.values.toList()[0].event, Event.fromProto(response.eventsList[0]).event) + assertEquals(events.values.toList()[0].value, Event.fromProto(response.eventsList[0]).value) + } + + @Test + fun triggerEvents() { + val event = Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0).proto() + val request = + InputController.TriggerEventRequest.newBuilder().setController(inputController.name.name).setEvent(event) + .setExtra(Struct.getDefaultInstance()).build() + client.triggerEvent(request) + verify(inputController).triggerEvent( + (any(Event::class.java) ?: Event.fromProto(event)), + eq(Struct.getDefaultInstance()) ?: Struct.getDefaultInstance() + ) + } + + @Test + fun streamEvents() = runBlocking { + + doAnswer { + val event = it.arguments[0] as Event + val func = callbacks[event.control]?.get(event.event) + if (func != null) { + func(event) + } + null + }.`when`(inputController).triggerEvent( + any(Event::class.java) ?: Event( + Instant.now().epochSecond, EventType.BUTTON_RELEASE, Control.BUTTON_START, 1.0 + ) + ) + + doAnswer { + val control = it.arguments[0] as Control + val triggers = it.arguments[1] as List + val func = it.arguments[2] as ControlFunction? + callbacks[control] = triggers.associateWith { func } + + null + }.`when`(inputController).registerControlCallback( + eq(Control.BUTTON_START) ?: Control.BUTTON_START, + eq(listOf(EventType.BUTTON_RELEASE)) ?: listOf(EventType.BUTTON_RELEASE), + any(), + any(Struct::class.java) ?: Struct.getDefaultInstance() + ) + + + var expected = mutableListOf() + fun createEvents(): MutableList { + val responses: MutableList = mutableListOf() + for (i in 1..5) { + var value = Random.nextDouble() + var ev = Event(Instant.now().epochSecond, EventType.BUTTON_RELEASE, Control.BUTTON_START, value) + inputController.triggerEvent(ev) + responses.add(StreamEventsResponse.newBuilder().setEvent(ev.proto()).build()) + + } + return responses + } + launch { + delay(2000L) + expected = createEvents() + } + + val events = listOf( + StreamEventsRequest.Events.newBuilder().setControl(Control.BUTTON_START.value) + .addAllEvents(listOf(EventType.BUTTON_RELEASE.value)).build() + ) + val request = + StreamEventsRequest.newBuilder().setController(inputController.name.name).addAllEvents(events).build() + val response = client.streamEvents(request) + for ((index, value) in response.withIndex()) { + assertEquals(expected[index], value) + } + + //check unregistered callbacks + assertNull(callbacks[Control.BUTTON_START]!![EventType.BUTTON_RELEASE]) + } + + @Test + fun doCommand() { + val command = + Struct.newBuilder().putAllFields(mapOf("foo" to Value.newBuilder().setStringValue("bar").build())).build() + doReturn(command).`when`(inputController).doCommand(anyMap()) + val request = + Common.DoCommandRequest.newBuilder().setName(inputController.name.name).setCommand(command).build() + val response = client.doCommand(request) + verify(inputController).doCommand(command.fieldsMap) + assertEquals(command, response.result) + } + + @Test + fun getGeometries() { + doReturn(listOf()).`when`(inputController).getGeometries(any()) + val request = Common.GetGeometriesRequest.newBuilder().setName(inputController.name.name).build() + client.getGeometries(request) + verify(inputController).getGeometries(Optional.of(Struct.getDefaultInstance())) + } +} \ No newline at end of file diff --git a/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputTest.kt b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputTest.kt new file mode 100644 index 000000000..58bf9bff0 --- /dev/null +++ b/core/sdk/src/test/kotlin/com/viam/sdk/core/component/input/InputTest.kt @@ -0,0 +1,59 @@ +package com.viam.sdk.core.component.input + +import com.google.protobuf.Struct +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Answers +import org.mockito.Mockito.* +import java.time.Instant + +class InputTest { + private lateinit var inputController: Controller + + @BeforeEach + fun setup() { + inputController = mock(Controller::class.java, Answers.CALLS_REAL_METHODS) + } + + @Test + fun getControls() { + val controls = listOf(Control.ABSOLUTE_X) + `when`(inputController.getControls(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(controls) + val response = inputController.getControls() + verify(inputController).getControls() + assertEquals(controls, response) + } + + @Test + fun getEvents() { + val events = + mapOf(Control.ABSOLUTE_X to Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0)) + `when`(inputController.getEvents(any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(events) + val response = inputController.getEvents() + verify(inputController).getEvents() + assertEquals(events, response) + } + + @Test + fun triggerEvent() { + val event = Event(Instant.now().epochSecond, EventType.CONNECT, Control.ABSOLUTE_X, 0.0) + inputController.triggerEvent(event) + verify(inputController).triggerEvent( + (any(Event::class.java) ?: event), + eq(Struct.getDefaultInstance()) ?: Struct.getDefaultInstance() + ) + } + + @Test + fun registerControlCallback() { + val testEventFun = fun(ev: Event) { return Unit } + inputController.registerControlCallback(Control.BUTTON_START, listOf(EventType.BUTTON_RELEASE), testEventFun) + verify(inputController).registerControlCallback( + Control.BUTTON_START, + listOf(EventType.BUTTON_RELEASE), + testEventFun, + Struct.getDefaultInstance() + ) + } +} \ No newline at end of file