Skip to content

Commit

Permalink
Allow extensions_options to accept Option field (apache#14664)
Browse files Browse the repository at this point in the history
* implement ConfigField for extension option

* fix compile

* fix doc test
  • Loading branch information
goldmedal authored Feb 16, 2025
1 parent 40bb75f commit 2238680
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 23 deletions.
83 changes: 60 additions & 23 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1241,35 +1241,72 @@ macro_rules! extensions_options {
Box::new(self.clone())
}

fn set(&mut self, key: &str, value: &str) -> $crate::Result<()> {
match key {
$(
stringify!($field_name) => {
self.$field_name = value.parse().map_err(|e| {
$crate::DataFusionError::Context(
format!(concat!("Error parsing {} as ", stringify!($t),), value),
Box::new($crate::DataFusionError::External(Box::new(e))),
)
})?;
Ok(())
}
)*
_ => Err($crate::DataFusionError::Configuration(
format!(concat!("Config value \"{}\" not found on ", stringify!($struct_name)), key)
))
}
fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> {
$crate::config::ConfigField::set(self, key, value)
}

fn entries(&self) -> Vec<$crate::config::ConfigEntry> {
vec![
struct Visitor(Vec<$crate::config::ConfigEntry>);

impl $crate::config::Visit for Visitor {
fn some<V: std::fmt::Display>(
&mut self,
key: &str,
value: V,
description: &'static str,
) {
self.0.push($crate::config::ConfigEntry {
key: key.to_string(),
value: Some(value.to_string()),
description,
})
}

fn none(&mut self, key: &str, description: &'static str) {
self.0.push($crate::config::ConfigEntry {
key: key.to_string(),
value: None,
description,
})
}
}

let mut v = Visitor(vec![]);
// The prefix is not used for extensions.
// The description is generated in ConfigField::visit.
// We can just pass empty strings here.
$crate::config::ConfigField::visit(self, &mut v, "", "");
v.0
}
}

impl $crate::config::ConfigField for $struct_name {
fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> {
let (key, rem) = key.split_once('.').unwrap_or((key, ""));
match key {
$(
$crate::config::ConfigEntry {
key: stringify!($field_name).to_owned(),
value: (self.$field_name != $default).then(|| self.$field_name.to_string()),
description: concat!($($d),*).trim(),
stringify!($field_name) => {
// Safely apply deprecated attribute if present
// $(#[allow(deprecated)])?
{
#[allow(deprecated)]
self.$field_name.set(rem, value.as_ref())
}
},
)*
]
_ => return $crate::error::_config_err!(
"Config value \"{}\" not found on {}", key, stringify!($struct_name)
)
}
}

fn visit<V: $crate::config::Visit>(&self, v: &mut V, _key_prefix: &str, _description: &'static str) {
$(
let key = stringify!($field_name).to_string();
let desc = concat!($($d),*).trim();
#[allow(deprecated)]
self.$field_name.visit(v, key.as_str(), desc);
)*
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions datafusion/execution/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ mod tests {
extensions_options! {
struct TestExtension {
value: usize, default = 42
option_value: Option<usize>, default = None
}
}

Expand All @@ -229,6 +230,7 @@ mod tests {

let mut config = ConfigOptions::new().with_extensions(extensions);
config.set("test.value", "24")?;
config.set("test.option_value", "42")?;
let session_config = SessionConfig::from(config);

let task_context = TaskContext::new(
Expand All @@ -249,6 +251,39 @@ mod tests {
assert!(test.is_some());

assert_eq!(test.unwrap().value, 24);
assert_eq!(test.unwrap().option_value, Some(42));

Ok(())
}

#[test]
fn task_context_extensions_default() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let mut extensions = Extensions::new();
extensions.insert(TestExtension::default());

let config = ConfigOptions::new().with_extensions(extensions);
let session_config = SessionConfig::from(config);

let task_context = TaskContext::new(
Some("task_id".to_string()),
"session_id".to_string(),
session_config,
HashMap::default(),
HashMap::default(),
HashMap::default(),
runtime,
);

let test = task_context
.session_config()
.options()
.extensions
.get::<TestExtension>();
assert!(test.is_some());

assert_eq!(test.unwrap().value, 42);
assert_eq!(test.unwrap().option_value, None);

Ok(())
}
Expand Down

0 comments on commit 2238680

Please sign in to comment.