Skip to content

Commit

Permalink
refactor settings_changed(old, new, fwd) & Decimate example block
Browse files Browse the repository at this point in the history
... to account for pre- and post-(/forward) settings

Signed-off-by: rstein <[email protected]>
Signed-off-by: Ralph J. Steinhagen <[email protected]>
  • Loading branch information
RalphSteinhagen authored and wirew0rm committed Sep 15, 2023
1 parent 39c680c commit 61e2bf8
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 22 deletions.
35 changes: 23 additions & 12 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,15 @@ struct node : protected std::tuple<Arguments...> {
init(std::shared_ptr<gr::Sequence> progress_, std::shared_ptr<fair::thread_pool::BasicThreadPool> ioThreadPool_) {
progress = std::move(progress_);
ioThreadPool = std::move(ioThreadPool_);
std::ignore = settings().apply_staged_parameters();
if (const auto forward_parameters = settings().apply_staged_parameters(); !forward_parameters.empty()) {
std::for_each(_tags_at_output.begin(), _tags_at_output.end(), [&forward_parameters](tag_t &tag) {
for (const auto &[key, value] : forward_parameters) {
tag.map.insert_or_assign(key, value);
}
});
_output_tags_changed = true;
}

// TODO: expand on this init function:
// * store initial setting -> needed for `reset()` call
// * ...
Expand Down Expand Up @@ -598,7 +606,7 @@ struct node : protected std::tuple<Arguments...> {

constexpr void
forward_tags() noexcept {
if (!_output_tags_changed) {
if (!_output_tags_changed && !_input_tags_present) {
return;
}
std::size_t port_id = 0; // TODO absorb this as optional tuple_for_each argument
Expand Down Expand Up @@ -730,10 +738,9 @@ struct node : protected std::tuple<Arguments...> {
if constexpr (node_template_parameters::template contains<PerformDecimationInterpolation>) {
if (numerator != 1_UZ || denominator != 1_UZ) {
// TODO: this ill-defined checks can be done only once after parameters were changed
const double ratio = static_cast<double>(numerator) / static_cast<double>(denominator);
bool is_ill_defined = (denominator > ports_status.in_max_samples)
|| (static_cast<double>(ports_status.in_min_samples) * ratio > static_cast<double>(ports_status.out_max_samples))
|| (static_cast<double>(ports_status.in_max_samples) * ratio < static_cast<double>(ports_status.out_min_samples));
const double ratio = static_cast<double>(numerator) / static_cast<double>(denominator);
bool is_ill_defined = (denominator > ports_status.in_max_samples) || (static_cast<double>(ports_status.in_min_samples) * ratio > static_cast<double>(ports_status.out_max_samples))
|| (static_cast<double>(ports_status.in_max_samples) * ratio < static_cast<double>(ports_status.out_min_samples));
assert(!is_ill_defined);
if (is_ill_defined) {
return { requested_work, 0_UZ, work_return_status_t::ERROR };
Expand Down Expand Up @@ -767,8 +774,6 @@ struct node : protected std::tuple<Arguments...> {
}
}

_input_tags_present = false;
_output_tags_changed = false;
if (ports_status.in_samples_to_next_tag == 0) {
if constexpr (HasProcessOneFunction<Derived>) {
ports_status.in_samples = 1; // N.B. limit to one so that only one process_on(...) invocation receives the tag
Expand All @@ -790,8 +795,10 @@ struct node : protected std::tuple<Arguments...> {
if ((readPos == -1 && tags[0].index <= 0) // first tag on initialised stream
|| tag_stream_pos <= 0) {
for (const auto &[index, map] : tags) {
tag_at_present_input.map.insert(map.begin(), map.end());
merged_tag_map.insert(map.begin(), map.end());
for (const auto &[key, value] : map) {
tag_at_present_input.map.insert_or_assign(key, value);
merged_tag_map.insert_or_assign(key, value);
}
}
std::ignore = input_port.tagReader().consume(1_UZ);
}
Expand All @@ -809,9 +816,13 @@ struct node : protected std::tuple<Arguments...> {
}
}

if (settings().changed()) {
if (settings().changed() || _input_tags_present || _output_tags_changed) {
if (const auto forward_parameters = settings().apply_staged_parameters(); !forward_parameters.empty()) {
std::for_each(_tags_at_output.begin(), _tags_at_output.end(), [&forward_parameters](tag_t &tag) { tag.map.insert(forward_parameters.cbegin(), forward_parameters.cend()); });
std::for_each(_tags_at_output.begin(), _tags_at_output.end(), [&forward_parameters](tag_t &tag) {
for (const auto &[key, value] : forward_parameters) {
tag.map.insert_or_assign(key, value);
}
});
_output_tags_changed = true;
}
settings()._changed.store(false);
Expand Down
20 changes: 13 additions & 7 deletions include/settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ class basic_settings : public settings_base {
std::lock_guard lg(_lock);

property_map oldSettings;
if constexpr (requires(Node d, const property_map &map) { d.settings_changed(map, map); }) {
if constexpr (
requires(const property_map &cmap, property_map &map) { _node->settings_changed(/* old settings */ cmap, /* new settings */ map); } or //
requires(const property_map &cmap, property_map &map) { _node->settings_changed(/* old settings */ cmap, /* new settings */ map, /* new forward settings */ map); }) {
// take a copy of the field -> map value of the old settings
if constexpr (refl::is_reflectable<Node>()) {
auto iterate_over_member = [&, this](auto member) {
Expand All @@ -421,14 +423,16 @@ class basic_settings : public settings_base {
if constexpr (is_writable(member) && (std::integral<Type> || std::floating_point<Type> || std::is_same_v<Type, std::string> || fair::meta::vector_type<Type>) ) {
if (std::string(get_display_name(member)) == key && std::holds_alternative<Type>(staged_value)) {
member(*_node) = std::get<Type>(staged_value);
if constexpr (requires { _node->settings_changed(/* old settings */ _active, /* new settings */ staged); }) {
if constexpr (
requires { _node->settings_changed(/* old settings */ _active, /* new settings */ staged); } or //
requires { _node->settings_changed(/* old settings */ _active, /* new settings */ staged, /* new forward settings */ forward_parameters); }) {
staged.insert_or_assign(key, staged_value);
} else {
std::ignore = staged; // help clang to see why staged is not unused
}
if (_auto_forward.contains(get_display_name(member))) {
forward_parameters.insert_or_assign(key, staged_value);
}
}
if (_auto_forward.contains(key)) {
forward_parameters.insert_or_assign(key, staged_value);
}
}
};
Expand All @@ -449,9 +453,11 @@ class basic_settings : public settings_base {
}
refl::util::for_each(refl::reflect<Node>().members, iterate_over_member);

if constexpr (requires(Node d, const property_map &map) { d.settings_changed(map, map); }) {
if (!staged.empty()) {
if (!staged.empty()) {
if constexpr (requires { _node->settings_changed(/* old settings */ _active, /* new settings */ staged); }) {
_node->settings_changed(/* old settings */ oldSettings, /* new settings */ staged);
} else if constexpr (requires { _node->settings_changed(/* old settings */ _active, /* new settings */ staged, /* new forward settings */ forward_parameters); }) {
_node->settings_changed(/* old settings */ oldSettings, /* new settings */ staged, /* new forward settings */ forward_parameters);
}
}
_staged.clear();
Expand Down
78 changes: 75 additions & 3 deletions test/qa_settings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ format_variant(const auto &value) noexcept {
}
},
value);
};
}

void
printChanges(const property_map &oldMap, const property_map &newMap) noexcept {
Expand Down Expand Up @@ -92,7 +92,7 @@ struct Source : public node<Source<T>> {
float sample_rate = 1000.0f;

void
settings_changed(const property_map & /*old_settings*/, const property_map & /*new_settings*/) {
settings_changed(const property_map & /*old_settings*/, property_map & /*new_settings*/) {
// optional init function that is called after construction and whenever settings change
fair::graph::publish_tag(out, { { "n_samples_max", n_samples_max } }, static_cast<std::size_t>(n_tag_offset));
}
Expand Down Expand Up @@ -133,7 +133,7 @@ struct TestBlock : public node<TestBlock<T>, BlockingIO<true>, TestBlockDoc, Sup
bool debug = false;

void
settings_changed(const property_map &old_settings, const property_map &new_settings) noexcept {
settings_changed(const property_map &old_settings, property_map &new_settings) noexcept {
// optional function that is called whenever settings change
update_count++;

Expand All @@ -154,6 +154,48 @@ static_assert(NodeType<TestBlock<int>>);
static_assert(NodeType<TestBlock<float>>);
static_assert(NodeType<TestBlock<double>>);

template<typename T, bool Average = false>
struct Decimate : public node<Decimate<T, Average>, SupportedTypes<float, double>, PerformDecimationInterpolation, Doc<R""(
@brief reduces sample rate by given fraction controlled by denominator
)"">> {
IN<T> in{};
OUT<T> out{};
A<float, "sample rate", Visible> sample_rate = 1.f;

void
settings_changed(const property_map & /*old_settings*/, property_map &new_settings, property_map &fwd_settings) noexcept {
if (new_settings.contains(std::string(fair::graph::tag::SIGNAL_RATE.shortKey())) || new_settings.contains("denominator")) {
const float fwdSampleRate = sample_rate / static_cast<float>(this->denominator);
fwd_settings[std::string(fair::graph::tag::SIGNAL_RATE.shortKey())] = fwdSampleRate; // TODO: handle 'gr:sample_rate' vs 'sample_rate';
}
}

constexpr work_return_status_t
process_bulk(std::span<const T> input, std::span<T> output) noexcept {
assert(this->numerator == std::size_t(1) && "block implements only basic decimation");
assert(this->denominator != std::size_t(0) && "denominator must be non-zero");

auto outputIt = output.begin();
if constexpr (Average) {
for (std::size_t start = 0; start < input.size(); start += this->denominator) {
constexpr auto chunk_begin = input.begin() + start;
constexpr auto chunk_end = chunk_begin + std::min(this->denominator, std::distance(chunk_begin, input.end()));
*outputIt++ = std::reduce(chunk_begin, chunk_end, T(0)) / static_cast<T>(this->denominator);
}
} else {
for (std::size_t i = 0; i < input.size(); i += this->denominator) {
*outputIt++ = input[i];
}
}

return work_return_status_t::OK;
}
};

static_assert(NodeType<Decimate<int>>);
static_assert(NodeType<Decimate<float>>);
static_assert(NodeType<Decimate<double>>);

template<typename T>
struct Sink : public node<Sink<T>> {
IN<T> in;
Expand Down Expand Up @@ -189,6 +231,7 @@ struct Sink : public node<Sink<T>> {

ENABLE_REFLECTION_FOR_TEMPLATE_FULL((typename T), (fair::graph::setting_test::Source<T>), out, n_samples_produced, n_samples_max, n_tag_offset, sample_rate);
ENABLE_REFLECTION_FOR_TEMPLATE_FULL((typename T), (fair::graph::setting_test::TestBlock<T>), in, out, scaling_factor, context, n_samples_max, sample_rate, vector_setting);
ENABLE_REFLECTION_FOR_TEMPLATE_FULL((typename T, bool Average), (fair::graph::setting_test::Decimate<T, Average>), in, out, sample_rate);
ENABLE_REFLECTION_FOR_TEMPLATE_FULL((typename T), (fair::graph::setting_test::Sink<T>), in, n_samples_consumed, n_samples_max, last_tag_position, sample_rate);

const boost::ut::suite SettingsTests = [] {
Expand Down Expand Up @@ -360,6 +403,35 @@ const boost::ut::suite SettingsTests = [] {
(wrapped2.meta_information())["key"] = "value";
expect(eq(std::get<std::string>(wrapped2.meta_information().at("key")), "value"sv)) << "node_model meta-information";
};

"basic decimation test"_test = []() {
graph flow_graph;
constexpr std::int32_t n_samples = gr::util::round_up(1'000'000, 1024);
auto &src = flow_graph.make_node<Source<float>>({ { "n_samples_max", n_samples }, { "sample_rate", 1000.0f } });
auto &block1 = flow_graph.make_node<Decimate<float>>({ { "name", "Decimate1" }, { "denominator", std::size_t(2) } });
auto &block2 = flow_graph.make_node<Decimate<float>>({ { "name", "Decimate2" }, { "denominator", std::size_t(5) } });
auto &sink = flow_graph.make_node<Sink<float>>();

// check denominator
expect(eq(block1.denominator, std::size_t(2)));
expect(eq(block2.denominator, std::size_t(5)));

// src -> block1 -> block2 -> sink
expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(src).to<"in">(block1)));
expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block1).to<"in">(block2)));
expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block2).to<"in">(sink)));

fair::graph::scheduler::simple sched{ std::move(flow_graph) };
sched.run_and_wait();

expect(eq(src.n_samples_produced, n_samples)) << "did not produce enough output samples";
expect(eq(sink.n_samples_consumed, n_samples / (2 * 5))) << "did not consume enough input samples";

expect(eq(src.sample_rate, 1000.0f)) << "src matching sample_rate";
expect(eq(block1.sample_rate, 1000.0f)) << "block1 matching sample_rate";
expect(eq(block2.sample_rate, 500.0f)) << "block2 matching sample_rate";
expect(eq(sink.sample_rate, 100.0f)) << "sink matching src sample_rate";
};
};

const boost::ut::suite AnnotationTests = [] {
Expand Down

0 comments on commit 61e2bf8

Please sign in to comment.