Skip to content

Commit

Permalink
Modify API to return 2D ROI coords
Browse files Browse the repository at this point in the history
  • Loading branch information
fiona-gladwin committed Aug 24, 2023
1 parent 0beaecc commit f106edd
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions rocAL/rocAL/include/pipeline/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class Node
std::shared_ptr<Graph> graph() { return _graph; }
void set_meta_data(pMetaDataBatch meta_data_info) { _meta_data_info = meta_data_info; }
bool _is_ssd = false;
ROI2DCords *get_src_roi() { return (ROI2DCords *)_inputs[0]->info().roi().get_ptr(); }
ROI2DCords *get_dst_roi() { return (ROI2DCords *)_outputs[0]->info().roi().get_ptr(); }
ROI2DCords *get_src_roi() { return _inputs[0]->info().roi().get_2D_roi(); }
ROI2DCords *get_dst_roi() { return _outputs[0]->info().roi().get_2D_roi(); }
protected:
virtual void create_node() = 0;
virtual void update_node() = 0;
Expand Down
6 changes: 5 additions & 1 deletion rocAL/rocAL/include/pipeline/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ void allocate_host_or_pinned_mem(void **ptr, size_t size, RocalMemType mem_type)

struct ROI {
unsigned *get_ptr() { return _roi_ptr.get(); }
ROI2DCords* get_2D_roi() { return reinterpret_cast<ROI2DCords*>(_roi_ptr.get()); }
ROI2DCords* get_2D_roi() {
if (_dims != 2)
THROW("ROI has more than 2 dimensions. Cannot return ROI2DCords")
return reinterpret_cast<ROI2DCords*>(_roi_ptr.get());
}
void set_ptr(unsigned *ptr, RocalMemType mem_type, unsigned dims = 0) {
if(!_dims) _dims = dims;
_stride = _dims * 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void CropNode::create_node() {
}

void CropNode::update_node() {
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
std::vector<uint32_t> crop_h_dims, crop_w_dims;
_crop_param->get_crop_dimensions(crop_w_dims, crop_h_dims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void CropMirrorNormalizeNode::create_node() {
}

void CropMirrorNormalizeNode::update_node() {
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
std::vector<uint32_t> crop_h_dims, crop_w_dims;
_crop_param->get_crop_dimensions(crop_w_dims, crop_h_dims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void CropResizeNode::create_node()
}

void CropResizeNode::update_node() {
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
std::vector<uint32_t> crop_h_dims, crop_w_dims;
_crop_param->get_crop_dimensions(crop_w_dims, crop_h_dims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void RandomCropNode::create_node()

void RandomCropNode::update_node()
{
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
std::vector<uint32_t> crop_h_dims, crop_w_dims;
_crop_param->get_crop_dimensions(crop_w_dims, crop_h_dims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void ResizeCropMirrorNode::create_node() {
}

void ResizeCropMirrorNode::update_node() {
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
std::vector<uint32_t> crop_h_dims, crop_w_dims;
_crop_param->get_crop_dimensions(crop_w_dims, crop_h_dims);
Expand Down
2 changes: 1 addition & 1 deletion rocAL/rocAL/source/augmentations/node_ssd_random_crop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ inline double ssd_BBoxIntersectionOverUnion(const BoundingBoxCord &box1, const B

void SSDRandomCropNode::update_node()
{
_crop_param->set_image_dimensions(reinterpret_cast<ROI2DCords *>(_inputs[0]->info().roi().get_ptr()));
_crop_param->set_image_dimensions(_inputs[0]->info().roi().get_2D_roi());
_crop_param->update_array();
ROI2DCords *crop_dims = static_cast<ROI2DCords *>(_crop_coordinates); // ROI to be cropped from source
std::random_device rd;
Expand Down

0 comments on commit f106edd

Please sign in to comment.