diff --git a/compute/cker/include/cker/ShapeIterator.h b/compute/cker/include/cker/ShapeIterator.h new file mode 100644 index 00000000000..72fe3aa48a4 --- /dev/null +++ b/compute/cker/include/cker/ShapeIterator.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_SHAPE_ITERATOR_H__ +#define __NNFW_CKER_SHAPE_ITERATOR_H__ + +#include +#include "cker/Shape.h" + +namespace nnfw +{ +namespace cker +{ +struct ShapeIterator +{ + /// Definition of this iterator's traits that can be accessed by std::iterator_traits + using value_type = decltype(std::declval().Dims(0)); + using difference_type = std::ptrdiff_t; + using pointer = value_type *; + using reference = value_type &; + using iterator_category = std::bidirectional_iterator_tag; + + ShapeIterator(const Shape &s) : _shape{s}, _current{0}, _last{s.DimensionsCount()} {} + static ShapeIterator end_iterator(const Shape &s) { return ShapeIterator(s, EndIteratorTag{}); } + + ShapeIterator &operator++() + { + ++_current; + return *this; + } + + // postincrement + ShapeIterator operator++(int) + { + auto copy = *this; + ++_current; + return copy; + } + + ShapeIterator &operator--() + { + --_current; + return *this; + } + + ShapeIterator operator--(int) + { + auto copy = *this; + --_current; + return copy; + } + + bool operator!=(const ShapeIterator &other) const { return _current != other._current; } + bool operator==(const ShapeIterator &other) const { return _current == other._current; } + + /// Because the underlying method returns by-value, this operator does the same + /// instead of returning by-reference like most iterators do. + value_type operator*() const { return _shape.Dims(_current); } + +private: + struct EndIteratorTag + { + }; + // Creates an iterator instance pointing to the past-the-end element + // This iterator doesn't point to a valid element and thus its dereference is undefined behavior + ShapeIterator(const Shape &s, EndIteratorTag) + : _shape{s}, _current{s.DimensionsCount()}, _last{s.DimensionsCount()} + { + } + + const Shape &_shape; + int32_t _current = 0, _last = 0; +}; + +inline ShapeIterator begin(const Shape &s) { return ShapeIterator(s); } +inline ShapeIterator end(const Shape &s) { return ShapeIterator::end_iterator(s); } + +} // namespace cker +} // namespace nnfw + +#endif // diff --git a/compute/cker/include/cker/Utils.h b/compute/cker/include/cker/Utils.h index 9aae0a957bc..f3cbf5c3b86 100644 --- a/compute/cker/include/cker/Utils.h +++ b/compute/cker/include/cker/Utils.h @@ -19,11 +19,14 @@ #define __NNFW_CKER_UTILS_H__ #include "Shape.h" +#include "ShapeIterator.h" #include "neon/neon_check.h" #include #include +#include +#include #include namespace nnfw @@ -480,6 +483,30 @@ template class SequentialTensorWriter T *output_ptr_; }; +inline std::ostream &operator<<(std::ostream &os, const Shape &shape) +{ + using std::begin; + using std::end; + + std::string formatted = + std::accumulate(begin(shape), end(shape), std::string{"["}, + [](std::string joined, ShapeIterator::value_type dim) { + return std::move(joined).append(std::to_string(dim)).append(","); + }); + + if (formatted.back() == '[') + { + formatted.push_back(']'); + } + else + { + formatted.back() = ']'; + } + + os << formatted; + return os; +} + } // namespace cker } // namespace nnfw diff --git a/compute/cker/src/ShapeIterator.test.cc b/compute/cker/src/ShapeIterator.test.cc new file mode 100644 index 00000000000..97c752b5387 --- /dev/null +++ b/compute/cker/src/ShapeIterator.test.cc @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace nnfw::cker; + +TEST(CKer_Utils, ShapeIterator_basic) +{ + const Shape test_shape{1, 3, 1024, 768}; + { + // test the front and back iterability with basic operators + ShapeIterator it{test_shape}; + EXPECT_EQ(*it, 1); + ++it; + EXPECT_EQ(*it, 3); + it++; + EXPECT_EQ(*it, 1024); + --it; + EXPECT_EQ(*it, 3); + it--; + EXPECT_EQ(*it, 1); + } + { + // test the iterator's compatibility with STL iterator functions + ShapeIterator it{test_shape}; + auto it2 = std::next(it); + EXPECT_EQ(*it2, 3); + EXPECT_EQ(*it, 1); // make sure the original iterator is untouched + + std::advance(it2, 2); + EXPECT_EQ(*it2, 768); + + std::advance(it2, -1); + EXPECT_EQ(*it2, 1024); + } + { + // postincrement operator test + ShapeIterator it{test_shape}; + const auto it2 = it++; + EXPECT_EQ(*it, 3); + EXPECT_EQ(*it2, 1); + } + { + // test the ability to iterate over a Shape with range-based loops + int expected_dims[] = {1, 3, 1024, 768}; + int i = 0; + for (auto &&dim : test_shape) + { + EXPECT_EQ(dim, expected_dims[i++]); + } + } + { + // test the ability to retrieve iterators using begin & end + const auto first = begin(test_shape); + const auto last = end(test_shape); + EXPECT_GT(std::distance(first, last), 0); + EXPECT_EQ(std::distance(first, last), test_shape.DimensionsCount()); + } + + { + // test and demostrate the usage of iterators with STL algos + const auto first = begin(test_shape); + const auto last = end(test_shape); + const auto shape_elems = + std::accumulate(first, last, 1, std::multiplies{}); + EXPECT_EQ(shape_elems, test_shape.FlatSize()); + } + + { + // Shape and ofstream interoperability test + std::stringstream ss; + ss << test_shape; + EXPECT_EQ(ss.str(), "[1,3,1024,768]"); + } +} + +TEST(CKer_Utils, neg_ShapeIterator_empty_shape) +{ + const Shape test_shape{}; + { + const auto first = begin(test_shape); + const auto last = end(test_shape); + EXPECT_EQ(first, last); + } + + { + std::stringstream ss; + ss << test_shape; + EXPECT_EQ(ss.str(), "[]"); + } +}