diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 6f95e429f6..cae745f97e 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -686,6 +686,8 @@ fn write_output( params: &Parameters, output_path: &str, ) -> anyhow::Result<()> { + use naga::back::pipeline_constants::ProcessOverridesOutput; + match Path::new(&output_path) .extension() .ok_or(CliError("Output filename has no extension"))? @@ -717,9 +719,14 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let ProcessOverridesOutput { module, info, .. } = + naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let pipeline_options = msl::PipelineOptions::default(); let (msl, _) = @@ -751,9 +758,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let spv = spv::write_vec(&module, &info, ¶ms.spv_out, pipeline_options).unwrap_pretty(); @@ -788,9 +803,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let mut buffer = String::new(); let mut writer = glsl::Writer::new( @@ -819,9 +842,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let mut buffer = String::new(); let pipeline_options = Default::default(); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7bc8289b9b..f34d2b23a8 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -52,7 +52,12 @@ use alloc::{ }; use core::fmt::{Error as FmtError, Write}; -use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo}; +use crate::{ + arena::Handle, + ir, + proc::index, + valid::{ModuleInfo, UnresolvedOverrides}, +}; mod keywords; pub mod sampler; @@ -431,6 +436,14 @@ pub struct PipelineOptions { /// point is not found, an error will be thrown while writing. pub entry_point: Option<(ir::ShaderStage, String)>, + /// Information about unresolved overrides. + /// + /// This struct is returned by `process_overrides`. It tells the writer + /// which items to omit from the output because they are not used and refer + /// to overrides that were not resolved to a concrete value. + #[cfg_attr(feature = "serialize", serde(skip))] + pub unresolved_overrides: UnresolvedOverrides, + /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. /// /// Metal doesn't like this for non-point primitive topologies and requires it for diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233a..a3b5cfb716 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -21,9 +21,10 @@ use crate::{ proc::{ self, index::{self, BoundsCheck}, - NameKey, TypeResolution, + NameKey, ResolveArraySizeError, TypeResolution, }, - valid, FastHashMap, FastHashSet, + valid::{self, UnresolvedOverrides}, + FastHashMap, FastHashSet, }; #[cfg(test)] @@ -436,6 +437,7 @@ pub struct Writer { /// Set of (struct type, struct field index) denoting which fields require /// padding inserted **before** them (i.e. between fields at index - 1 and index) struct_member_pads: FastHashSet<(Handle, u32)>, + unresolved_overrides: Option, } impl crate::Scalar { @@ -775,6 +777,7 @@ impl Writer { #[cfg(test)] put_block_stack_pointers: Default::default(), struct_member_pads: FastHashSet::default(), + unresolved_overrides: None, } } @@ -4032,6 +4035,7 @@ impl Writer { ); self.wrapped_functions.clear(); self.struct_member_pads.clear(); + self.unresolved_overrides = Some(pipeline_options.unresolved_overrides.clone()); writeln!( self.out, @@ -4216,8 +4220,20 @@ impl Writer { first_time: false, }; - match size.resolve(module.to_ctx())? { - proc::IndexableLength::Known(size) => { + match size.resolve(module.to_ctx()) { + Err(ResolveArraySizeError::NonConstArrayLength) => { + // The array size was never resolved. This _should_ + // be because it is an override expression and the + // type is not needed for the entry point being + // written. + // TODO: do we want to assemble `UnresolvedOverrides.types` to make this safer? + // (And if so, do we also want to validate that those types are truly unused?) + continue; + } + Err(err @ ResolveArraySizeError::ExpectedPositiveArrayLength) => { + return Err(err.into()); + } + Ok(proc::IndexableLength::Known(size)) => { writeln!(self.out, "struct {name} {{")?; writeln!( self.out, @@ -4229,7 +4245,7 @@ impl Writer { )?; writeln!(self.out, "}};")?; } - proc::IndexableLength::Dynamic => { + Ok(proc::IndexableLength::Dynamic) => { writeln!(self.out, "typedef {base_name} {name}[1];")?; } } @@ -5757,6 +5773,17 @@ template fun_handle ); + if self + .unresolved_overrides + .as_ref() + .unwrap() + .functions + .contains_key(&fun_handle) + { + log::trace!("skipping due to unresolved overrides"); + continue; + } + let ctx = back::FunctionCtx { ty: back::FunctionType::Function(fun_handle), info: &mod_info[fun_handle], @@ -5880,6 +5907,19 @@ template }; for ep_index in ep_range { + if self + .unresolved_overrides + .as_ref() + .unwrap() + .entry_points + .contains_key(&ep_index) + { + log::error!( + "must write the same entry point that was passed to `process_overrides`" + ); + return Err(Error::Override); + } + let ep = &module.entry_points[ep_index]; let fun = &ep.function; let fun_info = mod_info.get_entry_point(ep_index); @@ -6288,7 +6328,15 @@ template // within the entry point. for (handle, var) in module.global_variables.iter() { let usage = fun_info[handle]; - if usage.is_empty() || var.space == crate::AddressSpace::Private { + if usage.is_empty() + || var.space == crate::AddressSpace::Private + || self + .unresolved_overrides + .as_ref() + .unwrap() + .global_variables + .contains_key(&handle) + { continue; } @@ -6942,9 +6990,20 @@ mod workgroup_mem_init { let mut access_stack = AccessStack::new(); - let vars = module.global_variables.iter().filter(|&(handle, var)| { - !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }); + let vars = module + .global_variables + .iter() + .filter(|&(handle, var)| { + !fun_info[handle].is_empty() + && var.space == crate::AddressSpace::WorkGroup + && !self + .unresolved_overrides + .as_ref() + .unwrap() + .global_variables + .contains_key(&handle) + }) + .collect::>(); for (handle, var) in vars { access_stack.enter(Access::GlobalVariable(handle), |access_stack| { diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 1cf1c80524..790e783ab7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -10,10 +10,14 @@ use thiserror::Error; use super::PipelineConstants; use crate::{ arena::HandleVec, - proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, - valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, - Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, - Span, Statement, TypeInner, WithSpan, + ir, + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter, U32EvalError}, + valid::{ + Capabilities, FunctionInfo, ModuleInfo, UnresolvedOverrides, ValidationError, + ValidationFlags, Validator, + }, + Arena, Block, Constant, Expression, FastHashMap, Function, Handle, Literal, Module, Override, + Range, Scalar, Span, Statement, TypeInner, WithSpan, }; #[cfg(no_std)] @@ -37,28 +41,89 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, + #[error("unable to evaluate workgroup_size override")] + WorkgroupSizeOverrideEvaluationError, } -/// Replace all overrides in `module` with constants. +// Returns the key to use for an override in `pipeline_constants`. +fn override_key(ov: &Override) -> Cow<'_, str> { + if let Some(id) = ov.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = ov.name { + Cow::Borrowed(name) + } else { + unreachable!() + } +} + +#[derive(Debug)] +pub struct ProcessOverridesOutput<'a> { + pub module: Cow<'a, Module>, + pub info: Cow<'a, ModuleInfo>, + pub unresolved: UnresolvedOverrides, +} + +/// Check the global usage in `fun_info` for any globals affected by unresolved +/// overrides. +/// +/// If any is found, returns `Some`, otherwise returns `None`. +fn check_for_unresolved_global_use<'a>( + globals: impl Iterator, &'a ir::GlobalVariable)>, + unresolved: &UnresolvedOverrides, + fun_info: &FunctionInfo, +) -> Option> { + for (var_handle, _) in globals { + match unresolved.global_variables.get(&var_handle) { + Some(&o_handle) if !fun_info[var_handle].is_empty() => { + return Some(o_handle); + } + _ => {} + } + } + None +} + +/// Replace overrides in `module` with constants. /// /// If no changes are needed, this just returns `Cow::Borrowed` /// references to `module` and `module_info`. Otherwise, it clones -/// `module`, edits its [`global_expressions`] arena to contain only -/// fully-evaluated expressions, and returns `Cow::Owned` values -/// holding the simplified module and its validation results. +/// `module`, updates it with evaluated override expressions, and returns +/// `Cow::Owned` values holding the simplified module and its validation +/// results. +/// +/// If `entry_point` is specified, then any override referenced by +/// that entry point must be supplied, and other overrides are +/// optional. The returned module may still have override expressions, +/// but they should not be reachable from the specified entry point. +/// +/// If `entry_point` is not specified, then all overrides must be specified. +/// +/// This function completely rewrites both the [`global`] and function-local +/// arenas, replacing [`Expression::Override`] with [`Expression::Constant`]. +/// It then updates expressions, statements, and initializers that refer to a +/// an updated expression. /// -/// In either case, the module returned has an empty `overrides` -/// arena, and the `global_expressions` arena contains only -/// fully-evaluated expressions. +/// The types arena is not updated. This means that the size of an array (in the +/// workgroup space, because this is the only place override-sized arrays are +/// permitted) may still require indirection through an override handle to the +/// initializer expression, which will be an evaluated constant. See +/// [#6787](https://github.com/gfx-rs/wgpu/pull/6787). /// -/// [`global_expressions`]: Module::global_expressions +/// [`global`]: Module::global_expressions pub fn process_overrides<'a>( module: &'a Module, module_info: &'a ModuleInfo, + entry_point: Option<(ir::ShaderStage, &str)>, pipeline_constants: &PipelineConstants, -) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { +) -> Result, PipelineConstantError> { + let mut unresolved = UnresolvedOverrides::default(); + if module.overrides.is_empty() { - return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); + return Ok(ProcessOverridesOutput { + module: Cow::Borrowed(module), + info: Cow::Borrowed(module_info), + unresolved, + }); } let mut module = module.clone(); @@ -84,6 +149,7 @@ pub fn process_overrides<'a>( let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + let mut global_expressions_missing_overrides = FastHashMap::default(); let mut layouter = crate::proc::Layouter::default(); // An iterator through the original overrides table, consumed in @@ -93,10 +159,9 @@ pub fn process_overrides<'a>( // Do two things in tandem: // - // - Rebuild the global expression arena from scratch, fully - // evaluating all expressions, and replacing each `Override` - // expression in `module.global_expressions` with a `Constant` - // expression. + // - Rebuild the global expression arena from scratch, replacing + // `Override` expressions in `module.global_expressions` that can + // now be evaluated with `Constant` expressions. // // - Build a new `Constant` in `module.constants` to take the // place of each `Override`. @@ -123,28 +188,43 @@ pub fn process_overrides<'a>( for (old_h, expr, span) in module.global_expressions.drain() { let mut expr = match expr { Expression::Override(h) => { - let c_h = if let Some(new_h) = override_map.get(h) { - *new_h - } else { - let mut new_h = None; - for entry in override_iter.by_ref() { - let stop = entry.0 == h; - new_h = Some(process_override( - entry, - pipeline_constants, - &mut module, - &mut override_map, - &adjusted_global_expressions, - &mut adjusted_constant_initializers, - &mut global_expression_kind_tracker, - )?); - if stop { - break; + match override_map.get(h) { + Some(&Some(new_h)) => { + // Already evaluated. + Expression::Constant(new_h) + } + Some(&None) => { + // Already processed and could not evaluate. Leave + // expression unchanged. + expr + } + None => { + let mut result = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + result = process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + if stop { + break; + } + } + match result { + None => { + // Could not evaluate. Leave expression + // unchanged. + expr + } + Some(new_h) => Expression::Constant(new_h), } } - new_h.unwrap() - }; - Expression::Constant(c_h) + } } Expression::Constant(c_h) => { if adjusted_constant_initializers.insert(c_h) { @@ -155,6 +235,9 @@ pub fn process_overrides<'a>( } expr => expr, }; + // Attempt constant evaluation now that overrides referenced by this + // expression may have been resolved. If they have not been resolved, + // the expression will remain with `ExpressionKind::Override`. let mut evaluator = ConstantEvaluator::for_wgsl_module( &mut module, &mut global_expression_kind_tracker, @@ -162,8 +245,17 @@ pub fn process_overrides<'a>( false, ); adjust_expr(&adjusted_global_expressions, &mut expr); - let h = evaluator.try_eval_and_append(expr, span)?; - adjusted_global_expressions.insert(old_h, h); + match evaluator.try_eval_and_append(expr, span) { + Err((expr, ConstantEvaluatorError::Override(ov_h))) => { + let h = module.global_expressions.append(expr, span); + global_expression_kind_tracker.insert(h, crate::proc::ExpressionKind::Override); + adjusted_global_expressions.insert(old_h, h); + global_expressions_missing_overrides.insert(h, ov_h); + log::trace!("global {:?} initializer missing override {:?}", h, ov_h); + } + Err((_, e)) => return Err(e.into()), + Ok(h) => adjusted_global_expressions.insert(old_h, h), + } } // Finish processing any overrides we didn't visit in the loop above. @@ -184,6 +276,9 @@ pub fn process_overrides<'a>( init: Some(ref mut init), .. } => { + // Anonymous override representing by an array size expression. + // These are not handled through `process_override`, are not + // replaced by a constant, and are not added to `override_map`. *init = adjusted_global_expressions[*init]; } _ => {} @@ -201,23 +296,138 @@ pub fn process_overrides<'a>( c.init = adjusted_global_expressions[c.init]; } - for (_, v) in module.global_variables.iter_mut() { + // Identify which global variables are still unusable due to missing + // overrides. Overrides can appear in the initializer, and in the + // case of workgroup space arrays, in the array size. + for (v_handle, v) in module.global_variables.iter_mut() { if let Some(ref mut init) = v.init { *init = adjusted_global_expressions[*init]; + if let Some(&o_handle) = global_expressions_missing_overrides.get(init) { + log::trace!( + "global {:?} initializer missing override {:?}", + v.name, + overrides[o_handle].name + ); + unresolved.global_variables.insert(v_handle, o_handle); + } + } else if let TypeInner::Array { + size: crate::ArraySize::Pending(o_handle), + .. + } = module.types[v.ty].inner + { + let resolved = match override_map.get(o_handle) { + Some(&Some(_)) => { + // Override was processed successfully. + true + } + Some(&None) => { + // Override could not be processed. + false + } + None => { + // Anonymous override for array size expression + // These are not added to override_map. + match overrides[o_handle].init { + Some(init) => global_expression_kind_tracker.is_const(init), + None => { + // This should not happen. + log::error!("anonymous override with no initializer?"); + true + } + } + } + }; + if !resolved { + log::trace!( + "array size of global {:?} missing override {:?}", + v.name, + overrides[o_handle].name + ); + unresolved.global_variables.insert(v_handle, o_handle); + } } } + // Process functions, taking note of which ones require overrides that were + // not specified. Like expressions, callees are guaranteed to appear before + // their callers. let mut functions = mem::take(&mut module.functions); - for (_, function) in functions.iter_mut() { - process_function(&mut module, &override_map, &mut layouter, function)?; + for (f_handle, function) in functions.iter_mut() { + let result = if let Some(o_handle) = process_function( + &mut module, + &override_map, + &unresolved.functions, + &mut layouter, + function, + )? { + log::trace!( + "function {:?} missing override {:?}", + function.name, + overrides[o_handle].name + ); + Some(o_handle) + } else { + check_for_unresolved_global_use( + module.global_variables.iter(), + &unresolved, + &module_info[f_handle], + ) + }; + if let Some(o_handle) = result { + unresolved.functions.insert(f_handle, o_handle); + } } module.functions = functions; + // Process entry points let mut entry_points = mem::take(&mut module.entry_points); - for ep in entry_points.iter_mut() { - process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; - process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; + for (ep_index, ep) in entry_points.iter_mut().enumerate() { + let result = if let Some(o_handle) = process_function( + &mut module, + &override_map, + &unresolved.functions, + &mut layouter, + &mut ep.function, + )? { + log::trace!( + "entry point {} missing override {:?}", + ep.name, + overrides[o_handle].name + ); + Some(o_handle) + } else if let Some(o_handle) = + process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)? + { + Some(o_handle) + } else { + check_for_unresolved_global_use( + module.global_variables.iter(), + &unresolved, + module_info.get_entry_point(ep_index), + ) + }; + if let Some(o_handle) = result { + // We found a missing override that is required by this entry point. + // Decide whether that is an error. + match entry_point { + Some((tgt_stage, tgt_name)) if ep.stage != tgt_stage || ep.name != tgt_name => { + // An entry point was specified, and we are not currently + // processing that one, so it is okay not to have this + // override. + unresolved.entry_points.insert(ep_index, o_handle); + } + _ => { + // Either we are missing an override for the active entry point, + // or no entry point was specified. Either way, this override + // is required. + return Err(PipelineConstantError::MissingValue( + override_key(&overrides[o_handle]).to_string(), + )); + } + } + } } + module.entry_points = entry_points; module.overrides = overrides; @@ -225,66 +435,84 @@ pub fn process_overrides<'a>( // recompute their types and other metadata. For the time being, // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); - let module_info = validator.validate_resolved_overrides(&module)?; + let module_info = validator.validate_with_resolved_overrides(&module, &unresolved)?; - Ok((Cow::Owned(module), Cow::Owned(module_info))) + Ok(ProcessOverridesOutput { + module: Cow::Owned(module), + info: Cow::Owned(module_info), + unresolved, + }) } +/// Process override expressions in the WGSL `@workgroup_size` attribute. +/// +/// If all expressions are resolved, returns `Ok(None)`. If any expression could +/// not be resolved due to missing override values, returns `Ok(Some(handle))`, +/// with the handle of the first identified missing override. The caller is +/// responsible for determining whether translation can proceed despite the +/// missing override. fn process_workgroup_size_override( module: &mut Module, adjusted_global_expressions: &HandleVec>, ep: &mut crate::EntryPoint, -) -> Result<(), PipelineConstantError> { +) -> Result>, PipelineConstantError> { match ep.workgroup_size_overrides { None => {} Some(overrides) => { - overrides.iter().enumerate().try_for_each( - |(i, overridden)| -> Result<(), PipelineConstantError> { - match *overridden { - None => Ok(()), - Some(h) => { - ep.workgroup_size[i] = module - .to_ctx() - .eval_expr_to_u32(adjusted_global_expressions[h]) - .map(|n| { - if n == 0 { - Err(PipelineConstantError::NegativeWorkgroupSize) - } else { - Ok(n) - } - }) - .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??; - Ok(()) + for (ov_index, ov) in overrides.iter().enumerate() { + match *ov { + None => continue, + Some(h) => { + match module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[h]) + { + Ok(n) => { + if n == 0 { + return Err(PipelineConstantError::NegativeWorkgroupSize); + } else { + ep.workgroup_size[ov_index] = n; + } + } + Err(U32EvalError::Override(handle)) => { + return Ok(Some(handle)); + } + Err(U32EvalError::Runtime) => { + return Err( + PipelineConstantError::WorkgroupSizeOverrideEvaluationError, + ); + } + Err(U32EvalError::Negative) => { + return Err(PipelineConstantError::NegativeWorkgroupSize); + } } } - }, - )?; + } + } ep.workgroup_size_overrides = None; } } - Ok(()) + Ok(None) } -/// Add a [`Constant`] to `module` for the override `old_h`. +/// If a value for the override `old_h` is given in `self.pipeline_constants`, +/// add a [`Constant`] for that override to `module`. /// -/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. +/// If a value is found, adds the new `Constant` to `override_map` and +/// `adjusted_constant_initializers`, and returns it. +/// +/// If no value is found, returns `Ok(None)`. The caller is responsible for +/// determining whether translation can proceed despite the missing override. fn process_override( (old_h, r#override, span): (Handle, &mut Override, &Span), pipeline_constants: &PipelineConstants, module: &mut Module, - override_map: &mut HandleVec>, + override_map: &mut HandleVec>>, adjusted_global_expressions: &HandleVec>, adjusted_constant_initializers: &mut HashSet>, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, -) -> Result, PipelineConstantError> { - // Determine which key to use for `r#override` in `pipeline_constants`. - let key = if let Some(id) = r#override.id { - Cow::Owned(id.to_string()) - } else if let Some(ref name) = r#override.name { - Cow::Borrowed(name) - } else { - unreachable!(); - }; +) -> Result>, PipelineConstantError> { + let key = override_key(r#override); // Generate a global expression for `r#override`'s value, either // from the provided `pipeline_constants` table or its initializer @@ -299,10 +527,16 @@ fn process_override( .append(Expression::Literal(literal), Span::UNDEFINED); global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); expr - } else if let Some(init) = r#override.init { - adjusted_global_expressions[init] } else { - return Err(PipelineConstantError::MissingValue(key.to_string())); + match r#override.init { + Some(init) if global_expression_kind_tracker.is_const(init) => { + adjusted_global_expressions[init] + } + _ => { + override_map.insert(old_h, None); + return Ok(None); + } + } }; // Generate a new `Constant` to represent the override's value. @@ -312,27 +546,35 @@ fn process_override( init, }; let h = module.constants.append(constant, *span); - override_map.insert(old_h, h); + override_map.insert(old_h, Some(h)); adjusted_constant_initializers.insert(h); r#override.init = Some(init); - Ok(h) + Ok(Some(h)) } -/// Replace all override expressions in `function` with fully-evaluated constants. +/// Replace override expressions in `function` with fully-evaluated constants. /// -/// Replace all `Expression::Override`s in `function`'s expression arena with +/// Replace `Expression::Override`s in `function`'s expression arena with /// the corresponding `Expression::Constant`s, as given in `override_map`. /// Replace any expressions whose values are now known with their fully /// evaluated form. /// /// If `h` is a `Handle`, then `override_map[h]` is the /// `Handle` for the override's final value. +/// +/// If all override expressions are replaced, returns `Ok(None)`. If any +/// expression could not be replaced due to missing override values, or if +/// the function calls another function that is present in +/// `functions_missing_overrides`, returns `Ok(Some(handle))`, with the handle +/// of the first identified missing override. The caller is responsible for +/// determining whether translation can proceed despite the missing override. fn process_function( module: &mut Module, - override_map: &HandleVec>, + override_map: &HandleVec>>, + functions_missing_overrides: &FastHashMap, Handle>, layouter: &mut crate::proc::Layouter, function: &mut Function, -) -> Result<(), ConstantEvaluatorError> { +) -> Result>, ConstantEvaluatorError> { // A map from original local expression handles to // handles in the new, local expression arena. let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len()); @@ -341,6 +583,8 @@ fn process_function( let mut expressions = mem::take(&mut function.expressions); + let mut missing_override = None; + // Dummy `emitter` and `block` for the constant evaluator. // We can ignore the concept of emitting expressions here since // expressions have already been covered by a `Statement::Emit` @@ -363,14 +607,29 @@ fn process_function( for (old_h, mut expr, span) in expressions.drain() { if let Expression::Override(h) = expr { - expr = Expression::Constant(override_map[h]); + if let Some(&Some(const_h)) = override_map.get(h) { + expr = Expression::Constant(const_h); + } else if missing_override.is_none() { + missing_override = Some(h); + } } adjust_expr(&adjusted_local_expressions, &mut expr); - let h = evaluator.try_eval_and_append(expr, span)?; + let h = evaluator + .try_eval_and_append(expr, span) + .map_err(|(_expr, err)| err)?; adjusted_local_expressions.insert(old_h, h); } - adjust_block(&adjusted_local_expressions, &mut function.body); + match adjust_block( + &adjusted_local_expressions, + functions_missing_overrides, + &mut function.body, + ) { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } filter_emits_in_block(&mut function.body, &function.expressions); @@ -390,7 +649,7 @@ fn process_function( .insert(adjusted_local_expressions[expr_h], name); } - Ok(()) + Ok(missing_override) } /// Replace every expression handle in `expr` with its counterpart @@ -606,15 +865,39 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E /// Replace every expression handle in `block` with its counterpart /// given by `new_pos`. -fn adjust_block(new_pos: &HandleVec>, block: &mut Block) { +/// +/// On success, returns `Ok(None)`. If `block` calls a function that is present +/// in `functions_missing_overrides`, returns `Ok(Some(handle))`, with the +/// handle of the first identified missing override. +fn adjust_block( + new_pos: &HandleVec>, + functions_missing_overrides: &FastHashMap, Handle>, + block: &mut Block, +) -> Option> { + let mut missing_override = None; for stmt in block.iter_mut() { - adjust_stmt(new_pos, stmt); + match adjust_stmt(new_pos, functions_missing_overrides, stmt) { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } } + missing_override } /// Replace every expression handle in `stmt` with its counterpart /// given by `new_pos`. -fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut Statement) { +/// +/// On success, returns `Ok(None)`. If `stmt` calls a function that is present +/// in `functions_missing_overrides`, returns `Ok(Some(handle))`, with the +/// handle of the first identified missing override. +fn adjust_stmt( + new_pos: &HandleVec>, + functions_missing_overrides: &FastHashMap, Handle>, + stmt: &mut Statement, +) -> Option> { + let mut missing_override = None; let adjust = |expr: &mut Handle| { *expr = new_pos[*expr]; }; @@ -627,7 +910,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } } Statement::Block(ref mut block) => { - adjust_block(new_pos, block); + adjust_block(new_pos, functions_missing_overrides, block); } Statement::If { ref mut condition, @@ -635,8 +918,8 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S ref mut reject, } => { adjust(condition); - adjust_block(new_pos, accept); - adjust_block(new_pos, reject); + adjust_block(new_pos, functions_missing_overrides, accept); + adjust_block(new_pos, functions_missing_overrides, reject); } Statement::Switch { ref mut selector, @@ -644,7 +927,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } => { adjust(selector); for case in cases.iter_mut() { - adjust_block(new_pos, &mut case.body); + adjust_block(new_pos, functions_missing_overrides, &mut case.body); } } Statement::Loop { @@ -652,8 +935,8 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S ref mut continuing, ref mut break_if, } => { - adjust_block(new_pos, body); - adjust_block(new_pos, continuing); + adjust_block(new_pos, functions_missing_overrides, body); + adjust_block(new_pos, functions_missing_overrides, continuing); if let Some(e) = break_if.as_mut() { adjust(e); } @@ -769,8 +1052,14 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S Statement::Call { ref mut arguments, ref mut result, - function: _, + function, } => { + match functions_missing_overrides.get(&function).copied() { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } for argument in arguments.iter_mut() { adjust(argument); } @@ -803,6 +1092,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} } + missing_override } /// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index e6c5546fb9..b151d7af8e 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -277,10 +277,11 @@ impl<'a> Context<'a> { ) }; - eval.try_eval_and_append(expr, meta).map_err(|e| Error { - kind: e.into(), - meta, - }) + eval.try_eval_and_append(expr, meta) + .map_err(|(_expr, err)| Error { + kind: err.into(), + meta, + }) } /// Add variable to current scope diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index 2eb3ec4b00..290bcc5443 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -219,7 +219,7 @@ impl<'source> ParsingContext<'source> { kind: ErrorKind::SemanticError("int constant overflows".into()), meta, }), - Err(U32EvalError::NonConst) => Err(Error { + Err(U32EvalError::Runtime | U32EvalError::Override(_)) => Err(Error { kind: ErrorKind::SemanticError("Expected a uint constant".into()), meta, }), diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 93ccb7143c..35c2e8a69c 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -520,7 +520,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<'source, Handle> { let mut eval = self.as_const_evaluator(); eval.try_eval_and_append(expr, span) - .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) + .map_err(|(_expr, err)| Box::new(Error::ConstantEvaluatorError(err.into(), span))) } fn const_eval_expr_to_u32( @@ -530,7 +530,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { if !ctx.local_expression_kind_tracker.is_const(handle) { - return Err(proc::U32EvalError::NonConst); + return Err(proc::U32EvalError::Runtime); } self.module @@ -544,7 +544,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .eval_expr_to_u32_from(handle, &ctx.function.expressions) } ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle), - ExpressionContextType::Override => Err(proc::U32EvalError::NonConst), + ExpressionContextType::Override => Err(proc::U32EvalError::Runtime), } } @@ -628,7 +628,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .to_ctx() .eval_expr_to_u32_from(expr, &rctx.function.expressions) .map_err(|err| match err { - proc::U32EvalError::NonConst => { + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { Error::ExpectedConstExprConcreteIntegerScalar(component_span) } proc::U32EvalError::Negative => Error::ExpectedNonNegative(component_span), @@ -1431,7 +1431,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Err(err) => { if let Error::ConstantEvaluatorError(ref ty, _) = *err { match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { + proc::ConstantEvaluatorError::Override(_) => { workgroup_size_overrides_out[i] = Some(self.workgroup_size_override( size_expr, @@ -1739,12 +1739,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .to_ctx() .eval_expr_to_literal_from(expr, &ctx.function.expressions) { - Some(ir::Literal::I32(value)) => { - ir::SwitchValue::I32(value) - } - Some(ir::Literal::U32(value)) => { - ir::SwitchValue::U32(value) - } + Ok(ir::Literal::I32(value)) => ir::SwitchValue::I32(value), + Ok(ir::Literal::U32(value)) => ir::SwitchValue::U32(value), _ => { return Err(Box::new(Error::InvalidSwitchCase { span, @@ -3587,7 +3583,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .to_ctx() .eval_expr_to_u32(expr) .map_err(|err| match err { - proc::U32EvalError::NonConst => Error::ExpectedConstExprConcreteIntegerScalar(span), + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), })?; Ok((value, span)) @@ -3606,7 +3604,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(value) => { let len = ctx.const_eval_expr_to_u32(value).map_err(|err| { Box::new(match err { - proc::U32EvalError::NonConst => { + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { Error::ExpectedConstExprConcreteIntegerScalar(span) } proc::U32EvalError::Negative => { @@ -3621,7 +3619,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Err(err) => { if let Error::ConstantEvaluatorError(ref ty, _) = *err { match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { + proc::ConstantEvaluatorError::Override(_) => { ir::ArraySize::Pending(self.array_size_override( expr, &mut ctx.as_global().as_override(), diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 27d6addc82..573cc12413 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -558,11 +558,11 @@ pub enum ConstantEvaluatorError { #[error(transparent)] Literal(#[from] crate::valid::LiteralError), #[error("Can't use pipeline-overridable constants in const-expressions")] - Override, + Override(Handle), #[error("Unexpected runtime-expression")] RuntimeExpr, - #[error("Unexpected override-expression")] - OverrideExpr, + #[error("Unexpectedly able to evaluate an override expression")] + EvaluatedOverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -740,7 +740,8 @@ impl<'a> ConstantEvaluator<'a> { /// contributing to some function's expression arena, then append `expr` to /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be /// contributing to the module's constant expression arena; since `expr`'s - /// value is not a constant, return an error. + /// value is not a constant, return an error (along with the original + /// expression, in case the caller needs it). /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -755,7 +756,7 @@ impl<'a> ConstantEvaluator<'a> { &mut self, expr: Expression, span: Span, - ) -> Result, ConstantEvaluatorError> { + ) -> Result, (Expression, ConstantEvaluatorError)> { match self.expression_kind_tracker.type_of_with_expr(&expr) { ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); @@ -772,7 +773,7 @@ impl<'a> ConstantEvaluator<'a> { { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { - eval_result + eval_result.map_err(|err| (expr, err)) } } ExpressionKind::Override => match self.behavior { @@ -780,7 +781,19 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.append_expr(expr, span, ExpressionKind::Override)) } Behavior::Wgsl(WgslRestrictions::Const(_)) => { - Err(ConstantEvaluatorError::OverrideExpr) + // We should always get `ConstantEvaluatorError::Override` + // here. If we get something else, then it's probably a bug + // in the expression kind determination. We attempt evaluation + // here in order to identify the overrides that would be + // required to evaluate this expression, for use in diagnostics. + match self.try_eval_and_append_impl(&expr, span) { + Err(ov_err @ ConstantEvaluatorError::Override(_)) => Err((expr, ov_err)), + Err(err) => { + log::debug!("expected an override error, but got {:?}", err); + Err((expr, err)) + } + Ok(_) => Err((expr, ConstantEvaluatorError::EvaluatedOverrideExpr)), + } } Behavior::Glsl(_) => { unreachable!() @@ -790,7 +803,7 @@ impl<'a> ConstantEvaluator<'a> { if self.behavior.has_runtime_restrictions() { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { - Err(ConstantEvaluatorError::RuntimeExpr) + Err((expr, ConstantEvaluatorError::RuntimeExpr)) } } } @@ -830,7 +843,7 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } - Expression::Override(_) => Err(ConstantEvaluatorError::Override), + Expression::Override(ov) => Err(ConstantEvaluatorError::Override(ov)), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 0843e709b5..2b03f18dcb 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -414,8 +414,14 @@ impl crate::Module { } #[derive(Debug)] +#[cfg_attr(not(any(hlsl_out, msl_out, spv_out, glsl_out)), allow(dead_code))] pub(super) enum U32EvalError { - NonConst, + /// Expression is not constant. + Runtime, + + /// Expression is not constant because the indicated override value is not supplied. + Override(crate::Handle), + Negative, } @@ -444,11 +450,10 @@ impl GlobalCtx<'_> { arena: &crate::Arena, ) -> Result { match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::U32(value)) => Ok(value), - Some(crate::Literal::I32(value)) => { - value.try_into().map_err(|_| U32EvalError::Negative) - } - _ => Err(U32EvalError::NonConst), + Ok(crate::Literal::U32(value)) => Ok(value), + Ok(crate::Literal::I32(value)) => value.try_into().map_err(|_| U32EvalError::Negative), + Err(Some(ov_handle)) => Err(U32EvalError::Override(ov_handle)), + _ => Err(U32EvalError::Runtime), } } @@ -460,7 +465,7 @@ impl GlobalCtx<'_> { arena: &crate::Arena, ) -> Option { match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::Bool(value)) => Some(value), + Ok(crate::Literal::Bool(value)) => Some(value), _ => None, } } @@ -469,7 +474,7 @@ impl GlobalCtx<'_> { pub(crate) fn eval_expr_to_literal( &self, handle: crate::Handle, - ) -> Option { + ) -> Result>> { self.eval_expr_to_literal_from(handle, self.global_expressions) } @@ -477,25 +482,26 @@ impl GlobalCtx<'_> { &self, handle: crate::Handle, arena: &crate::Arena, - ) -> Option { + ) -> Result>> { fn get( gctx: GlobalCtx, handle: crate::Handle, arena: &crate::Arena, - ) -> Option { + ) -> Result>> { match arena[handle] { - crate::Expression::Literal(literal) => Some(literal), + crate::Expression::Literal(literal) => Ok(literal), crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner { - crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar), - _ => None, + crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar).ok_or(None), + _ => Err(None), }, - _ => None, + _ => Err(None), } } match arena[handle] { crate::Expression::Constant(c) => { get(*self, self.constants[c].init, self.global_expressions) } + crate::Expression::Override(handle) => Err(Some(handle)), _ => get(*self, handle, arena), } } @@ -531,7 +537,9 @@ impl crate::ArraySize { return Err(ResolveArraySizeError::NonConstArrayLength); }; let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err { - U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength, + U32EvalError::Runtime | U32EvalError::Override(_) => { + ResolveArraySizeError::NonConstArrayLength + } U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength, })?; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 63de450372..6466ca45d9 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -139,6 +139,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), + #[error("Missing value for pipeline-overridable constant {0:?}")] + UnresolvedOverride(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -194,7 +196,7 @@ impl core::ops::Index> for ExpressionTypeResolver<'_> } } -impl super::Validator { +impl super::Validator<'_> { pub(super) fn validate_const_expression( &self, handle: Handle, @@ -224,7 +226,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ if global_expr_kind.is_const(handle) || self.overrides_resolved => { + _ if global_expr_kind.is_const(handle) => { return Err(ConstExpressionError::NonFullyEvaluatedConst) } // the constant evaluator will report errors about override-expressions @@ -302,7 +304,9 @@ impl super::Validator { Err(crate::proc::U32EvalError::Negative) => { return Err(ExpressionError::NegativeIndex(base)) } - Err(crate::proc::U32EvalError::NonConst) => {} + Err( + crate::proc::U32EvalError::Runtime | crate::proc::U32EvalError::Override(_), + ) => {} } ShaderStages::all() @@ -373,7 +377,14 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Override(handle) => { + if self.overrides_resolved { + return Err(ExpressionError::UnresolvedOverride(handle)); + } else { + ShaderStages::all() + } + } E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 7865f1fc42..5b71b7289f 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,6 +217,8 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), + #[error("Missing value for pipeline-overridable constant {0:?}")] + UnresolvedOverride(Handle), } bitflags::bitflags! { @@ -318,7 +320,7 @@ impl<'a> BlockContext<'a> { } } -impl super::Validator { +impl super::Validator<'_> { fn validate_call( &mut self, function: Handle, @@ -1760,6 +1762,9 @@ impl super::Validator { &local_expr_kind, ) { Ok(stages) => info.available_stages &= stages, + Err(ExpressionError::UnresolvedOverride(handle)) => { + return Err(FunctionError::UnresolvedOverride(handle).with_span()) + } Err(source) => { return Err(FunctionError::Expression { handle, source } .with_span_handle(handle, &fun.expressions)) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 86285c2818..38787789a0 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -14,7 +14,7 @@ use crate::{Arena, UniqueArena}; #[cfg(test)] use alloc::string::ToString; -impl super::Validator { +impl super::Validator<'_> { /// Validates that all handles within `module` are: /// /// * Valid, in the sense that they contain indices within each arena structure inside the diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 3792c71abc..b50199abcb 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -487,7 +487,7 @@ impl VaryingContext<'_> { } } -impl super::Validator { +impl super::Validator<'_> { pub(super) fn validate_global_var( &self, var: &crate::GlobalVariable, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index c8a02db1af..11872ada1a 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -18,7 +18,7 @@ use bit_set::BitSet; use crate::{ arena::{Handle, HandleSet}, proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, - FastHashSet, + FastHashMap, FastHashSet, }; //TODO: analyze the model at the same time as we validate it, @@ -268,8 +268,21 @@ impl ops::Index> for ModuleInfo { } } +/// Information about overrides that remain unresolved after [`process_overrides`]. +/// +/// This struct may be passed to the various backend writers. +/// +/// [`process_overrides`]: crate::back::pipeline_constants::process_overrides +#[derive(Clone, Debug, Default)] +pub struct UnresolvedOverrides { + pub(crate) global_variables: + FastHashMap, Handle>, + pub(crate) functions: FastHashMap, Handle>, + pub(crate) entry_points: FastHashMap>, +} + #[derive(Debug)] -pub struct Validator { +pub struct Validator<'a> { flags: ValidationFlags, capabilities: Capabilities, subgroup_stages: ShaderStages, @@ -289,6 +302,8 @@ pub struct Validator { /// constant expressions as errors. overrides_resolved: bool, + unresolved_overrides: Option<&'a UnresolvedOverrides>, + /// A checklist of expressions that must be visited by a specific kind of /// statement. /// @@ -452,7 +467,7 @@ impl crate::TypeInner { } } -impl Validator { +impl<'a> Validator<'a> { /// Construct a new validator instance. pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) { @@ -487,6 +502,7 @@ impl Validator { valid_expression_set: HandleSet::new(), override_ids: FastHashSet::default(), overrides_resolved: false, + unresolved_overrides: None, needs_visit: HandleSet::new(), } } @@ -574,8 +590,6 @@ impl Validator { if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) { return Err(OverrideError::InvalidType); } - } else if self.overrides_resolved { - return Err(OverrideError::UninitializedOverride); } Ok(()) @@ -590,18 +604,19 @@ impl Validator { self.validate_impl(module) } - /// Check the given module to be valid, requiring overrides to be resolved. + /// Check the given module to be valid, after resolving overrides. /// - /// This is the same as [`validate`], except that any override - /// whose value is not a fully-evaluated constant expression is - /// treated as an error. + /// This is the same as [`validate`], but override expressions are allowed + /// in items that appear in one of the maps in `unresolved`. /// /// [`validate`]: Validator::validate - pub fn validate_resolved_overrides( + pub fn validate_with_resolved_overrides( &mut self, module: &crate::Module, + unresolved: &'a UnresolvedOverrides, ) -> Result> { self.overrides_resolved = true; + self.unresolved_overrides = Some(unresolved); self.validate_impl(module) } @@ -703,19 +718,36 @@ impl Validator { } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) - .map_err(|source| { - ValidationError::GlobalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - source, - } - .with_span_handle(var_handle, &module.global_variables) - })?; + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.global_variables.contains_key(&var_handle) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind); + self.overrides_resolved = save_overrides_resolved; + res.map_err(|source| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var_handle, &module.global_variables) + })?; } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.functions.contains_key(&handle) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_function(fun, module, &mod_info, false); + self.overrides_resolved = save_overrides_resolved; + match res { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -731,7 +763,7 @@ impl Validator { } let mut ep_map = FastHashSet::default(); - for ep in module.entry_points.iter() { + for (ep_index, ep) in module.entry_points.iter().enumerate() { if !ep_map.insert((ep.stage, &ep.name)) { return Err(ValidationError::EntryPoint { stage: ep.stage, @@ -741,7 +773,16 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.entry_points.contains_key(&ep_index) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_entry_point(ep, module, &mod_info); + self.overrides_resolved = save_overrides_resolved; + match res { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index b3ae13b7d4..67b91a8435 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -252,7 +252,7 @@ impl TypeInfo { } } -impl super::Validator { +impl super::Validator<'_> { const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { if self.capabilities.contains(capability) { Ok(()) diff --git a/naga/tests/in/wgsl/missing-unused-overrides.toml b/naga/tests/in/wgsl/missing-unused-overrides.toml new file mode 100644 index 0000000000..3eb9ae5c9c --- /dev/null +++ b/naga/tests/in/wgsl/missing-unused-overrides.toml @@ -0,0 +1,13 @@ +pipeline_constants = { ov_for_vertex = 1.5 } +#targets = "IR | ANALYSIS | SPIRV | METAL | HLSL | GLSL" +targets = "METAL" + +[msl] +lang_version = [2, 1] + +[msl_pipeline] +entry_point = ["Vertex", "vert_main"] + +[spv] +separate_entry_points = true +version = [1, 0] diff --git a/naga/tests/in/wgsl/missing-unused-overrides.wgsl b/naga/tests/in/wgsl/missing-unused-overrides.wgsl new file mode 100644 index 0000000000..2243a17c06 --- /dev/null +++ b/naga/tests/in/wgsl/missing-unused-overrides.wgsl @@ -0,0 +1,45 @@ +override ov_for_vertex: f32; + +@vertex +fn vert_main( + @location(0) pos : vec2, + @builtin(instance_index) ii: u32, + @builtin(vertex_index) vi: u32, +) -> @builtin(position) vec4 { + return vec4(pos.x * ov_for_vertex, pos.y, 0.0, 1.0); +} + +struct FragmentIn { + @location(0) color: vec4 +} + +override ov_for_fragment: f32; + +fn frag_helper(color: vec4) -> vec4 { + return color * ov_for_fragment; +} + +@fragment +fn frag_main(in: FragmentIn) -> @location(0) vec4 { + return frag_helper(in.color); +} + +override ov_global_init: u32; +var foo: u32 = ov_global_init; + +override ov_array_size: u32; +var arr: array; + +override ov_for_compute: u32; + +fn compute_helper() { + _ = foo; + _ = arr[0]; +} + +override ov_workgroup_size: u32; +@compute @workgroup_size(ov_workgroup_size) +fn compute_main() { + _ = ov_for_compute; + compute_helper(); +} diff --git a/naga/tests/naga/snapshots.rs b/naga/tests/naga/snapshots.rs index 931136ed8d..2cfd7b4048 100644 --- a/naga/tests/naga/snapshots.rs +++ b/naga/tests/naga/snapshots.rs @@ -145,6 +145,8 @@ struct Parameters { // -- HLSL options -- #[cfg(all(feature = "deserialize", hlsl_out))] hlsl: naga::back::hlsl::Options, + #[serde(default)] + hlsl_pipeline: naga::back::hlsl::PipelineOptions, // -- WGSL options -- wgsl: WgslOutParameters, @@ -233,6 +235,12 @@ impl Input { return None; } + if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") { + if !file_name.to_string_lossy().contains(&pat) { + return None; + } + } + let input = Input::new( subdirectory, file_name.file_stem().unwrap().to_str().unwrap(), @@ -548,6 +556,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& module, &info, ¶ms.hlsl, + ¶ms.hlsl_pipeline, ¶ms.pipeline_constants, frag_ep, ); @@ -598,9 +607,12 @@ fn write_output_spv( debug_info, }; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides(module, info, None, pipeline_constants) + .expect("override evaluation failed"); if params.separate_entry_points { for ep in module.entry_points.iter() { @@ -660,15 +672,28 @@ fn write_output_msl( ) { use naga::back::msl; - println!("generating MSL"); - - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + println!("generating MSL for {:?}", pipeline_options.entry_point); + + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + pipeline_options + .entry_point + .as_ref() + .map(|&(stage, ref name)| (stage, name.as_str())), + pipeline_constants, + ) + .expect("override evaluation failed"); let mut options = options.clone(); options.bounds_check_policies = bounds_check_policies; - let (string, tr_info) = msl::write_string(&module, &info, &options, pipeline_options) + let mut pipeline_options = pipeline_options.clone(); + pipeline_options.unresolved_overrides = unresolved; + let (string, tr_info) = msl::write_string(&module, &info, &options, &pipeline_options) .unwrap_or_else(|err| panic!("Metal write failed: {err}")); for (ep, result) in module.entry_points.iter().zip(tr_info.entry_point_names) { @@ -704,9 +729,12 @@ fn write_output_glsl( }; let mut buffer = String::new(); - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides(module, info, None, pipeline_constants) + .expect("override evaluation failed"); let mut writer = glsl::Writer::new( &mut buffer, &module, @@ -728,6 +756,7 @@ fn write_output_hlsl( module: &naga::Module, info: &naga::valid::ModuleInfo, options: &naga::back::hlsl::Options, + pipeline_options: &naga::back::hlsl::PipelineOptions, pipeline_constants: &naga::back::PipelineConstants, frag_ep: Option, ) { @@ -736,9 +765,20 @@ fn write_output_hlsl( println!("generating HLSL"); - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + pipeline_options + .entry_point + .as_ref() + .map(|&(stage, ref name)| (stage, name.as_str())), + pipeline_constants, + ) + .expect("override evaluation failed"); let mut buffer = String::new(); let pipeline_options = Default::default(); diff --git a/naga/tests/naga/validation.rs b/naga/tests/naga/validation.rs index 4b813d384e..d832cc0659 100644 --- a/naga/tests/naga/validation.rs +++ b/naga/tests/naga/validation.rs @@ -1,4 +1,8 @@ -use naga::{valid, Expression, Function, Scalar}; +use naga::{ + ir, + valid::{self, ModuleInfo}, + Expression, Function, Module, Scalar, +}; /// Validation should fail if `AtomicResult` expressions are not /// populated by `Atomic` statements. @@ -536,16 +540,16 @@ fn main(input: VertexOutput) {{ } #[allow(dead_code)] -struct BindingArrayFixture { +struct BindingArrayFixture<'a> { module: naga::Module, span: naga::Span, ty_u32: naga::Handle, ty_array: naga::Handle, ty_struct: naga::Handle, - validator: naga::valid::Validator, + validator: naga::valid::Validator<'a>, } -impl BindingArrayFixture { +impl BindingArrayFixture<'_> { fn new() -> Self { let mut module = naga::Module::default(); let span = naga::Span::default(); @@ -770,7 +774,6 @@ fn bad_texture_dimensions_level() { fn arity_check() { use ir::MathFunction as Mf; use naga::Span; - use naga::{ir, valid}; let _ = env_logger::builder().is_test(true).try_init(); type Result = core::result::Result; @@ -923,3 +926,157 @@ fn main() { naga::valid::GlobalUse::QUERY ); } + +fn parse_validate(source: &str) -> (Module, ModuleInfo) { + let module = naga::front::wgsl::parse_str(source).expect("module should parse"); + let info = valid::Validator::new(Default::default(), valid::Capabilities::all()) + .validate(&module) + .unwrap(); + (module, info) +} + +/// Helper for `process_overrides` tests. +/// +/// The goal of these tests is to verify that `process_overrides` accepts cases +/// where all necessary overrides are specified (even if some unnecessary ones +/// are not), and does not accept cases where necessary overrides are missing. +/// "Necessary" means that the entry point is referenced in some way by some +/// function reachable from the specified entry point. +/// +/// Each test passes a source snippet containing a compute entry point `used` +/// that makes use of the override `ov` in some way. We augment that with (1) +/// the definition of `ov` and (2) a dummy entrypoint that does not use the +/// override, and then test the matrix of (specified or not) x (used or not). +/// +/// Since `process_overrides` leaves unresolved overrides in the output module, +/// there could be bugs where a backend to reaches one of the remaining overrides +/// and fails. That is not exercised here, but is covered by the +/// `missing-unused-overrides` snapshot test. +fn override_test(test_case: &str) { + use hashbrown::HashMap; + use naga::back::pipeline_constants::PipelineConstantError; + + let source = [ + "override ov: u32;\n", + test_case, + "@compute @workgroup_size(64) +fn unused() { +} +", + ] + .concat(); + let (module, info) = parse_validate(&source); + + let overrides = HashMap::from([(String::from("ov"), 1.)]); + + // Can translate `unused` with or without the override + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "unused")), + &HashMap::new(), + ) + .unwrap(); + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "unused")), + &overrides, + ) + .unwrap(); + + // Cannot translate `used` without the override + let err = naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "used")), + &HashMap::new(), + ) + .unwrap_err(); + assert!(matches!(err, PipelineConstantError::MissingValue(name) if name == "ov")); + + // Can translate `used` if the override is specified + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "used")), + &overrides, + ) + .unwrap(); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_workgroup_size() { + override_test( + " +@compute @workgroup_size(ov) +fn used() { +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_function() { + override_test( + " +fn foo() -> u32 { + return ov; +} + +@compute @workgroup_size(64) +fn used() { + foo(); +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_entrypoint() { + override_test( + " +fn foo() -> u32 { + return ov; +} + +@compute @workgroup_size(64) +fn used() { + foo(); +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_array_size() { + override_test( + " +var arr: array; + +@compute @workgroup_size(64) +fn used() { + _ = arr[5]; +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_global_init() { + override_test( + " +var foo: u32 = ov; + +@compute @workgroup_size(64) +fn used() { + _ = foo; +} +", + ); +} diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index 71e5b87171..ad3818b7f5 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -3138,7 +3138,7 @@ fn local_const_from_override() { const c = o; } ", - r###"error: Unexpected override-expression + r###"error: Can't use pipeline-overridable constants in const-expressions ┌─ wgsl:4:23 │ 4 │ const c = o; diff --git a/naga/tests/out/msl/wgsl-missing-unused-overrides.msl b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl new file mode 100644 index 0000000000..e2a7262f4d --- /dev/null +++ b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl @@ -0,0 +1,26 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + +struct FragmentIn { + metal::float4 color; +}; +constant float ov_for_vertex = 1.5; + +struct vert_mainInput { + metal::float2 pos [[attribute(0)]]; +}; +struct vert_mainOutput { + metal::float4 member [[position]]; +}; +vertex vert_mainOutput vert_main( + vert_mainInput varyings [[stage_in]] +, uint ii [[instance_id]] +, uint vi [[vertex_id]] +) { + const auto pos = varyings.pos; + return vert_mainOutput { metal::float4(pos.x * ov_for_vertex, pos.y, 0.0, 1.0) }; +} + diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index a8d1329fac..421a1a682e 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -371,11 +371,11 @@ impl ImplicitPipelineIds<'_> { } /// Create a validator with the given validation flags. -pub fn create_validator( +pub fn create_validator<'a>( features: wgt::Features, downlevel: wgt::DownlevelFlags, flags: naga::valid::ValidationFlags, -) -> naga::valid::Validator { +) -> naga::valid::Validator<'a> { use naga::valid::Capabilities as Caps; let mut caps = Caps::empty(); caps.set( diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 3d8100b9c0..1db2468fe7 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -276,9 +276,14 @@ impl super::Device { let stage_bit = auxil::map_naga_stage(naga_stage); - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &stage.module.naga.module, &stage.module.naga.info, + None, stage.constants, ) .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?; diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index c5539eae35..75369290d0 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -216,9 +216,14 @@ impl super::Device { multiview: context.multiview, }; - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &stage.module.naga.module, &stage.module.naga.info, + None, stage.constants, ) .map_err(|e| { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6ab22b0c3e..08b1e241fe 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -134,9 +134,14 @@ impl super::Device { panic!("load_shader required a naga shader"); }; let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info: module_info, + unresolved: unresolved_overrides, + } = naga::back::pipeline_constants::process_overrides( &naga_shader.module, &naga_shader.info, + Some((naga_stage, stage.entry_point)), stage.constants, ) .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {:?}", e)))?; @@ -182,6 +187,7 @@ impl super::Device { let pipeline_options = naga::back::msl::PipelineOptions { entry_point: Some((naga_stage, stage.entry_point.to_owned())), + unresolved_overrides, allow_and_force_point_size: match primitive_class { MTLPrimitiveTopologyClass::Point => true, _ => false, diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 4ae7656512..262a0820cc 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -889,9 +889,14 @@ impl super::Device { &self.naga_options }; - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &naga_shader.module, &naga_shader.info, + None, stage.constants, ) .map_err(|e| {