Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-2808 - Remove deprecated constructor from all vision service models #3415

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions services/mlmodel/mlmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,8 @@ func Named(name string) resource.Name {
func FromRobot(r robot.Robot, name string) (Service, error) {
return robot.ResourceFromRobot[Service](r, Named(name))
}

// FromDependencies is a helper for getting the named ML model service from a collection of dependencies.
func FromDependencies(deps resource.Dependencies, name string) (Service, error) {
return resource.FromDependencies[Service](deps, Named(name))
}
30 changes: 18 additions & 12 deletions services/vision/colordetector/color_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,33 @@ package colordetector

import (
"context"

"fmt"
"github.com/pkg/errors"
"go.opencensus.io/trace"

"go.viam.com/rdk/components/camera"
"go.viam.com/rdk/logging"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/robot"
"go.viam.com/rdk/services/vision"
"go.viam.com/rdk/utils"
objdet "go.viam.com/rdk/vision/objectdetection"
)

var model = resource.DefaultModelFamily.WithModel("color_detector")

func init() {
resource.RegisterService(vision.API, model, resource.Registration[vision.Service, *objdet.ColorDetectorConfig]{
DeprecatedRobotConstructor: func(
ctx context.Context, r any, c resource.Config, logger logging.Logger,
Constructor: func(
ctx context.Context, deps resource.Dependencies, c resource.Config, logger logging.Logger,
) (vision.Service, error) {
attrs, err := resource.NativeConfig[*objdet.ColorDetectorConfig](c)
if err != nil {
return nil, err
}
actualR, err := utils.AssertType[robot.Robot](r)
if err != nil {
return nil, err
}
return registerColorDetector(ctx, c.ResourceName(), attrs, actualR)

return registerColorDetector(ctx, c.ResourceName(), attrs, deps)
},
WeakDependencies: []resource.Matcher{
resource.SubtypeMatcher{Subtype: camera.SubtypeName},
},
})
}
Expand All @@ -41,16 +40,23 @@ func registerColorDetector(
ctx context.Context,
name resource.Name,
conf *objdet.ColorDetectorConfig,
r robot.Robot,
deps resource.Dependencies,
) (vision.Service, error) {
_, span := trace.StartSpan(ctx, "service::vision::registerColorDetector")
defer span.End()

fmt.Println("Printing COL DETECTOR DEPS")
for _, d := range deps {
fmt.Println(d)
fmt.Println()
}

if conf == nil {
return nil, errors.New("object detection config for color detector cannot be nil")
}
detector, err := objdet.NewColorDetector(conf)
if err != nil {
return nil, errors.Wrapf(err, "error registering color detector %q", name)
}
return vision.NewService(name, r, nil, nil, detector, nil)
return vision.NewService(name, deps, nil, nil, detector, nil)
}
10 changes: 5 additions & 5 deletions services/vision/colordetector/color_detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"go.viam.com/test"
"go.viam.com/utils/artifact"

"go.viam.com/rdk/resource"
"go.viam.com/rdk/rimage"
"go.viam.com/rdk/services/vision"
"go.viam.com/rdk/testutils/inject"
"go.viam.com/rdk/vision/objectdetection"
)

Expand All @@ -20,9 +20,9 @@ func TestColorDetector(t *testing.T) {
DetectColorString: "#4F3815",
}
ctx := context.Background()
r := &inject.Robot{}
deps := make(resource.Dependencies)
name := vision.Named("test_cd")
srv, err := registerColorDetector(ctx, name, &inp, r)
srv, err := registerColorDetector(ctx, name, &inp, deps)
test.That(t, err, test.ShouldBeNil)
test.That(t, srv.Name(), test.ShouldResemble, name)
img, err := rimage.NewImageFromFile(artifact.MustPath("vision/objectdetection/detection_test.jpg"))
Expand All @@ -40,10 +40,10 @@ func TestColorDetector(t *testing.T) {

// with error - bad parameters
inp.HueTolerance = 4.0 // value out of range
_, err = registerColorDetector(ctx, name, &inp, r)
_, err = registerColorDetector(ctx, name, &inp, deps)
test.That(t, err.Error(), test.ShouldContainSubstring, "hue_tolerance_pct must be between")

// with error - nil parameters
_, err = registerColorDetector(ctx, name, nil, r)
_, err = registerColorDetector(ctx, name, nil, deps)
test.That(t, err.Error(), test.ShouldContainSubstring, "cannot be nil")
}
23 changes: 11 additions & 12 deletions services/vision/detectionstosegments/detections_to_3dsegments.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ import (
"github.com/pkg/errors"
"go.opencensus.io/trace"

"go.viam.com/rdk/components/camera"
"go.viam.com/rdk/logging"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/robot"
"go.viam.com/rdk/services/vision"
"go.viam.com/rdk/utils"
"go.viam.com/rdk/vision/objectdetection"
"go.viam.com/rdk/vision/segmentation"
)
Expand All @@ -22,18 +21,18 @@ var model = resource.DefaultModelFamily.WithModel("detector_3d_segmenter")

func init() {
resource.RegisterService(vision.API, model, resource.Registration[vision.Service, *segmentation.DetectionSegmenterConfig]{
DeprecatedRobotConstructor: func(
ctx context.Context, r any, c resource.Config, logger logging.Logger,
Constructor: func(
ctx context.Context, deps resource.Dependencies, c resource.Config, logger logging.Logger,
) (vision.Service, error) {
attrs, err := resource.NativeConfig[*segmentation.DetectionSegmenterConfig](c)
if err != nil {
return nil, err
}
actualR, err := utils.AssertType[robot.Robot](r)
if err != nil {
return nil, err
}
return register3DSegmenterFromDetector(ctx, c.ResourceName(), attrs, actualR)
return register3DSegmenterFromDetector(ctx, c.ResourceName(), attrs, deps)
},
WeakDependencies: []resource.Matcher{
resource.SubtypeMatcher{Subtype: camera.SubtypeName},
resource.SubtypeMatcher{Subtype: vision.SubtypeName},
},
})
}
Expand All @@ -43,14 +42,14 @@ func register3DSegmenterFromDetector(
ctx context.Context,
name resource.Name,
conf *segmentation.DetectionSegmenterConfig,
r robot.Robot,
deps resource.Dependencies,
) (vision.Service, error) {
_, span := trace.StartSpan(ctx, "service::vision::register3DSegmenterFromDetector")
defer span.End()
if conf == nil {
return nil, errors.New("config for 3D segmenter made from a detector cannot be nil")
}
detectorService, err := vision.FromRobot(r, conf.DetectorName)
detectorService, err := vision.FromDependencies(deps, conf.DetectorName)
if err != nil {
return nil, errors.Wrapf(err, "could not find necessary dependency, detector %q", conf.DetectorName)
}
Expand All @@ -65,5 +64,5 @@ func register3DSegmenterFromDetector(
if err != nil {
return nil, errors.Wrap(err, "cannot create 3D segmenter from detector")
}
return vision.NewService(name, r, nil, nil, detector, segmenter)
return vision.NewService(name, deps, nil, nil, detector, segmenter)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ func (s *simpleDetector) Detect(context.Context, image.Image) ([]objectdetection
}

func Test3DSegmentsFromDetector(t *testing.T) {
r := &inject.Robot{}
deps := make(resource.Dependencies)
deps2 := make(resource.Dependencies)
m := &simpleDetector{}
name := vision.Named("testDetector")
svc, err := vision.NewService(name, r, nil, nil, m.Detect, nil)
svc, err := vision.NewService(name, deps, nil, nil, m.Detect, nil)
test.That(t, err, test.ShouldBeNil)
cam := &inject.Camera{}
cam := inject.NewCamera("fakeCamera")
cam.NextPointCloudFunc = func(ctx context.Context) (pc.PointCloud, error) {
return nil, errors.New("no pointcloud")
}
Expand All @@ -44,44 +45,35 @@ func Test3DSegmentsFromDetector(t *testing.T) {
cam.ProjectorFunc = func(ctx context.Context) (transform.Projector, error) {
return &transform.ParallelProjection{}, nil
}
r.ResourceNamesFunc = func() []resource.Name {
return []resource.Name{camera.Named("fakeCamera"), name}
}
r.ResourceByNameFunc = func(n resource.Name) (resource.Resource, error) {
switch n.Name {
case "fakeCamera":
return cam, nil
case "testDetector":
return svc, nil
default:
return nil, resource.NewNotFoundError(n)
}
}
// set up cams as dependencies
deps2[vision.Named("testDetector")] = svc
deps2[camera.Named("fakeCamera")] = cam

params := &segmentation.DetectionSegmenterConfig{
DetectorName: "testDetector",
ConfidenceThresh: 0.2,
}
// bad registration, no parameters
name2 := vision.Named("test_seg")
_, err = register3DSegmenterFromDetector(context.Background(), name2, nil, r)
_, err = register3DSegmenterFromDetector(context.Background(), name2, nil, deps2)
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "cannot be nil")
// bad registration, no such detector
params.DetectorName = "noDetector"
_, err = register3DSegmenterFromDetector(context.Background(), name2, params, r)
_, err = register3DSegmenterFromDetector(context.Background(), name2, params, deps2)
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "could not find necessary dependency")
// successful registration
params.DetectorName = "testDetector"
name3 := vision.Named("test_rcs")
seg, err := register3DSegmenterFromDetector(context.Background(), name3, params, r)
seg, err := register3DSegmenterFromDetector(context.Background(), name3, params, deps2)
test.That(t, err, test.ShouldBeNil)
test.That(t, seg.Name(), test.ShouldResemble, name3)

// fails on not finding camera
_, err = seg.GetObjectPointClouds(context.Background(), "no_camera", map[string]interface{}{})
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "not found")
test.That(t, err.Error(), test.ShouldContainSubstring, "missing")

// fails since camera cannot return images
_, err = seg.GetObjectPointClouds(context.Background(), "fakeCamera", map[string]interface{}{})
Expand Down
30 changes: 18 additions & 12 deletions services/vision/mlvision/ml_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ import (
"go.opencensus.io/trace"
"golang.org/x/exp/constraints"

"go.viam.com/rdk/components/camera"
"go.viam.com/rdk/logging"
"go.viam.com/rdk/ml"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/robot"
"go.viam.com/rdk/services/mlmodel"
"go.viam.com/rdk/services/vision"
"go.viam.com/rdk/utils"
)

var model = resource.DefaultModelFamily.WithModel("mlmodel")
Expand All @@ -40,24 +39,25 @@ const (

func init() {
resource.RegisterService(vision.API, model, resource.Registration[vision.Service, *MLModelConfig]{
DeprecatedRobotConstructor: func(
ctx context.Context, r any, c resource.Config, logger logging.Logger,
Constructor: func(
ctx context.Context, deps resource.Dependencies, c resource.Config, logger logging.Logger,
) (vision.Service, error) {
attrs, err := resource.NativeConfig[*MLModelConfig](c)
if err != nil {
return nil, err
}
actualR, err := utils.AssertType[robot.Robot](r)
if err != nil {
return nil, err
}
return registerMLModelVisionService(ctx, c.ResourceName(), attrs, actualR, logger)
return registerMLModelVisionService(ctx, c.ResourceName(), attrs, deps, logger)
},
WeakDependencies: []resource.Matcher{
resource.SubtypeMatcher{Subtype: camera.SubtypeName},
resource.SubtypeMatcher{Subtype: mlmodel.SubtypeName},
},
})
}

// MLModelConfig specifies the parameters needed to turn an ML model into a vision Model.
type MLModelConfig struct {
resource.TriviallyReconfigurable
ModelName string `json:"mlmodel_name"`
}

Expand All @@ -73,13 +73,19 @@ func registerMLModelVisionService(
ctx context.Context,
name resource.Name,
params *MLModelConfig,
r robot.Robot,
deps resource.Dependencies,
logger logging.Logger,
) (vision.Service, error) {
_, span := trace.StartSpan(ctx, "service::vision::registerMLModelVisionService")
defer span.End()

mlm, err := mlmodel.FromRobot(r, params.ModelName)
fmt.Println("Printing ML MODEL DEPS")
for _, d := range deps {
fmt.Println(d)
fmt.Println()
}

mlm, err := mlmodel.FromDependencies(deps, params.ModelName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -154,7 +160,7 @@ func registerMLModelVisionService(
}

// Don't return a close function, because you don't want to close the underlying ML service
return vision.NewService(name, r, nil, classifierFunc, detectorFunc, segmenter3DFunc)
return vision.NewService(name, deps, nil, classifierFunc, detectorFunc, segmenter3DFunc)
}

// getLabelsFromMetadata returns a slice of strings--the intended labels.
Expand Down
9 changes: 6 additions & 3 deletions services/vision/mlvision/ml_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import (
"go.viam.com/utils/artifact"

"go.viam.com/rdk/logging"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/rimage"
"go.viam.com/rdk/services/mlmodel"
"go.viam.com/rdk/services/mlmodel/tflitecpu"
"go.viam.com/rdk/testutils/inject"
"go.viam.com/rdk/vision/classification"
)

Expand All @@ -32,9 +32,11 @@ func BenchmarkAddMLVisionModel(b *testing.B) {
test.That(b, out, test.ShouldNotBeNil)
modelCfg := MLModelConfig{ModelName: name.Name}

deps := make(resource.Dependencies)

b.ResetTimer()
for i := 0; i < b.N; i++ {
service, err := registerMLModelVisionService(ctx, name, &modelCfg, &inject.Robot{}, logging.NewLogger("benchmark"))
service, err := registerMLModelVisionService(ctx, name, &modelCfg, deps, logging.NewLogger("benchmark"))
test.That(b, err, test.ShouldBeNil)
test.That(b, service, test.ShouldNotBeNil)
test.That(b, service.Name(), test.ShouldResemble, name)
Expand All @@ -57,7 +59,8 @@ func BenchmarkUseMLVisionModel(b *testing.B) {
test.That(b, out, test.ShouldNotBeNil)
modelCfg := MLModelConfig{ModelName: name.Name}

service, err := registerMLModelVisionService(ctx, name, &modelCfg, &inject.Robot{}, logging.NewLogger("benchmark"))
deps := make(resource.Dependencies)
service, err := registerMLModelVisionService(ctx, name, &modelCfg, deps, logging.NewLogger("benchmark"))
test.That(b, err, test.ShouldBeNil)
test.That(b, service, test.ShouldNotBeNil)
test.That(b, service.Name(), test.ShouldResemble, name)
Expand Down
Loading