Skip to content

Commit

Permalink
RSDK-8714: Posetracker wrappers (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
gloriacai01 authored Sep 16, 2024
1 parent a1716af commit 8d664da
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.viam.sdk.core.component.posetracker;

import com.google.protobuf.Struct;
import com.viam.common.v1.Common;
import com.viam.sdk.core.component.Component;
import com.viam.sdk.core.resource.Resource;
import com.viam.sdk.core.resource.Subtype;
import com.viam.sdk.core.robot.RobotClient;

import java.util.List;
import java.util.Map;

/**
* PoseTracker represents a physical pose or motion tracking device.
*/
public abstract class PoseTracker extends Component {

public static final Subtype SUBTYPE = new Subtype(
Subtype.NAMESPACE_RDK,
Subtype.RESOURCE_TYPE_COMPONENT,
"poseTracker");

public PoseTracker(final String name) {
super(SUBTYPE, named(name));
}

/**
* Get the ResourceName of the component
*
* @param name the name of the component
* @return the component's ResourceName
*/
public static Common.ResourceName named(final String name) {
return Resource.named(SUBTYPE, name);
}

/**
* Get the component with the provided name from the provided robot.
* @param robot the RobotClient
* @param name the name of the component
* @return the component
*/
public static PoseTracker fromRobot(final RobotClient robot, final String name) {
return robot.getResource(PoseTracker.class, named(name));
}

/**
* Returns the current pose of each body tracked by the pose tracker.
* @param bodyNames Names of the bodies whose poses are being requested. In the event this parameter is not supplied or is
* an empty list, all available poses are returned.
* @return the mapping of each body name to the pose representing the center of the body
*/
public abstract Map<String, Common.PoseInFrame> getPoses(List<String> bodyNames, Struct extra);


/**
* Returns the current pose of each body tracked by the pose tracker.
* @param bodyNames Names of the bodies whose poses are being requested. In the event this parameter is not supplied or is
* an empty list, all available poses are returned.
* @return the mapping of each body name to the pose representing the center of the body
*/
public Map<String, Common.PoseInFrame> getPoses(List<String> bodyNames){
return getPoses(bodyNames,Struct.getDefaultInstance());
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.viam.sdk.core.component.posetracker;

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.viam.common.v1.Common;
import com.viam.component.v1.PoseTracker.GetPosesRequest;
import com.viam.component.v1.PoseTrackerServiceGrpc;
import com.viam.sdk.core.rpc.Channel;

import java.util.List;
import java.util.Map;
import java.util.Optional;

public class PoseTrackerRPCClient extends PoseTracker {
private final PoseTrackerServiceGrpc.PoseTrackerServiceBlockingStub client;

public PoseTrackerRPCClient(final String name, final Channel chan) {
super(name);
final PoseTrackerServiceGrpc.PoseTrackerServiceBlockingStub client = PoseTrackerServiceGrpc.newBlockingStub(chan);
if (chan.getCallCredentials().isPresent()) {
this.client = client.withCallCredentials(chan.getCallCredentials().get());
} else {
this.client = client;
}
}

@Override
public Struct doCommand(final Map<String, Value> command) {
return client.doCommand(Common.DoCommandRequest.newBuilder().
setName(getName().getName()).
setCommand(Struct.newBuilder().putAllFields(command).build()).
build()).getResult();
}

@Override
public Map<String, Common.PoseInFrame> getPoses(List<String> bodyNames, Struct extra) {
final GetPosesRequest request = GetPosesRequest.newBuilder().setName(getName().getName()).setExtra(extra).addAllBodyNames(bodyNames).build();
return client.getPoses(request).getBodyPosesMap();
}

@Override
public List<Common.Geometry> getGeometries(Optional<Struct> extra) {
final Common.GetGeometriesRequest.Builder builder = Common.GetGeometriesRequest.newBuilder().
setName(getName().getName());
extra.ifPresent(builder::setExtra);
return client.getGeometries(builder.build()).getGeometriesList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.viam.sdk.core.component.posetracker;

import com.google.protobuf.Struct;
import com.viam.common.v1.Common;
import com.viam.component.v1.PoseTrackerServiceGrpc;
import com.viam.sdk.core.resource.ResourceManager;
import com.viam.sdk.core.resource.ResourceRPCService;
import io.grpc.stub.StreamObserver;
import com.viam.sdk.core.component.posetracker.*;

import java.util.List;
import java.util.Map;
import java.util.Optional;

public class PoseTrackerRPCService extends PoseTrackerServiceGrpc.PoseTrackerServiceImplBase
implements ResourceRPCService<PoseTracker> {

private final ResourceManager manager;

public PoseTrackerRPCService(final ResourceManager manager) {
this.manager = manager;
}

@Override
public void getPoses(com.viam.component.v1.PoseTracker.GetPosesRequest request, StreamObserver<com.viam.component.v1.PoseTracker.GetPosesResponse> responseObserver) {
final PoseTracker poseTracker = getResource(PoseTracker.named(request.getName()));
final Map<String, Common.PoseInFrame> result = poseTracker.getPoses(request.getBodyNamesList(), request.getExtra());
responseObserver.onNext(com.viam.component.v1.PoseTracker.GetPosesResponse.newBuilder().putAllBodyPoses(result).build());
responseObserver.onCompleted();
}

@Override
public void doCommand(Common.DoCommandRequest request,
StreamObserver<Common.DoCommandResponse> responseObserver) {
final PoseTracker poseTracker = getResource(
PoseTracker.named(request.getName())
);
final Struct result = poseTracker.doCommand(request.getCommand().getFieldsMap());
responseObserver.onNext(Common.DoCommandResponse.newBuilder().setResult(result).build());
responseObserver.onCompleted();
}

@Override
public void getGeometries(Common.GetGeometriesRequest request, StreamObserver<Common.GetGeometriesResponse> responseObserver) {
final PoseTracker poseTracker = getResource(
PoseTracker.named(request.getName()));
final List<Common.Geometry> result = poseTracker.getGeometries(Optional.of(request.getExtra()));
responseObserver.onNext(Common.GetGeometriesResponse.newBuilder().addAllGeometries(result).build());
responseObserver.onCompleted();
}

@Override
public Class<PoseTracker> getResourceClass() {
return PoseTracker.class;
}

@Override
public ResourceManager getManager() {
return manager;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.viam.component.movementsensor.v1.MovementSensorServiceGrpc;
import com.viam.component.powersensor.v1.PowerSensorServiceGrpc;
import com.viam.component.sensor.v1.SensorServiceGrpc;
import com.viam.component.v1.PoseTrackerServiceGrpc;
import com.viam.sdk.core.component.base.*;
import com.viam.sdk.core.component.arm.*;
import com.viam.component.servo.v1.ServoServiceGrpc;
Expand All @@ -39,6 +40,7 @@
import com.viam.sdk.core.component.movementsensor.MovementSensor;
import com.viam.sdk.core.component.movementsensor.MovementSensorRPCClient;
import com.viam.sdk.core.component.movementsensor.MovementSensorRPCService;
import com.viam.sdk.core.component.posetracker.*;
import com.viam.sdk.core.component.powersensor.PowerSensor;
import com.viam.sdk.core.component.powersensor.PowerSensorRPCClient;
import com.viam.sdk.core.component.powersensor.PowerSensorRPCService;
Expand Down Expand Up @@ -131,6 +133,12 @@ public class ResourceManager implements Closeable {
MovementSensorRPCService::new,
MovementSensorRPCClient::new
));
Registry.registerSubtype(new ResourceRegistration<>(
PoseTracker.SUBTYPE,
PoseTrackerServiceGrpc.SERVICE_NAME,
PoseTrackerRPCService::new,
PoseTrackerRPCClient::new
));
Registry.registerSubtype(new ResourceRegistration<>(
PowerSensor.SUBTYPE,
PowerSensorServiceGrpc.SERVICE_NAME,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.viam.sdk.core.component.posetracker

import com.google.protobuf.Struct
import com.google.protobuf.Value
import com.viam.common.v1.Common
import com.viam.common.v1.Common.Geometry
import com.viam.common.v1.Common.PoseInFrame
import com.viam.sdk.core.component.posetracker.*;
import com.viam.sdk.core.resource.ResourceManager
import com.viam.sdk.core.rpc.BasicManagedChannel
import io.grpc.inprocess.InProcessChannelBuilder
import io.grpc.inprocess.InProcessServerBuilder
import io.grpc.testing.GrpcCleanupRule
import org.junit.Rule
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.Mockito.*
import java.util.*

class PoseTrackerRPCClientTest {
private lateinit var poseTracker: PoseTracker
private lateinit var client: PoseTrackerRPCClient

@JvmField
@Rule
val grpcCleanupRule: GrpcCleanupRule = GrpcCleanupRule()

@BeforeEach
fun setup() {
poseTracker = mock(
PoseTracker::class.java, withSettings().useConstructor("mock-poseTracker").defaultAnswer(
CALLS_REAL_METHODS
)
)
val resourceManager = ResourceManager(listOf(poseTracker))
val service = PoseTrackerRPCService(resourceManager)
val serviceName = InProcessServerBuilder.generateName()
grpcCleanupRule.register(
InProcessServerBuilder.forName(serviceName).directExecutor().addService(service).build().start()
)
val channel = grpcCleanupRule.register(InProcessChannelBuilder.forName(serviceName).directExecutor().build())
client = PoseTrackerRPCClient("mock-poseTracker", BasicManagedChannel(channel))
}

@Test
fun getPoses(){
val bodyNames = listOf("a", "b")
val pose = Common.Pose.newBuilder().setX(1.0).setY(1.0).setZ(1.0).setOX(2.0).setOY(2.0).setOZ(2.0).setTheta(3.0).build()
val poseFrames = mapOf("a" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("0").build(),
"b" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("1").build())
`when`(poseTracker.getPoses(eq(bodyNames), any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(poseFrames)
val res = client.getPoses(bodyNames)
verify(poseTracker).getPoses(bodyNames, Struct.getDefaultInstance())
assertEquals(poseFrames, res)
}

@Test
fun doCommand() {
val command = mapOf("foo" to Value.newBuilder().setStringValue("bar").build())
doReturn(Struct.newBuilder().putAllFields(command).build()).`when`(poseTracker).doCommand(anyMap())
val response = client.doCommand(command)
verify(poseTracker).doCommand(command)
assertEquals(command, response.fieldsMap)
}

@Test
fun getGeometries() {
doReturn(listOf<Geometry>()).`when`(poseTracker).getGeometries(any())
client.getGeometries(Optional.empty())
verify(poseTracker).getGeometries(any())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.viam.sdk.core.component.posetracker

import com.google.protobuf.Struct
import com.google.protobuf.Value
import com.viam.common.v1.Common
import com.viam.common.v1.Common.Geometry
import com.viam.common.v1.Common.PoseInFrame
import com.viam.component.v1.PoseTrackerServiceGrpc
import com.viam.component.v1.PoseTrackerServiceGrpc.PoseTrackerServiceBlockingStub
import com.viam.sdk.core.resource.ResourceManager
import io.grpc.inprocess.InProcessChannelBuilder
import io.grpc.inprocess.InProcessServerBuilder
import io.grpc.testing.GrpcCleanupRule
import org.junit.Rule
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.Mockito.*
import java.util.*

class PoseTrackerRPCServiceTest {
private lateinit var poseTracker: PoseTracker
private lateinit var client: PoseTrackerServiceBlockingStub

@JvmField
@Rule
val grpcCleanupRule: GrpcCleanupRule = GrpcCleanupRule()

@BeforeEach
fun setup() {
poseTracker = mock(
PoseTracker::class.java, withSettings().useConstructor("mock-poseTracker").defaultAnswer(
CALLS_REAL_METHODS
)
)

val resourceManager = ResourceManager(listOf(poseTracker))
val service = PoseTrackerRPCService(resourceManager)
val serviceName = InProcessServerBuilder.generateName()
grpcCleanupRule.register(
InProcessServerBuilder.forName(serviceName).directExecutor().addService(service).build().start()
)
client = PoseTrackerServiceGrpc.newBlockingStub(
grpcCleanupRule.register(
InProcessChannelBuilder.forName(serviceName).build()
)
)
}

@Test
fun getPoses(){
val bodyNames = listOf("a", "b")
val pose = Common.Pose.newBuilder().setX(1.0).setY(1.0).setZ(1.0).setOX(2.0).setOY(2.0).setOZ(2.0).setTheta(3.0).build()
val poseFrames = mapOf("a" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("0").build(),
"b" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("1").build())
`when`(poseTracker.getPoses(eq(bodyNames), any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(poseFrames)
val request = com.viam.component.v1.PoseTracker.GetPosesRequest.newBuilder().setName(poseTracker.name.name).addAllBodyNames(bodyNames).build()
val res = client.getPoses(request)
verify(poseTracker).getPoses(bodyNames, Struct.getDefaultInstance())
assertEquals(poseFrames, res.bodyPosesMap)
}
@Test
fun doCommand() {
val command =
Struct.newBuilder().putAllFields(mapOf("foo" to Value.newBuilder().setStringValue("bar").build())).build()
doReturn(command).`when`(poseTracker).doCommand(anyMap())
val request = Common.DoCommandRequest.newBuilder().setName(poseTracker.name.name).setCommand(command).build()
val response = client.doCommand(request)
verify(poseTracker).doCommand(command.fieldsMap)
assertEquals(command, response.result)
}

@Test
fun getGeometries() {
doReturn(listOf<Geometry>()).`when`(poseTracker).getGeometries(any())
val request = Common.GetGeometriesRequest.newBuilder().setName(poseTracker.name.name).build()
client.getGeometries(request)
verify(poseTracker).getGeometries(Optional.of(Struct.getDefaultInstance()))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.viam.sdk.core.component.posetracker

import com.google.protobuf.Struct
import com.viam.common.v1.Common
import com.viam.common.v1.Common.PoseInFrame
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.Answers
import org.mockito.Mockito.*

class PoseTrackerTest {
private lateinit var poseTracker: PoseTracker

@BeforeEach
fun setup() {
poseTracker = mock(PoseTracker::class.java, Answers.CALLS_REAL_METHODS)
}

@Test
fun getPoses(){
val bodyNames = listOf("a", "b")
val pose = Common.Pose.newBuilder().setX(1.0).setY(1.0).setZ(1.0).setOX(2.0).setOY(2.0).setOZ(2.0).setTheta(3.0).build()
val poseFrames = mapOf("a" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("0").build(),
"b" to PoseInFrame.newBuilder().setPose(pose).setReferenceFrame("1").build())
`when`(poseTracker.getPoses(eq(bodyNames), any(Struct::class.java) ?: Struct.getDefaultInstance())).thenReturn(poseFrames)
val res = poseTracker.getPoses(bodyNames)
verify(poseTracker).getPoses(bodyNames, Struct.getDefaultInstance())
assertEquals(poseFrames, res)
}
}

0 comments on commit 8d664da

Please sign in to comment.