diff --git a/services/vision/colordetector/color_detector.go b/services/vision/colordetector/color_detector.go index 609780b82c5..d65ad163bbc 100644 --- a/services/vision/colordetector/color_detector.go +++ b/services/vision/colordetector/color_detector.go @@ -59,5 +59,5 @@ func registerColorDetector( return nil, errors.Errorf("could not find camera %q", conf.DefaultCamera) } } - return vision.NewService(name, r, nil, nil, detector, nil, conf.DefaultCamera) + return vision.DeprecatedNewService(name, r, nil, nil, detector, nil, conf.DefaultCamera) } diff --git a/services/vision/colordetector/color_detector_test.go b/services/vision/colordetector/color_detector_test.go index 1baafb74ab7..c222cb58d77 100644 --- a/services/vision/colordetector/color_detector_test.go +++ b/services/vision/colordetector/color_detector_test.go @@ -8,6 +8,7 @@ import ( "go.viam.com/utils/artifact" "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" "go.viam.com/rdk/resource" "go.viam.com/rdk/rimage" "go.viam.com/rdk/services/vision" @@ -23,6 +24,7 @@ func TestColorDetector(t *testing.T) { } ctx := context.Background() r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { return nil } name := vision.Named("test_cd") srv, err := registerColorDetector(ctx, name, &inp, r) test.That(t, err, test.ShouldBeNil) @@ -68,6 +70,7 @@ func TestRegistrationWithDefaultCamera(t *testing.T) { } r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { return nil } r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { if name == cameraName { return inject.NewCamera(cameraName.Name), nil diff --git a/services/vision/detectionstosegments/detections_to_3dsegments.go b/services/vision/detectionstosegments/detections_to_3dsegments.go index 9672ea54c81..6871cd8c419 100644 --- a/services/vision/detectionstosegments/detections_to_3dsegments.go +++ b/services/vision/detectionstosegments/detections_to_3dsegments.go @@ -72,5 +72,5 @@ func register3DSegmenterFromDetector( return nil, errors.Errorf("could not find camera %q", conf.DefaultCamera) } } - return vision.NewService(name, r, nil, nil, detector, segmenter, conf.DefaultCamera) + return vision.DeprecatedNewService(name, r, nil, nil, detector, segmenter, conf.DefaultCamera) } diff --git a/services/vision/detectionstosegments/detections_to_3dsegments_test.go b/services/vision/detectionstosegments/detections_to_3dsegments_test.go index 473a3ddf24c..b596ccf6fa7 100644 --- a/services/vision/detectionstosegments/detections_to_3dsegments_test.go +++ b/services/vision/detectionstosegments/detections_to_3dsegments_test.go @@ -11,6 +11,7 @@ import ( "go.viam.com/test" "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" pc "go.viam.com/rdk/pointcloud" "go.viam.com/rdk/resource" "go.viam.com/rdk/rimage" @@ -29,9 +30,10 @@ func (s *simpleDetector) Detect(context.Context, image.Image) ([]objectdetection func Test3DSegmentsFromDetector(t *testing.T) { r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { return nil } m := &simpleDetector{} name := vision.Named("testDetector") - svc, err := vision.NewService(name, r, nil, nil, m.Detect, nil, "") + svc, err := vision.DeprecatedNewService(name, r, nil, nil, m.Detect, nil, "") test.That(t, err, test.ShouldBeNil) cam := &inject.Camera{} cam.NextPointCloudFunc = func(ctx context.Context) (pc.PointCloud, error) { diff --git a/services/vision/fake/vision.go b/services/vision/fake/vision.go index 1cff572253d..a492f8e9a13 100644 --- a/services/vision/fake/vision.go +++ b/services/vision/fake/vision.go @@ -71,5 +71,5 @@ func registerFake( name resource.Name, r robot.Robot, ) (vision.Service, error) { - return vision.NewService(name, r, nil, fakeClassifier, fakeDetector, nil, fakeCameraName) + return vision.DeprecatedNewService(name, r, nil, fakeClassifier, fakeDetector, nil, fakeCameraName) } diff --git a/services/vision/fake/vision_test.go b/services/vision/fake/vision_test.go index 58cbdc32990..a1e4f5f0932 100644 --- a/services/vision/fake/vision_test.go +++ b/services/vision/fake/vision_test.go @@ -7,6 +7,7 @@ import ( "go.viam.com/test" + "go.viam.com/rdk/logging" "go.viam.com/rdk/services/vision" "go.viam.com/rdk/testutils/inject" ) @@ -14,6 +15,9 @@ import ( func TestFakeVision(t *testing.T) { ctx := context.Background() r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { + return nil + } name := vision.Named("test_fake") srv, err := registerFake(name, r) test.That(t, err, test.ShouldBeNil) diff --git a/services/vision/mlvision/ml_model.go b/services/vision/mlvision/ml_model.go index 196df9ed7f7..8d4fde77034 100644 --- a/services/vision/mlvision/ml_model.go +++ b/services/vision/mlvision/ml_model.go @@ -221,7 +221,7 @@ func registerMLModelVisionService( } // Don't return a close function, because you don't want to close the underlying ML service - return vision.NewService(name, r, nil, classifierFunc, detectorFunc, segmenter3DFunc, params.DefaultCamera) + return vision.DeprecatedNewService(name, r, nil, classifierFunc, detectorFunc, segmenter3DFunc, params.DefaultCamera) } func getLabelsFromFile(labelPath string) []string { diff --git a/services/vision/mlvision/ml_model_test.go b/services/vision/mlvision/ml_model_test.go index a5d16c91c58..b9f1b20ff27 100644 --- a/services/vision/mlvision/ml_model_test.go +++ b/services/vision/mlvision/ml_model_test.go @@ -26,7 +26,11 @@ func BenchmarkAddMLVisionModel(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - service, err := registerMLModelVisionService(ctx, name, &modelCfg, &inject.Robot{}, logging.NewLogger("benchmark")) + r := inject.Robot{} + r.LoggerFunc = func() logging.Logger { + return nil + } + service, err := registerMLModelVisionService(ctx, name, &modelCfg, &r, logging.NewLogger("benchmark")) test.That(b, err, test.ShouldBeNil) test.That(b, service, test.ShouldNotBeNil) test.That(b, service.Name(), test.ShouldResemble, name) @@ -42,7 +46,11 @@ func BenchmarkUseMLVisionModel(b *testing.B) { test.That(b, pic, test.ShouldNotBeNil) modelCfg := MLModelConfig{ModelName: name.Name} - service, err := registerMLModelVisionService(ctx, name, &modelCfg, &inject.Robot{}, logging.NewLogger("benchmark")) + r := inject.Robot{} + r.LoggerFunc = func() logging.Logger { + return nil + } + service, err := registerMLModelVisionService(ctx, name, &modelCfg, &r, logging.NewLogger("benchmark")) test.That(b, err, test.ShouldBeNil) test.That(b, service, test.ShouldNotBeNil) test.That(b, service.Name(), test.ShouldResemble, name) @@ -573,6 +581,9 @@ func TestRegistrationWithDefaultCamera(t *testing.T) { modelCfg := MLModelConfig{ModelName: modelName.Name} r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { + return nil + } r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { switch name { case modelName: diff --git a/services/vision/obstaclesdepth/obstacles_depth.go b/services/vision/obstaclesdepth/obstacles_depth.go index 826a400c1e8..b90968d04da 100644 --- a/services/vision/obstaclesdepth/obstacles_depth.go +++ b/services/vision/obstaclesdepth/obstacles_depth.go @@ -101,7 +101,7 @@ func registerObstaclesDepth( } segmenter := myObsDep.buildObsDepth(logger) // does the thing - return svision.NewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) + return svision.DeprecatedNewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) } // BuildObsDepth will check for intrinsics and determine how to build based on that. diff --git a/services/vision/obstaclesdistance/obstacles_distance.go b/services/vision/obstaclesdistance/obstacles_distance.go index 8955fe53772..9696b5ca975 100644 --- a/services/vision/obstaclesdistance/obstacles_distance.go +++ b/services/vision/obstaclesdistance/obstacles_distance.go @@ -120,7 +120,7 @@ func registerObstacleDistanceDetector( return nil, errors.Errorf("could not find camera %q", conf.DefaultCamera) } } - return svision.NewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) + return svision.DeprecatedNewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) } func medianFromPointClouds(ctx context.Context, clouds []pointcloud.PointCloud) (r3.Vector, error) { diff --git a/services/vision/obstaclesdistance/obstacles_distance_test.go b/services/vision/obstaclesdistance/obstacles_distance_test.go index 4cbd102bc05..630e323c161 100644 --- a/services/vision/obstaclesdistance/obstacles_distance_test.go +++ b/services/vision/obstaclesdistance/obstacles_distance_test.go @@ -10,6 +10,7 @@ import ( "go.viam.com/utils/artifact" "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" pc "go.viam.com/rdk/pointcloud" "go.viam.com/rdk/resource" "go.viam.com/rdk/rimage" @@ -37,6 +38,9 @@ func TestObstacleDist(t *testing.T) { cam.NextPointCloudFunc = func(ctx context.Context) (pc.PointCloud, error) { return nil, errors.New("no pointcloud") } + r.LoggerFunc = func() logging.Logger { + return nil + } r.ResourceNamesFunc = func() []resource.Name { return []resource.Name{camera.Named("fakeCamera")} } @@ -139,6 +143,9 @@ func TestRegistrationWithDefaultCamera(t *testing.T) { } r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { + return nil + } r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { if name == cameraName { return inject.NewCamera(cameraName.Name), nil diff --git a/services/vision/obstaclespointcloud/obstacles_pointcloud.go b/services/vision/obstaclespointcloud/obstacles_pointcloud.go index 24da660aa81..0772d1b53ac 100644 --- a/services/vision/obstaclespointcloud/obstacles_pointcloud.go +++ b/services/vision/obstaclespointcloud/obstacles_pointcloud.go @@ -60,5 +60,5 @@ func registerOPSegmenter( return nil, errors.Errorf("could not find camera %q", conf.DefaultCamera) } } - return vision.NewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) + return vision.DeprecatedNewService(name, r, nil, nil, nil, segmenter, conf.DefaultCamera) } diff --git a/services/vision/obstaclespointcloud/obstacles_pointcloud_test.go b/services/vision/obstaclespointcloud/obstacles_pointcloud_test.go index 09619cd44d9..ecaebc9139c 100644 --- a/services/vision/obstaclespointcloud/obstacles_pointcloud_test.go +++ b/services/vision/obstaclespointcloud/obstacles_pointcloud_test.go @@ -10,6 +10,7 @@ import ( "go.viam.com/test" "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" pc "go.viam.com/rdk/pointcloud" "go.viam.com/rdk/resource" "go.viam.com/rdk/services/vision" @@ -19,6 +20,7 @@ import ( func TestRadiusClusteringSegmentation(t *testing.T) { r := &inject.Robot{} + r.LoggerFunc = func() logging.Logger { return nil } cam := &inject.Camera{} cam.NextPointCloudFunc = func(ctx context.Context) (pc.PointCloud, error) { return nil, errors.New("no pointcloud") diff --git a/services/vision/vision.go b/services/vision/vision.go index 824f4e80045..aa5408822ae 100644 --- a/services/vision/vision.go +++ b/services/vision/vision.go @@ -9,18 +9,14 @@ import ( "context" "image" - "github.com/pkg/errors" - "go.opencensus.io/trace" servicepb "go.viam.com/api/service/vision/v1" - "go.viam.com/rdk/components/camera" "go.viam.com/rdk/data" "go.viam.com/rdk/resource" "go.viam.com/rdk/robot" viz "go.viam.com/rdk/vision" "go.viam.com/rdk/vision/classification" "go.viam.com/rdk/vision/objectdetection" - "go.viam.com/rdk/vision/segmentation" "go.viam.com/rdk/vision/viscapture" ) @@ -235,19 +231,6 @@ func FromDependencies(deps resource.Dependencies, name string) (Service, error) return resource.FromDependencies[Service](deps, Named(name)) } -// vizModel wraps the vision model with all the service interface methods. -type vizModel struct { - resource.Named - resource.AlwaysRebuild - r robot.Robot // in order to get access to all cameras - properties Properties - closerFunc func(ctx context.Context) error // close the underlying model - classifierFunc classification.Classifier - detectorFunc objectdetection.Detector - segmenter3DFunc segmentation.Segmenter - defaultCamera string -} - // Properties returns various information regarding the current vision service, // specifically, which vision tasks are supported by the resource. type Properties struct { @@ -255,241 +238,3 @@ type Properties struct { DetectionSupported bool ObjectPCDsSupported bool } - -// NewService wraps the vision model in the struct that fulfills the vision service interface. -func NewService( - name resource.Name, - r robot.Robot, - c func(ctx context.Context) error, - cf classification.Classifier, - df objectdetection.Detector, - s3f segmentation.Segmenter, - defaultCamera string, -) (Service, error) { - if cf == nil && df == nil && s3f == nil { - return nil, errors.Errorf( - "model %q does not fulfill any method of the vision service. It is neither a detector, nor classifier, nor 3D segmenter", name) - } - - p := Properties{false, false, false} - if cf != nil { - p.ClassificationSupported = true - } - if df != nil { - p.DetectionSupported = true - } - if s3f != nil { - p.ObjectPCDsSupported = true - } - - return &vizModel{ - Named: name.AsNamed(), - r: r, - properties: p, - closerFunc: c, - classifierFunc: cf, - detectorFunc: df, - segmenter3DFunc: s3f, - defaultCamera: defaultCamera, - }, nil -} - -// Detections returns the detections of given image if the model implements objectdetector.Detector. -func (vm *vizModel) Detections( - ctx context.Context, - img image.Image, - extra map[string]interface{}, -) ([]objectdetection.Detection, error) { - ctx, span := trace.StartSpan(ctx, "service::vision::Detections::"+vm.Named.Name().String()) - defer span.End() - if vm.detectorFunc == nil { - return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name()) - } - return vm.detectorFunc(ctx, img) -} - -// DetectionsFromCamera returns the detections of the next image from the given camera. -func (vm *vizModel) DetectionsFromCamera( - ctx context.Context, - cameraName string, - extra map[string]interface{}, -) ([]objectdetection.Detection, error) { - ctx, span := trace.StartSpan(ctx, "service::vision::DetectionsFromCamera::"+vm.Named.Name().String()) - defer span.End() - if cameraName == "" && vm.defaultCamera == "" { - return nil, errors.New("no camera name provided and no default camera found") - } else if cameraName == "" { - cameraName = vm.defaultCamera - } - if vm.detectorFunc == nil { - return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name()) - } - cam, err := camera.FromRobot(vm.r, cameraName) - if err != nil { - return nil, errors.Wrapf(err, "could not find camera named %s", cameraName) - } - img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) - if err != nil { - return nil, errors.Wrapf(err, "could not get image from %s", cameraName) - } - return vm.detectorFunc(ctx, img) -} - -// Classifications returns the classifications of given image if the model implements classifications.Classifier. -func (vm *vizModel) Classifications( - ctx context.Context, - img image.Image, - n int, - extra map[string]interface{}, -) (classification.Classifications, error) { - ctx, span := trace.StartSpan(ctx, "service::vision::Classifications::"+vm.Named.Name().String()) - defer span.End() - if vm.classifierFunc == nil { - return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name()) - } - fullClassifications, err := vm.classifierFunc(ctx, img) - if err != nil { - return nil, errors.Wrap(err, "could not get classifications from image") - } - return fullClassifications.TopN(n) -} - -// ClassificationsFromCamera returns the classifications of the next image from the given camera. -func (vm *vizModel) ClassificationsFromCamera( - ctx context.Context, - cameraName string, - n int, - extra map[string]interface{}, -) (classification.Classifications, error) { - ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String()) - defer span.End() - if cameraName == "" && vm.defaultCamera == "" { - return nil, errors.New("no camera name provided and no default camera found") - } else if cameraName == "" { - cameraName = vm.defaultCamera - } - if vm.classifierFunc == nil { - return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name()) - } - cam, err := camera.FromRobot(vm.r, cameraName) - if err != nil { - return nil, errors.Wrapf(err, "could not find camera named %s", cameraName) - } - img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) - if err != nil { - return nil, errors.Wrapf(err, "could not get image from %s", cameraName) - } - fullClassifications, err := vm.classifierFunc(ctx, img) - if err != nil { - return nil, errors.Wrap(err, "could not get classifications from image") - } - return fullClassifications.TopN(n) -} - -// GetObjectPointClouds returns all the found objects in a 3D image if the model implements Segmenter3D. -func (vm *vizModel) GetObjectPointClouds( - ctx context.Context, - cameraName string, - extra map[string]interface{}, -) ([]*viz.Object, error) { - if vm.segmenter3DFunc == nil { - return nil, errors.Errorf("vision model %q does not implement a 3D segmenter", vm.Named.Name().String()) - } - ctx, span := trace.StartSpan(ctx, "service::vision::GetObjectPointClouds::"+vm.Named.Name().String()) - defer span.End() - if cameraName == "" && vm.defaultCamera == "" { - return nil, errors.New("no camera name provided and no default camera found") - } else if cameraName == "" { - cameraName = vm.defaultCamera - } - cam, err := camera.FromRobot(vm.r, cameraName) - if err != nil { - return nil, err - } - return vm.segmenter3DFunc(ctx, cam) -} - -// GetProperties returns a Properties object that details the vision capabilities of the model. -func (vm *vizModel) GetProperties(ctx context.Context, extra map[string]interface{}) (*Properties, error) { - _, span := trace.StartSpan(ctx, "service::vision::GetProperties::"+vm.Named.Name().String()) - defer span.End() - - return &vm.properties, nil -} - -func (vm *vizModel) CaptureAllFromCamera( - ctx context.Context, - cameraName string, - opt viscapture.CaptureOptions, - extra map[string]interface{}, -) (viscapture.VisCapture, error) { - ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String()) - defer span.End() - if cameraName == "" && vm.defaultCamera == "" { - return viscapture.VisCapture{}, errors.New("no camera name provided and no default camera found") - } else if cameraName == "" { - cameraName = vm.defaultCamera - } - cam, err := camera.FromRobot(vm.r, cameraName) - if err != nil { - return viscapture.VisCapture{}, errors.Wrapf(err, "could not find camera named %s", cameraName) - } - img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) - if err != nil { - return viscapture.VisCapture{}, errors.Wrapf(err, "could not get image from %s", cameraName) - } - logger := vm.r.Logger() - var detections []objectdetection.Detection - if opt.ReturnDetections { - if !vm.properties.DetectionSupported { - logger.Debugf("detections requested but vision model %q does not implement a Detector", vm.Named.Name()) - } else { - detections, err = vm.Detections(ctx, img, extra) - if err != nil { - return viscapture.VisCapture{}, err - } - } - } - var classifications classification.Classifications - if opt.ReturnClassifications { - logger := vm.r.Logger() - if !vm.properties.ClassificationSupported { - logger.Debugf("classifications requested in CaptureAll but vision model %q does not implement a Classifier", - vm.Named.Name()) - } else { - classifications, err = vm.Classifications(ctx, img, 0, extra) - if err != nil { - return viscapture.VisCapture{}, err - } - } - } - - var objPCD []*viz.Object - if opt.ReturnObject { - if !vm.properties.ObjectPCDsSupported { - logger := vm.r.Logger() - logger.Debugf("object point cloud requested in CaptureAll but vision model %q does not implement a 3D Segmenter", vm.Named.Name()) - } else { - objPCD, err = vm.GetObjectPointClouds(ctx, cameraName, extra) - if err != nil { - return viscapture.VisCapture{}, err - } - } - } - if !opt.ReturnImage { - img = nil - } - return viscapture.VisCapture{ - Image: img, - Detections: detections, - Classifications: classifications, - Objects: objPCD, - }, nil -} - -func (vm *vizModel) Close(ctx context.Context) error { - if vm.closerFunc == nil { - return nil - } - return vm.closerFunc(ctx) -} diff --git a/services/vision/vision_service_builder.go b/services/vision/vision_service_builder.go new file mode 100644 index 00000000000..5d1c44b11f3 --- /dev/null +++ b/services/vision/vision_service_builder.go @@ -0,0 +1,332 @@ +package vision + +import ( + "context" + "image" + + "github.com/pkg/errors" + "go.opencensus.io/trace" + + "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/robot" + viz "go.viam.com/rdk/vision" + "go.viam.com/rdk/vision/classification" + "go.viam.com/rdk/vision/objectdetection" + "go.viam.com/rdk/vision/segmentation" + "go.viam.com/rdk/vision/viscapture" +) + +// vizModel wraps the vision model with all the service interface methods. +type vizModel struct { + resource.Named + resource.AlwaysRebuild + logger logging.Logger + properties Properties + closerFunc func(ctx context.Context) error // close the underlying model + getCamera func(cameraName string) (camera.Camera, error) + classifierFunc classification.Classifier + detectorFunc objectdetection.Detector + segmenter3DFunc segmentation.Segmenter + defaultCamera string +} + +// NewService wraps the vision model in the struct that fulfills the vision service interface. +func NewService( + name resource.Name, + deps resource.Dependencies, + logger logging.Logger, + closer func(ctx context.Context) error, + cf classification.Classifier, + df objectdetection.Detector, + s3f segmentation.Segmenter, + defaultCamera string, +) (Service, error) { + if cf == nil && df == nil && s3f == nil { + return nil, errors.Errorf( + "model %q does not fulfill any method of the vision service. It is neither a detector, nor classifier, nor 3D segmenter", name) + } + + p := Properties{false, false, false} + if cf != nil { + p.ClassificationSupported = true + } + if df != nil { + p.DetectionSupported = true + } + if s3f != nil { + p.ObjectPCDsSupported = true + } + + cameraGetter := func(cameraName string) (camera.Camera, error) { + return camera.FromDependencies(deps, cameraName) + } + + return &vizModel{ + Named: name.AsNamed(), + logger: logger, + properties: p, + closerFunc: closer, + getCamera: cameraGetter, + classifierFunc: cf, + detectorFunc: df, + segmenter3DFunc: s3f, + defaultCamera: defaultCamera, + }, nil +} + +// DeprecatedNewService wraps the vision model in the struct that fulfills the vision service +// interface. Register this service with DeprecatedRobotConstructor. +func DeprecatedNewService( + name resource.Name, + r robot.Robot, + c func(ctx context.Context) error, + cf classification.Classifier, + df objectdetection.Detector, + s3f segmentation.Segmenter, + defaultCamera string, +) (Service, error) { + if cf == nil && df == nil && s3f == nil { + return nil, errors.Errorf( + "model %q does not fulfill any method of the vision service. It is neither a detector, nor classifier, nor 3D segmenter", name) + } + + p := Properties{false, false, false} + if cf != nil { + p.ClassificationSupported = true + } + if df != nil { + p.DetectionSupported = true + } + if s3f != nil { + p.ObjectPCDsSupported = true + } + + logger := r.Logger() + + cameraGetter := func(cameraName string) (camera.Camera, error) { + return camera.FromRobot(r, cameraName) + } + + return &vizModel{ + Named: name.AsNamed(), + logger: logger, + properties: p, + closerFunc: c, + getCamera: cameraGetter, + classifierFunc: cf, + detectorFunc: df, + segmenter3DFunc: s3f, + defaultCamera: defaultCamera, + }, nil +} + +// Detections returns the detections of given image if the model implements objectdetector.Detector. +func (vm *vizModel) Detections( + ctx context.Context, + img image.Image, + extra map[string]interface{}, +) ([]objectdetection.Detection, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::Detections::"+vm.Named.Name().String()) + defer span.End() + + if vm.detectorFunc == nil { + return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name()) + } + return vm.detectorFunc(ctx, img) +} + +// DetectionsFromCamera returns the detections of the next image from the given camera. +func (vm *vizModel) DetectionsFromCamera( + ctx context.Context, + cameraName string, + extra map[string]interface{}, +) ([]objectdetection.Detection, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::DetectionsFromCamera::"+vm.Named.Name().String()) + defer span.End() + + if cameraName == "" && vm.defaultCamera == "" { + return nil, errors.New("no camera name provided and no default camera found") + } else if cameraName == "" { + cameraName = vm.defaultCamera + } + if vm.detectorFunc == nil { + return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name()) + } + + cam, err := vm.getCamera(cameraName) + if err != nil { + return nil, errors.Wrapf(err, "could not find camera named %s", cameraName) + } + img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) + if err != nil { + return nil, errors.Wrapf(err, "could not get image from %s", cameraName) + } + return vm.detectorFunc(ctx, img) +} + +// Classifications returns the classifications of given image if the model implements classifications.Classifier. +func (vm *vizModel) Classifications( + ctx context.Context, + img image.Image, + n int, + extra map[string]interface{}, +) (classification.Classifications, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::Classifications::"+vm.Named.Name().String()) + defer span.End() + + if vm.classifierFunc == nil { + return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name()) + } + fullClassifications, err := vm.classifierFunc(ctx, img) + if err != nil { + return nil, errors.Wrap(err, "could not get classifications from image") + } + return fullClassifications.TopN(n) +} + +// ClassificationsFromCamera returns the classifications of the next image from the given camera. +func (vm *vizModel) ClassificationsFromCamera( + ctx context.Context, + cameraName string, + n int, + extra map[string]interface{}, +) (classification.Classifications, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String()) + defer span.End() + + if cameraName == "" && vm.defaultCamera == "" { + return nil, errors.New("no camera name provided and no default camera found") + } else if cameraName == "" { + cameraName = vm.defaultCamera + } + if vm.classifierFunc == nil { + return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name()) + } + + cam, err := vm.getCamera(cameraName) + if err != nil { + return nil, errors.Wrapf(err, "could not find camera named %s", cameraName) + } + img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) + if err != nil { + return nil, errors.Wrapf(err, "could not get image from %s", cameraName) + } + + fullClassifications, err := vm.classifierFunc(ctx, img) + if err != nil { + return nil, errors.Wrap(err, "could not get classifications from image") + } + return fullClassifications.TopN(n) +} + +// GetObjectPointClouds returns all the found objects in a 3D image if the model implements Segmenter3D. +func (vm *vizModel) GetObjectPointClouds( + ctx context.Context, + cameraName string, + extra map[string]interface{}, +) ([]*viz.Object, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::GetObjectPointClouds::"+vm.Named.Name().String()) + defer span.End() + + if vm.segmenter3DFunc == nil { + return nil, errors.Errorf("vision model %q does not implement a 3D segmenter", vm.Named.Name().String()) + } + if cameraName == "" && vm.defaultCamera == "" { + return nil, errors.New("no camera name provided and no default camera found") + } else if cameraName == "" { + cameraName = vm.defaultCamera + } + cam, err := vm.getCamera(cameraName) + if err != nil { + return nil, err + } + return vm.segmenter3DFunc(ctx, cam) +} + +// GetProperties returns a Properties object that details the vision capabilities of the model. +func (vm *vizModel) GetProperties(ctx context.Context, extra map[string]interface{}) (*Properties, error) { + _, span := trace.StartSpan(ctx, "service::vision::GetProperties::"+vm.Named.Name().String()) + defer span.End() + + return &vm.properties, nil +} + +func (vm *vizModel) CaptureAllFromCamera( + ctx context.Context, + cameraName string, + opt viscapture.CaptureOptions, + extra map[string]interface{}, +) (viscapture.VisCapture, error) { + ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String()) + defer span.End() + + if cameraName == "" && vm.defaultCamera == "" { + return viscapture.VisCapture{}, errors.New("no camera name provided and no default camera found") + } else if cameraName == "" { + cameraName = vm.defaultCamera + } + cam, err := vm.getCamera(cameraName) + if err != nil { + return viscapture.VisCapture{}, errors.Wrapf(err, "could not find camera named %s", cameraName) + } + img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam) + if err != nil { + return viscapture.VisCapture{}, errors.Wrapf(err, "could not get image from %s", cameraName) + } + + var detections []objectdetection.Detection + if opt.ReturnDetections { + if !vm.properties.DetectionSupported { + vm.logger.Debugf("detections requested but vision model %q does not implement a Detector", vm.Named.Name()) + } else { + detections, err = vm.Detections(ctx, img, extra) + if err != nil { + return viscapture.VisCapture{}, err + } + } + } + + var classifications classification.Classifications + if opt.ReturnClassifications { + if !vm.properties.ClassificationSupported { + vm.logger.Debugf("classifications requested in CaptureAll but vision model %q does not implement a Classifier", + vm.Named.Name()) + } else { + classifications, err = vm.Classifications(ctx, img, 0, extra) + if err != nil { + return viscapture.VisCapture{}, err + } + } + } + + var objPCD []*viz.Object + if opt.ReturnObject { + if !vm.properties.ObjectPCDsSupported { + vm.logger.Debugf("object point cloud requested in CaptureAll but vision model %q does not implement a 3D Segmenter", vm.Named.Name()) + } else { + objPCD, err = vm.GetObjectPointClouds(ctx, cameraName, extra) + if err != nil { + return viscapture.VisCapture{}, err + } + } + } + + if !opt.ReturnImage { + img = nil + } + return viscapture.VisCapture{ + Image: img, + Detections: detections, + Classifications: classifications, + Objects: objPCD, + }, nil +} + +func (vm *vizModel) Close(ctx context.Context) error { + if vm.closerFunc == nil { + return nil + } + return vm.closerFunc(ctx) +} diff --git a/services/vision/vision_service_builder_test.go b/services/vision/vision_service_builder_test.go new file mode 100644 index 00000000000..c6418963a97 --- /dev/null +++ b/services/vision/vision_service_builder_test.go @@ -0,0 +1,143 @@ +package vision_test + +import ( + "context" + "errors" + "image" + "testing" + + "go.viam.com/test" + + "go.viam.com/rdk/components/camera" + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/rimage" + "go.viam.com/rdk/services/vision" + "go.viam.com/rdk/testutils/inject" + "go.viam.com/rdk/utils" + visionObject "go.viam.com/rdk/vision" + "go.viam.com/rdk/vision/classification" + "go.viam.com/rdk/vision/objectdetection" + "go.viam.com/rdk/vision/viscapture" +) + +const testCameraName = "camera1" + +type ( + simpleDetector struct{} + simpleClassifier struct{} + simpleSegmenter struct{} +) + +func (s *simpleDetector) Detect(ctx context.Context, img image.Image) ([]objectdetection.Detection, error) { + det1 := objectdetection.NewDetection(image.Rect(0, 0, 50, 50), image.Rect(0, 0, 10, 20), 0.5, "yes") + return []objectdetection.Detection{det1}, nil +} + +func (s *simpleClassifier) Classify(context.Context, image.Image) (classification.Classifications, error) { + class1 := classification.NewClassification(0.5, "yes") + return classification.Classifications{class1}, nil +} + +func (s *simpleSegmenter) Segment(ctx context.Context, src camera.Camera) ([]*visionObject.Object, error) { + return []*visionObject.Object{}, nil +} + +func TestNewService(t *testing.T) { + var r inject.Robot + r.LoggerFunc = func() logging.Logger { return nil } + var m simpleDetector + svc, err := vision.DeprecatedNewService(vision.Named("testService"), &r, nil, nil, m.Detect, nil, "") + test.That(t, err, test.ShouldBeNil) + test.That(t, svc, test.ShouldNotBeNil) + result, err := svc.Detections(context.Background(), nil, nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, len(result), test.ShouldEqual, 1) + test.That(t, result[0].Score(), test.ShouldEqual, 0.5) +} + +func TestDefaultCameraSettings(t *testing.T) { + var r inject.Robot + var c simpleClassifier + var d simpleDetector + var s simpleSegmenter + + fakeCamera := &inject.Camera{ + ImageFunc: func(ctx context.Context, mimeType string, extra map[string]interface{}) ([]byte, camera.ImageMetadata, error) { + sourceImg := image.NewRGBA(image.Rect(0, 0, 3, 3)) + imgBytes, err := rimage.EncodeImage(ctx, sourceImg, utils.MimeTypePNG) + test.That(t, err, test.ShouldBeNil) + return imgBytes, camera.ImageMetadata{MimeType: utils.MimeTypePNG}, nil + }, + } + + r.LoggerFunc = func() logging.Logger { + return nil + } + r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { + return fakeCamera, nil + } + r.LoggerFunc = func() logging.Logger { + return logging.NewTestLogger(t) + } + + svc, err := vision.DeprecatedNewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, testCameraName) + test.That(t, err, test.ShouldBeNil) + test.That(t, svc, test.ShouldNotBeNil) + + // test *FromCamera methods with default camera and no camera name + _, err = svc.DetectionsFromCamera(context.Background(), "", nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.ClassificationsFromCamera(context.Background(), "", 1, nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.GetObjectPointClouds(context.Background(), "", nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.CaptureAllFromCamera(context.Background(), "", viscapture.CaptureOptions{}, nil) + test.That(t, err, test.ShouldBeNil) + + // test *FromCamera methods with no default camera or camera name (should throw error) + noCameraError := "no camera name provided and no default camera found" + + svc, err = vision.DeprecatedNewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, "") + test.That(t, err, test.ShouldBeNil) + test.That(t, svc, test.ShouldNotBeNil) + + _, err = svc.DetectionsFromCamera(context.Background(), "", nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldEqual, noCameraError) + _, err = svc.ClassificationsFromCamera(context.Background(), "", 1, nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldEqual, noCameraError) + _, err = svc.GetObjectPointClouds(context.Background(), "", nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldEqual, noCameraError) + _, err = svc.CaptureAllFromCamera(context.Background(), "", viscapture.CaptureOptions{}, nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldEqual, noCameraError) + + // test *FromCamera methods with camera name and default camera (should choose camera name) + // remove default camera from test robot to ensure that only camera name is used + secondCameraName := "used-camera" + r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { + switch name { + case camera.Named(testCameraName): + return nil, errors.New("default camera is being used when camera name should instead") + case camera.Named(secondCameraName): + return fakeCamera, nil + default: + return nil, errors.New("camera not found") + } + } + svc, err = vision.DeprecatedNewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, testCameraName) + test.That(t, err, test.ShouldBeNil) + test.That(t, svc, test.ShouldNotBeNil) + + _, err = svc.DetectionsFromCamera(context.Background(), secondCameraName, nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.ClassificationsFromCamera(context.Background(), secondCameraName, 1, nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.GetObjectPointClouds(context.Background(), secondCameraName, nil) + test.That(t, err, test.ShouldBeNil) + _, err = svc.CaptureAllFromCamera(context.Background(), secondCameraName, viscapture.CaptureOptions{}, nil) + test.That(t, err, test.ShouldBeNil) +} diff --git a/services/vision/vision_test.go b/services/vision/vision_test.go index eaddb6be01c..61a254e4f6f 100644 --- a/services/vision/vision_test.go +++ b/services/vision/vision_test.go @@ -2,29 +2,21 @@ package vision_test import ( "context" - "errors" "image" "testing" "go.viam.com/test" - "go.viam.com/rdk/components/camera" "go.viam.com/rdk/logging" "go.viam.com/rdk/resource" - "go.viam.com/rdk/rimage" "go.viam.com/rdk/services/vision" "go.viam.com/rdk/testutils/inject" - "go.viam.com/rdk/utils" - visionObject "go.viam.com/rdk/vision" - "go.viam.com/rdk/vision/classification" "go.viam.com/rdk/vision/objectdetection" - "go.viam.com/rdk/vision/viscapture" ) const ( - testVisionServiceName = "vision1" - testVisionServiceName2 = "vision2" - testCameraName = "camera1" + testVisionServiceName = "vision1" // Used both here and server_test.go + testVisionServiceName2 = "vision2" // Used in server_test.go, but not here ) func TestFromRobot(t *testing.T) { @@ -34,6 +26,9 @@ func TestFromRobot(t *testing.T) { return []objectdetection.Detection{det1}, nil } var r inject.Robot + r.LoggerFunc = func() logging.Logger { + return nil + } r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { return svc1, nil } @@ -45,118 +40,3 @@ func TestFromRobot(t *testing.T) { test.That(t, len(result), test.ShouldEqual, 1) test.That(t, result[0].Score(), test.ShouldEqual, 0.5) } - -type ( - simpleDetector struct{} - simpleClassifier struct{} - simpleSegmenter struct{} -) - -func (s *simpleDetector) Detect(ctx context.Context, img image.Image) ([]objectdetection.Detection, error) { - det1 := objectdetection.NewDetection(image.Rect(0, 0, 50, 50), image.Rect(0, 0, 10, 20), 0.5, "yes") - return []objectdetection.Detection{det1}, nil -} - -func (s *simpleClassifier) Classify(context.Context, image.Image) (classification.Classifications, error) { - class1 := classification.NewClassification(0.5, "yes") - return classification.Classifications{class1}, nil -} - -func (s *simpleSegmenter) Segment(ctx context.Context, src camera.Camera) ([]*visionObject.Object, error) { - return []*visionObject.Object{}, nil -} - -func TestNewService(t *testing.T) { - var r inject.Robot - var m simpleDetector - svc, err := vision.NewService(vision.Named("testService"), &r, nil, nil, m.Detect, nil, "") - test.That(t, err, test.ShouldBeNil) - test.That(t, svc, test.ShouldNotBeNil) - result, err := svc.Detections(context.Background(), nil, nil) - test.That(t, err, test.ShouldBeNil) - test.That(t, len(result), test.ShouldEqual, 1) - test.That(t, result[0].Score(), test.ShouldEqual, 0.5) -} - -func TestDefaultCameraSettings(t *testing.T) { - var r inject.Robot - var c simpleClassifier - var d simpleDetector - var s simpleSegmenter - - fakeCamera := &inject.Camera{ - ImageFunc: func(ctx context.Context, mimeType string, extra map[string]interface{}) ([]byte, camera.ImageMetadata, error) { - sourceImg := image.NewRGBA(image.Rect(0, 0, 3, 3)) - imgBytes, err := rimage.EncodeImage(ctx, sourceImg, utils.MimeTypePNG) - test.That(t, err, test.ShouldBeNil) - return imgBytes, camera.ImageMetadata{MimeType: utils.MimeTypePNG}, nil - }, - } - - r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { - return fakeCamera, nil - } - r.LoggerFunc = func() logging.Logger { - return logging.NewTestLogger(t) - } - - svc, err := vision.NewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, testCameraName) - test.That(t, err, test.ShouldBeNil) - test.That(t, svc, test.ShouldNotBeNil) - - // test *FromCamera methods with default camera and no camera name - _, err = svc.DetectionsFromCamera(context.Background(), "", nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.ClassificationsFromCamera(context.Background(), "", 1, nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.GetObjectPointClouds(context.Background(), "", nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.CaptureAllFromCamera(context.Background(), "", viscapture.CaptureOptions{}, nil) - test.That(t, err, test.ShouldBeNil) - - // test *FromCamera methods with no default camera or camera name (should throw error) - noCameraError := "no camera name provided and no default camera found" - - svc, err = vision.NewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, "") - test.That(t, err, test.ShouldBeNil) - test.That(t, svc, test.ShouldNotBeNil) - - _, err = svc.DetectionsFromCamera(context.Background(), "", nil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err.Error(), test.ShouldEqual, noCameraError) - _, err = svc.ClassificationsFromCamera(context.Background(), "", 1, nil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err.Error(), test.ShouldEqual, noCameraError) - _, err = svc.GetObjectPointClouds(context.Background(), "", nil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err.Error(), test.ShouldEqual, noCameraError) - _, err = svc.CaptureAllFromCamera(context.Background(), "", viscapture.CaptureOptions{}, nil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err.Error(), test.ShouldEqual, noCameraError) - - // test *FromCamera methods with camera name and default camera (should choose camera name) - // remove default camera from test robot to ensure that only camera name is used - secondCameraName := "used-camera" - r.ResourceByNameFunc = func(name resource.Name) (resource.Resource, error) { - switch name { - case camera.Named(testCameraName): - return nil, errors.New("default camera is being used when camera name should instead") - case camera.Named(secondCameraName): - return fakeCamera, nil - default: - return nil, errors.New("camera not found") - } - } - svc, err = vision.NewService(vision.Named("testService"), &r, nil, c.Classify, d.Detect, s.Segment, testCameraName) - test.That(t, err, test.ShouldBeNil) - test.That(t, svc, test.ShouldNotBeNil) - - _, err = svc.DetectionsFromCamera(context.Background(), secondCameraName, nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.ClassificationsFromCamera(context.Background(), secondCameraName, 1, nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.GetObjectPointClouds(context.Background(), secondCameraName, nil) - test.That(t, err, test.ShouldBeNil) - _, err = svc.CaptureAllFromCamera(context.Background(), secondCameraName, viscapture.CaptureOptions{}, nil) - test.That(t, err, test.ShouldBeNil) -}