Skip to content

Commit 3715f4c

Browse files
authored
split vizModel to a separate file (#5094)
vision.go was super long, and half of it was unrelated to the other half. So, it's now 2 files. This will make it easier to think about fixing RDSK-11093, which is about updating `vision.NewService()` to not use the deprecated constructor. No changes to functionality are intended: I just took a file and split it in half.
1 parent cba228d commit 3715f4c

File tree

4 files changed

+410
-381
lines changed

4 files changed

+410
-381
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
package vision
2+
3+
import (
4+
"context"
5+
"image"
6+
7+
"github.com/pkg/errors"
8+
"go.opencensus.io/trace"
9+
10+
"go.viam.com/rdk/components/camera"
11+
"go.viam.com/rdk/resource"
12+
"go.viam.com/rdk/robot"
13+
viz "go.viam.com/rdk/vision"
14+
"go.viam.com/rdk/vision/classification"
15+
"go.viam.com/rdk/vision/objectdetection"
16+
"go.viam.com/rdk/vision/segmentation"
17+
"go.viam.com/rdk/vision/viscapture"
18+
)
19+
20+
// vizModel wraps the vision model with all the service interface methods.
21+
type vizModel struct {
22+
resource.Named
23+
resource.AlwaysRebuild
24+
r robot.Robot // in order to get access to all cameras
25+
properties Properties
26+
closerFunc func(ctx context.Context) error // close the underlying model
27+
classifierFunc classification.Classifier
28+
detectorFunc objectdetection.Detector
29+
segmenter3DFunc segmentation.Segmenter
30+
defaultCamera string
31+
}
32+
33+
// NewService wraps the vision model in the struct that fulfills the vision service interface.
34+
func NewService(
35+
name resource.Name,
36+
r robot.Robot,
37+
c func(ctx context.Context) error,
38+
cf classification.Classifier,
39+
df objectdetection.Detector,
40+
s3f segmentation.Segmenter,
41+
defaultCamera string,
42+
) (Service, error) {
43+
if cf == nil && df == nil && s3f == nil {
44+
return nil, errors.Errorf(
45+
"model %q does not fulfill any method of the vision service. It is neither a detector, nor classifier, nor 3D segmenter", name)
46+
}
47+
48+
p := Properties{false, false, false}
49+
if cf != nil {
50+
p.ClassificationSupported = true
51+
}
52+
if df != nil {
53+
p.DetectionSupported = true
54+
}
55+
if s3f != nil {
56+
p.ObjectPCDsSupported = true
57+
}
58+
59+
return &vizModel{
60+
Named: name.AsNamed(),
61+
r: r,
62+
properties: p,
63+
closerFunc: c,
64+
classifierFunc: cf,
65+
detectorFunc: df,
66+
segmenter3DFunc: s3f,
67+
defaultCamera: defaultCamera,
68+
}, nil
69+
}
70+
71+
// Detections returns the detections of given image if the model implements objectdetector.Detector.
72+
func (vm *vizModel) Detections(
73+
ctx context.Context,
74+
img image.Image,
75+
extra map[string]interface{},
76+
) ([]objectdetection.Detection, error) {
77+
ctx, span := trace.StartSpan(ctx, "service::vision::Detections::"+vm.Named.Name().String())
78+
defer span.End()
79+
if vm.detectorFunc == nil {
80+
return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name())
81+
}
82+
return vm.detectorFunc(ctx, img)
83+
}
84+
85+
// DetectionsFromCamera returns the detections of the next image from the given camera.
86+
func (vm *vizModel) DetectionsFromCamera(
87+
ctx context.Context,
88+
cameraName string,
89+
extra map[string]interface{},
90+
) ([]objectdetection.Detection, error) {
91+
ctx, span := trace.StartSpan(ctx, "service::vision::DetectionsFromCamera::"+vm.Named.Name().String())
92+
defer span.End()
93+
if cameraName == "" && vm.defaultCamera == "" {
94+
return nil, errors.New("no camera name provided and no default camera found")
95+
} else if cameraName == "" {
96+
cameraName = vm.defaultCamera
97+
}
98+
if vm.detectorFunc == nil {
99+
return nil, errors.Errorf("vision model %q does not implement a Detector", vm.Named.Name())
100+
}
101+
cam, err := camera.FromRobot(vm.r, cameraName)
102+
if err != nil {
103+
return nil, errors.Wrapf(err, "could not find camera named %s", cameraName)
104+
}
105+
img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam)
106+
if err != nil {
107+
return nil, errors.Wrapf(err, "could not get image from %s", cameraName)
108+
}
109+
return vm.detectorFunc(ctx, img)
110+
}
111+
112+
// Classifications returns the classifications of given image if the model implements classifications.Classifier.
113+
func (vm *vizModel) Classifications(
114+
ctx context.Context,
115+
img image.Image,
116+
n int,
117+
extra map[string]interface{},
118+
) (classification.Classifications, error) {
119+
ctx, span := trace.StartSpan(ctx, "service::vision::Classifications::"+vm.Named.Name().String())
120+
defer span.End()
121+
if vm.classifierFunc == nil {
122+
return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name())
123+
}
124+
fullClassifications, err := vm.classifierFunc(ctx, img)
125+
if err != nil {
126+
return nil, errors.Wrap(err, "could not get classifications from image")
127+
}
128+
return fullClassifications.TopN(n)
129+
}
130+
131+
// ClassificationsFromCamera returns the classifications of the next image from the given camera.
132+
func (vm *vizModel) ClassificationsFromCamera(
133+
ctx context.Context,
134+
cameraName string,
135+
n int,
136+
extra map[string]interface{},
137+
) (classification.Classifications, error) {
138+
ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String())
139+
defer span.End()
140+
if cameraName == "" && vm.defaultCamera == "" {
141+
return nil, errors.New("no camera name provided and no default camera found")
142+
} else if cameraName == "" {
143+
cameraName = vm.defaultCamera
144+
}
145+
if vm.classifierFunc == nil {
146+
return nil, errors.Errorf("vision model %q does not implement a Classifier", vm.Named.Name())
147+
}
148+
cam, err := camera.FromRobot(vm.r, cameraName)
149+
if err != nil {
150+
return nil, errors.Wrapf(err, "could not find camera named %s", cameraName)
151+
}
152+
img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam)
153+
if err != nil {
154+
return nil, errors.Wrapf(err, "could not get image from %s", cameraName)
155+
}
156+
fullClassifications, err := vm.classifierFunc(ctx, img)
157+
if err != nil {
158+
return nil, errors.Wrap(err, "could not get classifications from image")
159+
}
160+
return fullClassifications.TopN(n)
161+
}
162+
163+
// GetObjectPointClouds returns all the found objects in a 3D image if the model implements Segmenter3D.
164+
func (vm *vizModel) GetObjectPointClouds(
165+
ctx context.Context,
166+
cameraName string,
167+
extra map[string]interface{},
168+
) ([]*viz.Object, error) {
169+
if vm.segmenter3DFunc == nil {
170+
return nil, errors.Errorf("vision model %q does not implement a 3D segmenter", vm.Named.Name().String())
171+
}
172+
ctx, span := trace.StartSpan(ctx, "service::vision::GetObjectPointClouds::"+vm.Named.Name().String())
173+
defer span.End()
174+
if cameraName == "" && vm.defaultCamera == "" {
175+
return nil, errors.New("no camera name provided and no default camera found")
176+
} else if cameraName == "" {
177+
cameraName = vm.defaultCamera
178+
}
179+
cam, err := camera.FromRobot(vm.r, cameraName)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return vm.segmenter3DFunc(ctx, cam)
184+
}
185+
186+
// GetProperties returns a Properties object that details the vision capabilities of the model.
187+
func (vm *vizModel) GetProperties(ctx context.Context, extra map[string]interface{}) (*Properties, error) {
188+
_, span := trace.StartSpan(ctx, "service::vision::GetProperties::"+vm.Named.Name().String())
189+
defer span.End()
190+
191+
return &vm.properties, nil
192+
}
193+
194+
func (vm *vizModel) CaptureAllFromCamera(
195+
ctx context.Context,
196+
cameraName string,
197+
opt viscapture.CaptureOptions,
198+
extra map[string]interface{},
199+
) (viscapture.VisCapture, error) {
200+
ctx, span := trace.StartSpan(ctx, "service::vision::ClassificationsFromCamera::"+vm.Named.Name().String())
201+
defer span.End()
202+
if cameraName == "" && vm.defaultCamera == "" {
203+
return viscapture.VisCapture{}, errors.New("no camera name provided and no default camera found")
204+
} else if cameraName == "" {
205+
cameraName = vm.defaultCamera
206+
}
207+
cam, err := camera.FromRobot(vm.r, cameraName)
208+
if err != nil {
209+
return viscapture.VisCapture{}, errors.Wrapf(err, "could not find camera named %s", cameraName)
210+
}
211+
img, err := camera.DecodeImageFromCamera(ctx, "", extra, cam)
212+
if err != nil {
213+
return viscapture.VisCapture{}, errors.Wrapf(err, "could not get image from %s", cameraName)
214+
}
215+
logger := vm.r.Logger()
216+
var detections []objectdetection.Detection
217+
if opt.ReturnDetections {
218+
if !vm.properties.DetectionSupported {
219+
logger.Debugf("detections requested but vision model %q does not implement a Detector", vm.Named.Name())
220+
} else {
221+
detections, err = vm.Detections(ctx, img, extra)
222+
if err != nil {
223+
return viscapture.VisCapture{}, err
224+
}
225+
}
226+
}
227+
var classifications classification.Classifications
228+
if opt.ReturnClassifications {
229+
logger := vm.r.Logger()
230+
if !vm.properties.ClassificationSupported {
231+
logger.Debugf("classifications requested in CaptureAll but vision model %q does not implement a Classifier",
232+
vm.Named.Name())
233+
} else {
234+
classifications, err = vm.Classifications(ctx, img, 0, extra)
235+
if err != nil {
236+
return viscapture.VisCapture{}, err
237+
}
238+
}
239+
}
240+
241+
var objPCD []*viz.Object
242+
if opt.ReturnObject {
243+
if !vm.properties.ObjectPCDsSupported {
244+
logger := vm.r.Logger()
245+
logger.Debugf("object point cloud requested in CaptureAll but vision model %q does not implement a 3D Segmenter", vm.Named.Name())
246+
} else {
247+
objPCD, err = vm.GetObjectPointClouds(ctx, cameraName, extra)
248+
if err != nil {
249+
return viscapture.VisCapture{}, err
250+
}
251+
}
252+
}
253+
if !opt.ReturnImage {
254+
img = nil
255+
}
256+
return viscapture.VisCapture{
257+
Image: img,
258+
Detections: detections,
259+
Classifications: classifications,
260+
Objects: objPCD,
261+
}, nil
262+
}
263+
264+
func (vm *vizModel) Close(ctx context.Context) error {
265+
if vm.closerFunc == nil {
266+
return nil
267+
}
268+
return vm.closerFunc(ctx)
269+
}

0 commit comments

Comments
 (0)