From a5bebb07eef83d6cba7afbe16a5946b0eee82822 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 17 May 2024 17:36:17 +0200 Subject: [PATCH] ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer --- tests/tests/dispatch_workgroups_indirect.rs | 125 ++++++++++++ tests/tests/root.rs | 1 + wgpu-core/src/command/bind.rs | 14 +- wgpu-core/src/command/compute.rs | 207 +++++++++++++++++++- wgpu-core/src/command/compute_command.rs | 16 ++ wgpu-core/src/command/mod.rs | 2 + wgpu-core/src/device/global.rs | 126 ++++++++++++ wgpu-core/src/device/resource.rs | 12 ++ wgpu-core/src/lock/rank.rs | 1 + wgpu-core/src/pipeline.rs | 15 ++ 10 files changed, 517 insertions(+), 2 deletions(-) create mode 100644 tests/tests/dispatch_workgroups_indirect.rs diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs new file mode 100644 index 00000000000..e01e893eaf0 --- /dev/null +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -0,0 +1,125 @@ +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; + +const SHADER_SRC: &str = " +@group(0) @binding(0) +var out: u32; + +@compute @workgroup_size(1) +fn main() { + out = 1u; +} +"; + +#[gpu_test] +static CHECK_NUM_WORKGROUPS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_compute_workgroups_per_dimension: 10, + ..wgpu::Limits::downlevel_defaults() + }), + ) + .run_async(|ctx| async move { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + compilation_options: Default::default(), + cache: None, + }); + + let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::INDIRECT + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::UNIFORM, + mapped_at_creation: false, + }); + + let max = ctx.device.limits().max_compute_workgroups_per_dimension; + ctx.queue + .write_buffer(&indirect_buffer, 0, bytemuck::bytes_of(&[max + 1, 1, 1])); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: out_buffer.as_entire_binding(), + }], + }); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor::default()); + + { + let mut compute_pass = + encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default()); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_bind_group(0, &bind_group, &[]); + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, 0); + } + + encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 4); + + ctx.queue.submit(Some(encoder.finish())); + + readback_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| {}); + + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let view = readback_buffer.slice(..).get_mapped_range(); + // Make sure the dispatch was discarded + assert!(view.iter().all(|v| *v == 0)); + + // Test that unsetting the bind group works properly + { + ctx.device.push_error_scope(wgpu::ErrorFilter::Validation); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor::default()); + { + let mut compute_pass = + encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default()); + compute_pass.set_pipeline(&pipeline); + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, 0); + } + let _ = encoder.finish(); + let error = pollster::block_on(ctx.device.pop_error_scope()); + assert!(error.map_or(false, |error| format!("{error}") + .contains("Expected bind group is missing"))); + } + }); diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 29f894ede91..01341720c19 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -14,6 +14,7 @@ mod clear_texture; mod compute_pass_resource_ownership; mod create_surface_error; mod device; +mod dispatch_workgroups_indirect; mod encoder; mod external_texture; mod float32_filterable; diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index c643611a967..75d566c5963 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -131,7 +131,7 @@ mod compat { diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}")) } } else { - diff.push("Assigned bind group layout not found (internal error)".to_owned()); + diff.push("Expected bind group is missing".to_owned()); } } else { diff.push("Expected bind group layout not found (internal error)".to_owned()); @@ -191,6 +191,10 @@ mod compat { self.make_range(index) } + pub fn unassign(&mut self, index: usize) { + self.entries[index].assigned = None; + } + pub fn list_active(&self) -> impl Iterator + '_ { self.entries .iter() @@ -358,6 +362,14 @@ impl Binder { &self.payloads[bind_range] } + pub(super) fn unassign_group(&mut self, index: usize) { + log::trace!("\tBinding [{}] = null", index); + + self.payloads[index].reset(); + + self.manager.unassign(index); + } + pub(super) fn list_active<'a>(&'a self) -> impl Iterator>> + '_ { let payloads = &self.payloads; self.manager diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 997c62e8b1f..66746c15a2e 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -146,6 +146,20 @@ pub enum ComputePassErrorInner { MissingFeatures(#[from] MissingFeatures), #[error(transparent)] MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error(transparent)] + IndirectValidation(#[from] ComputePassIndirectValidationError), +} + +#[derive(Clone, Debug, Error)] +pub enum ComputePassIndirectValidationError { + #[error(transparent)] + ValidationPipeline( + #[from] crate::pipeline::CreateDispatchWorkgroupsIndirectValidationPipelineError, + ), + #[error(transparent)] + Buffer(#[from] crate::resource::CreateBufferError), + #[error(transparent)] + BindGroup(#[from] crate::binding_model::CreateBindGroupError), } impl PrettyError for ComputePassErrorInner { @@ -283,9 +297,26 @@ impl Global { &self, pass: &ComputePass, ) -> Result<(), ComputePassError> { + let mut base = pass.base.as_ref(); + + let new_commands = base + .commands + .iter() + .any(|cmd| matches!(cmd, ArcComputeCommand::DispatchIndirect { .. })) + .then(|| { + self.command_encoder_inject_dispatch_workgroups_indirect_validation( + pass.parent_id, + base.commands, + ) + }) + .transpose()?; + if let Some(new_commands) = new_commands.as_ref() { + base.commands = new_commands; + } + self.command_encoder_run_compute_pass_impl( pass.parent_id, - pass.base.as_ref(), + base, pass.timestamp_writes.as_ref(), ) } @@ -515,6 +546,20 @@ impl Global { } } } + ArcComputeCommand::UnsetBindGroup { index } => { + let scope = PassErrorScope::UnsetBindGroup(*index); + + let max_bind_groups = cmd_buf.limits.max_bind_groups; + if *index >= max_bind_groups { + return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { + index: *index, + max: max_bind_groups, + }) + .map_pass_err(scope); + } + + state.binder.unassign_group(*index as usize); + } ArcComputeCommand::SetPipeline(pipeline) => { let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); @@ -823,6 +868,166 @@ impl Global { Ok(()) } + + fn command_encoder_inject_dispatch_workgroups_indirect_validation( + &self, + encoder_id: id::CommandEncoderId, + commands: &[ArcComputeCommand], + ) -> Result>, ComputePassError> { + profiling::scope!("CommandEncoder::inject_dispatch_workgroups_indirect_validation"); + let scope = PassErrorScope::Pass(encoder_id); + + let hub = A::hub(self); + + let cmd_buf: Arc> = + CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?; + let device = &cmd_buf.device; + if !device.is_valid() { + return Err(ComputePassErrorInner::InvalidDevice( + cmd_buf.device.as_info().id(), + )) + .map_pass_err(scope); + } + let device_id = device.as_info().id(); + + let mut new_commands = Vec::with_capacity(commands.len()); + let mut current_pipeline = None; + let mut current_bind_group_0 = None; + + for command in commands { + match command { + ArcComputeCommand::SetBindGroup { + index: 0, + num_dynamic_offsets, + bind_group, + } => { + current_bind_group_0 = Some((*num_dynamic_offsets, bind_group.clone())); + new_commands.push(command.clone()); + } + ArcComputeCommand::SetPipeline(pipeline) => { + current_pipeline = Some(pipeline.clone()); + new_commands.push(command.clone()); + } + ArcComputeCommand::DispatchIndirect { buffer, offset } => { + // if there is no pipeline set, don't inject the validation commands as we will error anyway + if let Some(original_pipeline) = current_pipeline.clone() { + let (validation_pipeline_id, validation_bgl_id) = self + .device_get_or_create_dispatch_workgroups_indirect_validation_pipeline::( + device_id, + ) + .map_err(ComputePassIndirectValidationError::ValidationPipeline) + .map_pass_err(scope)?; + + let (dst_buffer_id, error) = self.device_create_buffer::( + device_id, + &crate::resource::BufferDescriptor { + label: None, + size: 4 * 3, + usage: wgt::BufferUsages::INDIRECT | wgt::BufferUsages::STORAGE, + mapped_at_creation: false, + }, + None, + ); + if let Some(error) = error { + return Err(ComputePassIndirectValidationError::Buffer(error)) + .map_pass_err(scope)?; + } + + let (bind_group_id, error) = self.device_create_bind_group::( + device_id, + &crate::binding_model::BindGroupDescriptor { + label: None, + layout: validation_bgl_id, + entries: std::borrow::Cow::Borrowed(&[ + crate::binding_model::BindGroupEntry { + binding: 0, + resource: crate::binding_model::BindingResource::Buffer( + crate::binding_model::BufferBinding { + buffer_id: buffer.as_info().id(), + offset: *offset, + size: Some( + std::num::NonZeroU64::new(4 * 3).unwrap(), + ), + }, + ), + }, + crate::binding_model::BindGroupEntry { + binding: 1, + resource: crate::binding_model::BindingResource::Buffer( + crate::binding_model::BufferBinding { + buffer_id: dst_buffer_id, + offset: 0, + size: Some( + std::num::NonZeroU64::new(4 * 3).unwrap(), + ), + }, + ), + }, + ]), + }, + None, + ); + if let Some(error) = error { + return Err(ComputePassIndirectValidationError::BindGroup(error)) + .map_pass_err(scope)?; + } + + let validation_pipeline = hub + .compute_pipelines + .read() + .get_owned(validation_pipeline_id) + .map_err(|_| { + ComputePassErrorInner::InvalidPipeline(validation_pipeline_id) + }) + .map_pass_err(scope)?; + + let bind_group = hub + .bind_groups + .read() + .get_owned(bind_group_id) + .map_err(|_| ComputePassErrorInner::InvalidBindGroup(0)) + .map_pass_err(scope)?; + + let dst_buffer = hub + .buffers + .read() + .get_owned(dst_buffer_id) + .map_err(|_| ComputePassErrorInner::InvalidBuffer(dst_buffer_id)) + .map_pass_err(scope)?; + + new_commands.push(ArcComputeCommand::SetPipeline(validation_pipeline)); + new_commands.push(ArcComputeCommand::SetBindGroup { + index: 0, + num_dynamic_offsets: 0, + bind_group, + }); + new_commands.push(ArcComputeCommand::Dispatch([1, 1, 1])); + + new_commands.push(ArcComputeCommand::SetPipeline(original_pipeline)); + if let Some((num_dynamic_offsets, bind_group)) = + current_bind_group_0.clone() + { + new_commands.push(ArcComputeCommand::SetBindGroup { + index: 0, + num_dynamic_offsets, + bind_group, + }); + } else { + new_commands.push(ArcComputeCommand::UnsetBindGroup { index: 0 }); + } + new_commands.push(ArcComputeCommand::DispatchIndirect { + buffer: dst_buffer, + offset: 0, + }); + } else { + new_commands.push(command.clone()) + } + } + command => new_commands.push(command.clone()), + } + } + Ok(new_commands) + } } // Recording a compute pass. diff --git a/wgpu-core/src/command/compute_command.rs b/wgpu-core/src/command/compute_command.rs index 49fdbbec24d..fd3b531628e 100644 --- a/wgpu-core/src/command/compute_command.rs +++ b/wgpu-core/src/command/compute_command.rs @@ -19,6 +19,10 @@ pub enum ComputeCommand { bind_group_id: id::BindGroupId, }, + UnsetBindGroup { + index: u32, + }, + SetPipeline(id::ComputePipelineId), /// Set a range of push constants to values stored in `push_constant_data`. @@ -103,6 +107,10 @@ impl ComputeCommand { })?, }, + ComputeCommand::UnsetBindGroup { index } => { + ArcComputeCommand::UnsetBindGroup { index } + } + ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline( pipelines_guard .get_owned(pipeline_id) @@ -194,6 +202,10 @@ pub enum ArcComputeCommand { bind_group: Arc>, }, + UnsetBindGroup { + index: u32, + }, + SetPipeline(Arc>), /// Set a range of push constants to values stored in `push_constant_data`. @@ -261,6 +273,10 @@ impl From<&ArcComputeCommand> for ComputeCommand { bind_group_id: bind_group.as_info().id(), }, + ArcComputeCommand::UnsetBindGroup { index } => { + ComputeCommand::UnsetBindGroup { index: *index } + } + ArcComputeCommand::SetPipeline(pipeline) => { ComputeCommand::SetPipeline(pipeline.as_info().id()) } diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 5159d6fa85c..32dd15a5e6f 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -808,6 +808,8 @@ pub enum PassErrorScope { Pass(id::CommandEncoderId), #[error("In a set_bind_group command")] SetBindGroup(id::BindGroupId), + #[error("In a unset_bind_group command, slot: {0}")] + UnsetBindGroup(u32), #[error("In a set_pipeline command")] SetPipelineRender(id::RenderPipelineId), #[error("In a set_pipeline command")] diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index a5c51b269f7..0f7303cc69b 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -1767,6 +1767,132 @@ impl Global { (id, Some(error)) } + pub(crate) fn device_get_or_create_dispatch_workgroups_indirect_validation_pipeline< + A: HalApi, + >( + &self, + device_id: DeviceId, + ) -> Result< + (id::ComputePipelineId, id::BindGroupLayoutId), + pipeline::CreateDispatchWorkgroupsIndirectValidationPipelineError, + > { + profiling::scope!("Device::get_or_create_dispatch_workgroups_indirect_validation_pipeline"); + + let hub = A::hub(self); + + let device = hub + .devices + .get(device_id) + .map_err(|_| DeviceError::Invalid)?; + if !device.is_valid() { + return Err(DeviceError::Lost.into()); + } + + let mut dispatch_workgroups_indirect_validation_pipeline_guard = device + .dispatch_workgroups_indirect_validation_pipeline + .lock(); + if let Some(res) = *dispatch_workgroups_indirect_validation_pipeline_guard { + Ok(res) + } else { + let max_compute_workgroups_per_dimension = + device.limits.max_compute_workgroups_per_dimension; + + let src = format!(" + @group(0) @binding(0) + var src: vec3; + @group(0) @binding(1) + var dst: vec3; + + @compute @workgroup_size(1) + fn main() {{ + dst = select(src, vec3(), src > vec3({max_compute_workgroups_per_dimension}u)); + }} + "); + + let (module, error) = self.device_create_shader_module::( + device_id, + &crate::pipeline::ShaderModuleDescriptor { + label: None, + shader_bound_checks: wgt::ShaderBoundChecks::default(), + }, + crate::pipeline::ShaderModuleSource::Wgsl(std::borrow::Cow::Owned(src)), + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (bgl, error) = self.device_create_bind_group_layout::( + device_id, + &crate::binding_model::BindGroupLayoutDescriptor { + label: None, + entries: std::borrow::Cow::Borrowed(&[ + wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }, + wgt::BindGroupLayoutEntry { + binding: 1, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }, + ]), + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (layout, error) = self.device_create_pipeline_layout::( + device_id, + &crate::binding_model::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: std::borrow::Cow::Borrowed(&[bgl]), + push_constant_ranges: std::borrow::Cow::Borrowed(&[]), + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (pipeline, error) = self.device_create_compute_pipeline::( + device_id, + &crate::pipeline::ComputePipelineDescriptor { + label: None, + layout: Some(layout), + stage: crate::pipeline::ProgrammableStageDescriptor { + module, + entry_point: Some(std::borrow::Cow::Borrowed("main")), + constants: Default::default(), + zero_initialize_workgroup_memory: true, + }, + cache: None, + }, + None, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + *dispatch_workgroups_indirect_validation_pipeline_guard = Some((pipeline, bgl)); + Ok((pipeline, bgl)) + } + } + /// Get an ID of one of the bind group layouts. The ID adds a refcount, /// which needs to be released by calling `bind_group_layout_drop`. pub fn compute_pipeline_get_bind_group_layout( diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 7ac3878ef8f..12262905bc9 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -96,6 +96,8 @@ pub struct Device { pub(crate) queue: OnceCell>>, queue_to_drop: OnceCell, pub(crate) zero_buffer: Option, + pub(crate) dispatch_workgroups_indirect_validation_pipeline: + Mutex>, pub(crate) info: ResourceInfo>, pub(crate) command_allocator: command::CommandAllocator, @@ -270,6 +272,10 @@ impl Device { queue: OnceCell::new(), queue_to_drop: OnceCell::new(), zero_buffer: Some(zero_buffer), + dispatch_workgroups_indirect_validation_pipeline: Mutex::new( + rank::DEVICE_DISPATCH_WORKGROUPS_INDIRECT_VALIDATION_PIPELINE, + None, + ), info: ResourceInfo::new("", None), command_allocator, active_submission_index: AtomicU64::new(0), @@ -588,6 +594,12 @@ impl Device { return Err(resource::CreateBufferError::InvalidUsage(desc.usage)); } + if desc.usage.contains(wgt::BufferUsages::INDIRECT) { + // We are going to be reading from it, internally; + // when validating the content of the buffer + usage |= hal::BufferUses::UNIFORM; + } + if !self .features .contains(wgt::Features::MAPPABLE_PRIMARY_BUFFERS) diff --git a/wgpu-core/src/lock/rank.rs b/wgpu-core/src/lock/rank.rs index 4387b8d138e..6565e3890f0 100644 --- a/wgpu-core/src/lock/rank.rs +++ b/wgpu-core/src/lock/rank.rs @@ -141,6 +141,7 @@ define_lock_ranks! { rank DEVICE_TRACE "Device::trace" followed by { } rank DEVICE_TRACKERS "Device::trackers" followed by { } rank DEVICE_USAGE_SCOPES "Device::usage_scopes" followed by { } + rank DEVICE_DISPATCH_WORKGROUPS_INDIRECT_VALIDATION_PIPELINE "Device::dispatch_workgroups_indirect_validation_pipeline" followed by { } rank IDENTITY_MANAGER_VALUES "IdentityManager::values" followed by { } rank REGISTRY_STORAGE "Registry::storage" followed by { } rank RENDER_BUNDLE_SCOPE_BUFFERS "RenderBundleScope::buffers" followed by { } diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index bfb2c331d8a..84b5cad6b96 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -214,6 +214,21 @@ pub enum CreateComputePipelineError { MissingDownlevelFlags(#[from] MissingDownlevelFlags), } +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreateDispatchWorkgroupsIndirectValidationPipelineError { + #[error(transparent)] + Device(#[from] DeviceError), + #[error(transparent)] + ShaderModule(#[from] CreateShaderModuleError), + #[error(transparent)] + BindGroupLayout(#[from] CreateBindGroupLayoutError), + #[error(transparent)] + PipelineLayout(#[from] CreatePipelineLayoutError), + #[error(transparent)] + ComputePipeline(#[from] CreateComputePipelineError), +} + #[derive(Debug)] pub struct ComputePipeline { pub(crate) raw: Option,