From 63074d6889941b2ed8bb49c8e20f8387bde6e65b Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:23:13 +0200 Subject: [PATCH] ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer --- tests/tests/dispatch_workgroups_indirect.rs | 244 +++++++++++++ tests/tests/root.rs | 1 + wgpu-core/Cargo.toml | 4 + wgpu-core/src/command/bind.rs | 22 +- wgpu-core/src/command/compute.rs | 160 ++++++++- wgpu-core/src/command/mod.rs | 7 +- wgpu-core/src/device/global.rs | 7 +- wgpu-core/src/device/mod.rs | 2 +- wgpu-core/src/device/resource.rs | 74 +++- wgpu-core/src/indirect_validation.rs | 371 ++++++++++++++++++++ wgpu-core/src/lib.rs | 2 + wgpu-core/src/pipeline.rs | 2 +- wgpu-core/src/present.rs | 7 +- wgpu-core/src/resource.rs | 41 ++- wgpu-core/src/snatch.rs | 8 +- wgpu/Cargo.toml | 6 + 16 files changed, 924 insertions(+), 34 deletions(-) create mode 100644 tests/tests/dispatch_workgroups_indirect.rs create mode 100644 wgpu-core/src/indirect_validation.rs diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs new file mode 100644 index 00000000000..745de5f295b --- /dev/null +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -0,0 +1,244 @@ +use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext}; + +/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12). +#[gpu_test] +static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }) + .expect_fail(FailureCase::backend(wgt::Backends::DX12)), + ) + .run_async(|ctx| async move { + let num_workgroups = [1, 2, 3]; + let res = run_test(&ctx, &num_workgroups, false).await; + assert_eq!(res, num_workgroups); + }); + +/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit. +#[gpu_test] +static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_compute_workgroups_per_dimension: 10, + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }) + .expect_fail(FailureCase::backend(wgt::Backends::DX12)), + ) + .run_async(|ctx| async move { + let max = ctx.device.limits().max_compute_workgroups_per_dimension; + + let res = run_test(&ctx, &[max, max, max], false).await; + assert_eq!(res, [max; 3]); + + let res = run_test(&ctx, &[max + 1, 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, max + 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, 1, max + 1], false).await; + assert_eq!(res, [0; 3]); + }); + +/// Make sure that resetting the bind groups set by the validation code works properly. +#[gpu_test] +static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }), + ) + .run_async(|ctx| async move { + ctx.device.push_error_scope(wgpu::ErrorFilter::Validation); + + let _ = run_test(&ctx, &[0, 0, 0], true).await; + + let error = pollster::block_on(ctx.device.pop_error_scope()); + assert!(error.map_or(false, |error| { + format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0") + })); + }); + +async fn run_test( + ctx: &TestingContext, + num_workgroups: &[u32; 3], + forget_to_set_bind_group: bool, +) -> [u32; 3] { + const SHADER_SRC: &str = " + struct TestOffsetPc { + inner: u32, + } + + // `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly. + var test_offset: TestOffsetPc; + + @group(0) @binding(0) + var out: array; + + @compute @workgroup_size(1) + fn main(@builtin(num_workgroups) num_workgroups: vec3, @builtin(workgroup_id) workgroup_id: vec3) { + if (all(workgroup_id == vec3())) { + out[0] = num_workgroups.x + test_offset.inner; + out[1] = num_workgroups.y + test_offset.inner; + out[2] = num_workgroups.z + test_offset.inner; + } + } + "; + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }], + }); + + let layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bgl], + push_constant_ranges: &[wgt::PushConstantRange { + stages: wgt::ShaderStages::COMPUTE, + range: 0..4, + }], + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&layout), + module: &module, + entry_point: "main", + compilation_options: Default::default(), + cache: None, + }); + + let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: out_buffer.as_entire_binding(), + }], + }); + + let mut res = None; + + for (indirect_offset, indirect_buffer_size) in [ + // internal src buffer binding size will be buffer.size + (0, 12), + (4, 4 + 12), + (4, 8 + 12), + (256 * 2 - 4 - 12, 256 * 2 - 4), + // internal src buffer binding size will be 256 * 2 + x + (0, 256 * 2 * 2 + 4), + (256, 256 * 2 * 2 + 8), + (256 + 4, 256 * 2 * 2 + 12), + (256 * 2 + 16, 256 * 2 * 2 + 16), + (256 * 2 * 2, 256 * 2 * 2 + 32), + (256 + 12, 256 * 2 * 2 + 64), + ] { + let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: indirect_buffer_size, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + }); + + ctx.queue.write_buffer( + &indirect_buffer, + indirect_offset, + bytemuck::bytes_of(num_workgroups), + ); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor::default()); + { + let mut compute_pass = + encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default()); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_push_constants(0, &[0, 0, 0, 0]); + if !forget_to_set_bind_group { + compute_pass.set_bind_group(0, &bind_group, &[]); + } + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); + } + + encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12); + + ctx.queue.submit(Some(encoder.finish())); + + readback_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| {}); + + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let view = readback_buffer.slice(..).get_mapped_range(); + + let current_res = *bytemuck::from_bytes(&view); + drop(view); + readback_buffer.unmap(); + + if let Some(past_res) = res { + assert_eq!(past_res, current_res); + } else { + res = Some(current_res); + } + } + + res.unwrap() +} diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 384cfcf78fc..910af40ca14 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -16,6 +16,7 @@ mod clear_texture; mod compute_pass_ownership; mod create_surface_error; mod device; +mod dispatch_workgroups_indirect; mod encoder; mod external_texture; mod float32_filterable; diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index d6fe5346296..2d08c86fec7 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"] ## to the validation carried out at public APIs in all builds. strict_asserts = ["wgt/strict_asserts"] +## Validates indirect draw/dispatch calls. This will also enable naga's +## WGSL frontend since we use a WGSL compute shader to do the validation. +indirect-validation = ["naga/wgsl-in"] + ## Enables serialization via `serde` on common wgpu types. serde = ["dep:serde", "wgt/serde", "arrayvec/serde"] diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index 64d534b558b..c2cb560ba52 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -207,13 +207,17 @@ mod compat { entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(), } } - fn make_range(&self, start_index: usize) -> Range { + + pub fn nr_of_valid_entries(&self) -> usize { // find first incompatible entry - let end = self - .entries + self.entries .iter() .position(|e| e.is_incompatible()) - .unwrap_or(self.entries.len()); + .unwrap_or(self.entries.len()) + } + + fn make_range(&self, start_index: usize) -> Range { + let end = self.nr_of_valid_entries(); start_index..end.max(start_index) } @@ -425,6 +429,16 @@ impl Binder { .map(move |index| payloads[index].group.as_ref().unwrap()) } + #[cfg(feature = "indirect-validation")] + pub(super) fn list_valid<'a>( + &'a self, + ) -> impl Iterator)> + '_ { + self.payloads + .iter() + .take(self.manager.nr_of_valid_entries()) + .enumerate() + } + pub(super) fn check_compatibility( &self, pipeline: &T, diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 4f8d549780c..917011c7070 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -19,7 +19,6 @@ use crate::{ pipeline::ComputePipeline, resource::{ self, Buffer, DestroyedResourceError, Labeled, MissingBufferUsageError, ParentDevice, - Trackable, }, snatch::SnatchGuard, track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope}, @@ -225,6 +224,8 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi> { string_offset: usize, active_query: Option<(Arc>, u32)>, + push_constants: Vec, + intermediate_trackers: Tracker, /// Immediate texture inits required because of prior discards. Need to @@ -480,6 +481,8 @@ impl Global { string_offset: 0, active_query: None, + push_constants: Vec::new(), + intermediate_trackers: Tracker::new(), pending_discard_init_fixups: SurfacesInDiscardState::new(), @@ -784,6 +787,21 @@ fn set_pipeline( } } + // TODO: integrate this in the code below once we simplify push constants + state.push_constants.clear(); + // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error. + if let Some(push_constant_range) = + pipeline.layout.push_constant_ranges.iter().find_map(|pcr| { + pcr.stages + .contains(wgt::ShaderStages::COMPUTE) + .then_some(pcr.range.clone()) + }) + { + // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502 + let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize; + state.push_constants.extend(core::iter::repeat(0).take(len)); + } + // Clear push constant ranges let non_overlapping = super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges); @@ -818,7 +836,7 @@ fn set_push_constant( .binder .pipeline_layout .as_ref() - //TODO: don't error here, lazily update the push constants + // TODO: don't error here, lazily update the push constants using `state.push_constants` .ok_or(ComputePassErrorInner::Dispatch( DispatchError::MissingPipeline, ))?; @@ -829,6 +847,11 @@ fn set_push_constant( end_offset_bytes, )?; + let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; + let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; + state.push_constants[offset_in_elements..offset_in_elements + size_in_elements] + .copy_from_slice(data_slice); + unsafe { state.raw_encoder.set_push_constants( pipeline_layout.raw(), @@ -882,10 +905,6 @@ fn dispatch_indirect( .device .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; - state - .scope - .buffers - .merge_single(&buffer, hal::BufferUses::INDIRECT)?; buffer.check_usage(wgt::BufferUsages::INDIRECT)?; if offset % 4 != 0 { @@ -902,7 +921,6 @@ fn dispatch_indirect( } let stride = 3 * 4; // 3 integers, x/y/z group size - state .buffer_memory_init_actions .extend(buffer.initialization_status.read().create_action( @@ -911,12 +929,132 @@ fn dispatch_indirect( MemoryInitKind::NeedsInitializedMemory, )); - state.flush_states(Some(buffer.tracker_index()))?; + #[cfg(feature = "indirect-validation")] + { + let params = state.device.indirect_validation.as_ref().unwrap().params( + &state.device.limits, + offset, + buffer.size, + ); - let buf_raw = buffer.try_raw(&state.snatch_guard)?; - unsafe { - state.raw_encoder.dispatch_indirect(buf_raw, offset); + unsafe { + state.raw_encoder.set_compute_pipeline(params.pipeline); + } + + unsafe { + state.raw_encoder.set_push_constants( + params.pipeline_layout, + wgt::ShaderStages::COMPUTE, + 0, + &[params.offset_remainder as u32 / 4], + ); + } + + unsafe { + state + .raw_encoder + .set_bind_group(params.pipeline_layout, 0, params.dst_bind_group, &[]); + } + unsafe { + state.raw_encoder.set_bind_group( + params.pipeline_layout, + 1, + buffer + .raw_indirect_validation_bind_group + .get(&state.snatch_guard) + .unwrap(), + &[params.aligned_offset as u32], + ); + } + + let src_transition = state + .intermediate_trackers + .buffers + .set_single(&buffer, hal::BufferUses::STORAGE_READ); + let src_barrier = + src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard)); + unsafe { + state + .raw_encoder + .transition_buffers(src_barrier.into_iter()); + } + + unsafe { + state + .raw_encoder + .transition_buffers(std::iter::once(hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, + })); + } + + unsafe { + state.raw_encoder.dispatch([1, 1, 1]); + } + + // reset state + { + let pipeline = state.pipeline.as_ref().unwrap(); + + unsafe { + state.raw_encoder.set_compute_pipeline(pipeline.raw()); + } + + if !state.push_constants.is_empty() { + unsafe { + state.raw_encoder.set_push_constants( + pipeline.layout.raw(), + wgt::ShaderStages::COMPUTE, + 0, + &state.push_constants, + ); + } + } + + for (i, e) in state.binder.list_valid() { + let group = e.group.as_ref().unwrap(); + let raw_bg = group.try_raw(&state.snatch_guard)?; + unsafe { + state.raw_encoder.set_bind_group( + pipeline.layout.raw(), + i as u32, + raw_bg, + &e.dynamic_offsets, + ); + } + } + } + + unsafe { + state + .raw_encoder + .transition_buffers(std::iter::once(hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, + })); + } + + state.flush_states(None)?; + unsafe { + state.raw_encoder.dispatch_indirect(params.dst_buffer, 0); + } + }; + #[cfg(not(feature = "indirect-validation"))] + { + state + .scope + .buffers + .merge_single(&buffer, hal::BufferUses::INDIRECT)?; + + use crate::resource::Trackable; + state.flush_states(Some(buffer.tracker_index()))?; + + let buf_raw = buffer.try_raw(&state.snatch_guard)?; + unsafe { + state.raw_encoder.dispatch_indirect(buf_raw, offset); + } } + Ok(()) } diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index e73a5bc0b0e..59d99576b84 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -18,9 +18,10 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, - dyn_compute_pass::DynComputePass, dyn_render_pass::DynRenderPass, query::*, render::*, - render_command::RenderCommand, transfer::*, + bundle::*, clear::ClearError, compute::*, compute_command::ArcComputeCommand, + compute_command::ComputeCommand, draw::*, dyn_compute_pass::DynComputePass, + dyn_render_pass::DynRenderPass, query::*, render::*, render_command::RenderCommand, + transfer::*, }; pub(crate) use allocator::CommandAllocator; diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 69a9ebf32c1..de022f677fe 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -565,7 +565,12 @@ impl Global { trace.add(trace::Action::CreateBuffer(fid.id(), desc.clone())); } - let buffer = device.create_buffer_from_hal(hal_buffer, desc); + let buffer = match device.create_buffer_from_hal(hal_buffer, desc) { + Ok(buffer) => buffer, + Err(e) => { + break 'error e; + } + }; let id = fid.assign(buffer); api_log!("Device::create_buffer -> {id:?}"); diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 222c50248a3..b8a44798f9f 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -39,7 +39,7 @@ pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10; // See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this. const CLEANUP_WAIT_MS: u32 = 60000; -const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid"; +pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid"; pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor>; diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index e0f2ddfe579..63ba77fba35 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -148,6 +148,9 @@ pub struct Device { #[cfg(feature = "trace")] pub(crate) trace: Mutex>, pub(crate) usage_scopes: UsageScopePool, + + #[cfg(feature = "indirect-validation")] + pub(crate) indirect_validation: Option>, } pub(crate) enum DeferredDestroy { @@ -174,6 +177,8 @@ impl Drop for Device { let pending_writes = unsafe { ManuallyDrop::take(&mut self.pending_writes.lock()) }; pending_writes.dispose(&raw); self.command_allocator.dispose(&raw); + #[cfg(feature = "indirect-validation")] + self.indirect_validation.take().unwrap().dispose(&raw); unsafe { raw.destroy_buffer(self.zero_buffer.take().unwrap()); raw.destroy_fence(self.fence.write().take().unwrap()); @@ -189,6 +194,8 @@ pub enum CreateDeviceError { OutOfMemory, #[error("Failed to create internal buffer for initializing textures")] FailedToCreateZeroBuffer(#[from] DeviceError), + #[error("Device initialization failed due to implementation specific errors")] + Internal, } impl Device { @@ -270,6 +277,25 @@ impl Device { let alignments = adapter.raw.capabilities.alignments.clone(); let downlevel = adapter.raw.capabilities.downlevel.clone(); + #[cfg(feature = "indirect-validation")] + let indirect_validation = if downlevel + .flags + .contains(wgt::DownlevelFlags::INDIRECT_EXECUTION) + { + match crate::indirect_validation::IndirectValidation::new( + &raw_device, + &desc.required_limits, + ) { + Ok(indirect_validation) => Some(indirect_validation), + Err(e) => { + log::error!("indirect-validation error: {e:?}"); + return Err(CreateDeviceError::Internal); + } + } + } else { + None + }; + Ok(Self { raw: Some(raw_device), adapter: adapter.clone(), @@ -315,6 +341,8 @@ impl Device { ), deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), + #[cfg(feature = "indirect-validation")] + indirect_validation, }) } @@ -352,7 +380,7 @@ impl Device { let Some(view) = view.upgrade() else { continue; }; - let Some(raw_view) = view.raw.snatch(self.snatchable_lock.write()) else { + let Some(raw_view) = view.raw.snatch(&mut self.snatchable_lock.write()) else { continue; }; @@ -367,7 +395,8 @@ impl Device { let Some(bind_group) = bind_group.upgrade() else { continue; }; - let Some(raw_bind_group) = bind_group.raw.snatch(self.snatchable_lock.write()) + let Some(raw_bind_group) = + bind_group.raw.snatch(&mut self.snatchable_lock.write()) else { continue; }; @@ -539,6 +568,13 @@ impl Device { let mut usage = conv::map_buffer_usage(desc.usage); + if desc.usage.contains(wgt::BufferUsages::INDIRECT) { + self.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; + // We are going to be reading from it, internally; + // when validating the content of the buffer + usage |= hal::BufferUses::STORAGE_READ | hal::BufferUses::STORAGE_READ_WRITE; + } + if desc.mapped_at_creation { if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 { return Err(resource::CreateBufferError::UnalignedSize); @@ -577,6 +613,10 @@ impl Device { }; let buffer = unsafe { self.raw().create_buffer(&hal_desc) }.map_err(DeviceError::from)?; + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = + self.create_indirect_validation_bind_group(&buffer, desc.size, desc.usage)?; + let buffer = Buffer { raw: Snatchable::new(buffer), device: self.clone(), @@ -591,6 +631,8 @@ impl Device { label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }; let buffer = Arc::new(buffer); @@ -673,7 +715,11 @@ impl Device { self: &Arc, hal_buffer: A::Buffer, desc: &resource::BufferDescriptor, - ) -> Arc> { + ) -> Result>, resource::CreateBufferError> { + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = + self.create_indirect_validation_bind_group(&hal_buffer, desc.size, desc.usage)?; + let buffer = Buffer { raw: Snatchable::new(hal_buffer), device: self.clone(), @@ -688,6 +734,8 @@ impl Device { label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }; let buffer = Arc::new(buffer); @@ -697,7 +745,25 @@ impl Device { .buffers .insert_single(&buffer, hal::BufferUses::empty()); - buffer + Ok(buffer) + } + + #[cfg(feature = "indirect-validation")] + fn create_indirect_validation_bind_group( + &self, + raw_buffer: &A::Buffer, + buffer_size: u64, + usage: wgt::BufferUsages, + ) -> Result, resource::CreateBufferError> { + if usage.contains(wgt::BufferUsages::INDIRECT) { + let indirect_validation = self.indirect_validation.as_ref().unwrap(); + let bind_group = indirect_validation + .create_src_bind_group(self.raw(), &self.limits, buffer_size, raw_buffer) + .map_err(resource::CreateBufferError::IndirectValidationBindGroup)?; + Ok(Snatchable::new(bind_group)) + } else { + Ok(Snatchable::empty()) + } } pub(crate) fn create_texture( diff --git a/wgpu-core/src/indirect_validation.rs b/wgpu-core/src/indirect_validation.rs new file mode 100644 index 00000000000..5e802172a29 --- /dev/null +++ b/wgpu-core/src/indirect_validation.rs @@ -0,0 +1,371 @@ +use thiserror::Error; + +use crate::{ + device::DeviceError, + hal_api::HalApi, + pipeline::{CreateComputePipelineError, CreateShaderModuleError}, +}; +use hal::Device as _; + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreateDispatchIndirectValidationPipelineError { + #[error(transparent)] + DeviceError(#[from] DeviceError), + #[error(transparent)] + ShaderModule(#[from] CreateShaderModuleError), + #[error(transparent)] + ComputePipeline(#[from] CreateComputePipelineError), +} + +/// This machinery requires the following limits: +/// +/// - max_bind_groups: 2, +/// - max_dynamic_storage_buffers_per_pipeline_layout: 1, +/// - max_storage_buffers_per_shader_stage: 2, +/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment, +/// - max_push_constant_size: 4, +/// - max_compute_invocations_per_workgroup 1 +/// +/// Which are be all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`. +#[derive(Debug)] +pub struct IndirectValidation { + module: A::ShaderModule, + dst_bind_group_layout: A::BindGroupLayout, + src_bind_group_layout: A::BindGroupLayout, + pipeline_layout: A::PipelineLayout, + pipeline: A::ComputePipeline, + dst_buffer: A::Buffer, + dst_bind_group: A::BindGroup, +} + +pub struct Params<'a, A: HalApi> { + pub pipeline_layout: &'a A::PipelineLayout, + pub pipeline: &'a A::ComputePipeline, + pub dst_buffer: &'a A::Buffer, + pub dst_bind_group: &'a A::BindGroup, + pub aligned_offset: u64, + pub offset_remainder: u64, +} + +impl IndirectValidation { + pub fn new( + device: &A::Device, + limits: &wgt::Limits, + ) -> Result { + let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension; + + let src = format!(" + @group(0) @binding(0) + var dst: array; + @group(1) @binding(0) + var src: array; + struct OffsetPc {{ + inner: u32, + }} + var offset: OffsetPc; + + @compute @workgroup_size(1) + fn main() {{ + let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]); + let res = select(src, vec3(), src > vec3({max_compute_workgroups_per_dimension}u)); + dst[0] = res.x; + dst[1] = res.y; + dst[2] = res.z; + }} + "); + + let module = naga::front::wgsl::parse_str(&src).map_err(|inner| { + CreateShaderModuleError::Parsing(naga::error::ShaderError { + source: src.clone(), + label: None, + inner: Box::new(inner), + }) + })?; + let info = crate::device::create_validator( + wgt::Features::PUSH_CONSTANTS, + wgt::DownlevelFlags::empty(), + naga::valid::ValidationFlags::all(), + ) + .validate(&module) + .map_err(|inner| { + CreateShaderModuleError::Validation(naga::error::ShaderError { + source: src, + label: None, + inner: Box::new(inner), + }) + })?; + let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { + module: std::borrow::Cow::Owned(module), + info, + debug_source: None, + }); + let hal_desc = hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }; + let module = + unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| { + match error { + hal::ShaderError::Device(error) => { + CreateShaderModuleError::Device(error.into()) + } + hal::ShaderError::Compilation(ref msg) => { + log::error!("Shader error: {}", msg); + CreateShaderModuleError::Generation + } + } + })?; + + let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }], + }; + let dst_bind_group_layout = unsafe { + device + .create_bind_group_layout(&dst_bind_group_layout_desc) + .map_err(DeviceError::from)? + }; + + let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: true, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }], + }; + let src_bind_group_layout = unsafe { + device + .create_bind_group_layout(&src_bind_group_layout_desc) + .map_err(DeviceError::from)? + }; + + let pipeline_layout_desc = hal::PipelineLayoutDescriptor { + label: None, + flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE, + bind_group_layouts: &[&dst_bind_group_layout, &src_bind_group_layout], + push_constant_ranges: &[wgt::PushConstantRange { + stages: wgt::ShaderStages::COMPUTE, + range: 0..4, + }], + }; + let pipeline_layout = unsafe { + device + .create_pipeline_layout(&pipeline_layout_desc) + .map_err(DeviceError::from)? + }; + + let pipeline_desc = hal::ComputePipelineDescriptor { + label: None, + layout: &pipeline_layout, + stage: hal::ProgrammableStage { + module: &module, + entry_point: "main", + constants: &Default::default(), + zero_initialize_workgroup_memory: false, + }, + cache: None, + }; + let pipeline = + unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err { + hal::PipelineError::Device(error) => { + CreateComputePipelineError::Device(error.into()) + } + hal::PipelineError::Linkage(_stages, msg) => { + CreateComputePipelineError::Internal(msg) + } + hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal( + crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(), + ), + })?; + + let dst_buffer_desc = hal::BufferDescriptor { + label: None, + size: 4 * 3, + usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, + memory_flags: hal::MemoryFlags::empty(), + }; + let dst_buffer = + unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?; + + let dst_bind_group_desc = hal::BindGroupDescriptor { + label: None, + layout: &dst_bind_group_layout, + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer: &dst_buffer, + offset: 0, + size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + let dst_bind_group = unsafe { + device + .create_bind_group(&dst_bind_group_desc) + .map_err(DeviceError::from) + }?; + + Ok(Self { + module, + dst_bind_group_layout, + src_bind_group_layout, + pipeline_layout, + pipeline, + dst_buffer, + dst_bind_group, + }) + } + + pub fn create_src_bind_group( + &self, + device: &A::Device, + limits: &wgt::Limits, + buffer_size: u64, + buffer: &A::Buffer, + ) -> Result { + let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); + let hal_desc = hal::BindGroupDescriptor { + label: None, + layout: &self.src_bind_group_layout, + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer, + offset: 0, + size: Some(std::num::NonZeroU64::new(binding_size).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + unsafe { + device + .create_bind_group(&hal_desc) + .map_err(DeviceError::from) + } + } + + pub fn params<'a>( + &'a self, + limits: &wgt::Limits, + offset: u64, + buffer_size: u64, + ) -> Params<'a, A> { + // The offset we receive is only required to be aligned to 4 bytes. + // + // Binding offsets and dynamic offsets are required to be aligned to + // min_storage_buffer_offset_alignment (256 bytes by default). + // + // So, we work around this limitation by calculating an aligned offset + // and pass the remainder through a push constant. + // + // We could bind the whole buffer and only have to pass the offset + // through a push constant but we might run into the + // max_storage_buffer_binding_size limit. + // + // See the inner docs of `calculate_src_buffer_binding_size` to + // see how we get the appropriate `binding_size`. + let alignment = limits.min_storage_buffer_offset_alignment as u64; + let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); + let aligned_offset = offset - offset % alignment; + // This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`. + let max_aligned_offset = buffer_size - binding_size; + let aligned_offset = aligned_offset.min(max_aligned_offset); + let offset_remainder = offset - aligned_offset; + + Params { + pipeline_layout: &self.pipeline_layout, + pipeline: &self.pipeline, + dst_buffer: &self.dst_buffer, + dst_bind_group: &self.dst_bind_group, + aligned_offset, + offset_remainder, + } + } + + pub fn dispose(self, device: &A::Device) { + let IndirectValidation { + module, + dst_bind_group_layout, + src_bind_group_layout, + pipeline_layout, + pipeline, + dst_buffer, + dst_bind_group, + } = self; + + use hal::Device; + unsafe { + device.destroy_bind_group(dst_bind_group); + device.destroy_buffer(dst_buffer); + device.destroy_compute_pipeline(pipeline); + device.destroy_pipeline_layout(pipeline_layout); + device.destroy_bind_group_layout(src_bind_group_layout); + device.destroy_bind_group_layout(dst_bind_group_layout); + device.destroy_shader_module(module); + } + } +} + +fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 { + let alignment = limits.min_storage_buffer_offset_alignment as u64; + + // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking + // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`. + + // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`. + + // Let `chunks = floor(buffer_size / alignment)`. + // Let `chunk` be the interval `[0, chunks]`. + // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4]. + // Let `binding` be the interval `[offset, offset + 12]`. + // Let `aligned_offset = alignment * chunk`. + // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`. + // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`. + // Let `min_aligned_binding_size = alignment + 8`. + + // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer + // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must + // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and + // `binding_size >= min_aligned_binding_size`. + + // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4]. + // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks]. + // => `binding_size = buffer_size - last_aligned_offset` + // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)` + // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u` + // => `binding_size = sr + alignment * u` + // => `min_aligned_binding_size <= sr + alignment * u` + // => `alignment + 8 <= sr + alignment * u` + // => `u` must be at least 2 + // => `binding_size = sr + alignment * 2` + + let binding_size = 2 * alignment + (buffer_size % alignment); + binding_size.min(buffer_size) +} diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 7bc6cfcefe6..64f0557b241 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -73,6 +73,8 @@ mod hash_utils; pub mod hub; pub mod id; pub mod identity; +#[cfg(feature = "indirect-validation")] +mod indirect_validation; mod init_tracker; pub mod instance; mod lock; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 2ab49f83d0c..6ac51082f60 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -94,7 +94,7 @@ impl ShaderModule { #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum CreateShaderModuleError { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] #[error(transparent)] Parsing(#[from] ShaderError), #[cfg(feature = "glsl")] diff --git a/wgpu-core/src/present.rs b/wgpu-core/src/present.rs index b59493f3164..c061008eb33 100644 --- a/wgpu-core/src/present.rs +++ b/wgpu-core/src/present.rs @@ -407,8 +407,11 @@ impl Global { .textures .remove(texture.tracker_index()); let suf = A::surface_as_hal(&surface); - let exclusive_snatch_guard = device.snatchable_lock.write(); - match texture.inner.snatch(exclusive_snatch_guard).unwrap() { + match texture + .inner + .snatch(&mut device.snatchable_lock.write()) + .unwrap() + { resource::TextureInner::Surface { mut raw, parent_id } => { if surface_id == parent_id { unsafe { suf.unwrap().discard_texture(raw.take().unwrap()) }; diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 5b11525126f..7601d569c63 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -440,14 +440,22 @@ pub struct Buffer { pub(crate) tracking_data: TrackingData, pub(crate) map_state: Mutex>, pub(crate) bind_groups: Mutex>>>, + #[cfg(feature = "indirect-validation")] + pub(crate) raw_indirect_validation_bind_group: Snatchable, } impl Drop for Buffer { fn drop(&mut self) { + use hal::Device; + #[cfg(feature = "indirect-validation")] + if let Some(raw) = self.raw_indirect_validation_bind_group.take() { + unsafe { + self.device.raw().destroy_bind_group(raw); + } + } if let Some(raw) = self.raw.take() { resource_log!("Destroy raw {}", self.error_ident()); unsafe { - use hal::Device; self.device.raw().destroy_buffer(raw); } } @@ -700,14 +708,22 @@ impl Buffer { let device = &self.device; let temp = { - let snatch_guard = device.snatchable_lock.write(); - let raw = match self.raw.snatch(snatch_guard) { + let mut snatch_guard = device.snatchable_lock.write(); + + let raw = match self.raw.snatch(&mut snatch_guard) { Some(raw) => raw, None => { return Err(DestroyError::AlreadyDestroyed); } }; + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = self + .raw_indirect_validation_bind_group + .snatch(&mut snatch_guard); + + drop(snatch_guard); + let bind_groups = { let mut guard = self.bind_groups.lock(); mem::take(&mut *guard) @@ -718,6 +734,8 @@ impl Buffer { device: Arc::clone(&self.device), label: self.label().to_owned(), bind_groups, + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }) }; @@ -753,6 +771,8 @@ pub enum CreateBufferError { MaxBufferSize { requested: u64, maximum: u64 }, #[error(transparent)] MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error("Failed to create bind group for indirect buffer validation: {0}")] + IndirectValidationBindGroup(DeviceError), } crate::impl_resource_type!(Buffer); @@ -768,6 +788,8 @@ pub struct DestroyedBuffer { device: Arc>, label: String, bind_groups: Vec>>, + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group: Option, } impl DestroyedBuffer { @@ -778,17 +800,25 @@ impl DestroyedBuffer { impl Drop for DestroyedBuffer { fn drop(&mut self) { + use hal::Device; + let mut deferred = self.device.deferred_destroy.lock(); for bind_group in self.bind_groups.drain(..) { deferred.push(DeferredDestroy::BindGroup(bind_group)); } drop(deferred); + #[cfg(feature = "indirect-validation")] + if let Some(raw) = self.raw_indirect_validation_bind_group.take() { + unsafe { + self.device.raw().destroy_bind_group(raw); + } + } + if let Some(raw) = self.raw.take() { resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label()); unsafe { - use hal::Device; self.device.raw().destroy_buffer(raw); } } @@ -1155,8 +1185,7 @@ impl Texture { let device = &self.device; let temp = { - let snatch_guard = device.snatchable_lock.write(); - let raw = match self.inner.snatch(snatch_guard) { + let raw = match self.inner.snatch(&mut device.snatchable_lock.write()) { Some(TextureInner::Native { raw }) => raw, Some(TextureInner::Surface { .. }) => { return Ok(()); diff --git a/wgpu-core/src/snatch.rs b/wgpu-core/src/snatch.rs index 6f60f45d857..55fe04aa9d0 100644 --- a/wgpu-core/src/snatch.rs +++ b/wgpu-core/src/snatch.rs @@ -32,6 +32,12 @@ impl Snatchable { } } + pub fn empty() -> Self { + Snatchable { + value: UnsafeCell::new(None), + } + } + /// Get read access to the value. Requires a the snatchable lock's read guard. pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> { unsafe { (*self.value.get()).as_ref() } @@ -43,7 +49,7 @@ impl Snatchable { } /// Take the value. Requires a the snatchable lock's write guard. - pub fn snatch(&self, _guard: ExclusiveSnatchGuard) -> Option { + pub fn snatch(&self, _guard: &mut ExclusiveSnatchGuard) -> Option { unsafe { (*self.value.get()).take() } } diff --git a/wgpu/Cargo.toml b/wgpu/Cargo.toml index cd73f5dc9e7..b092fd9c611 100644 --- a/wgpu/Cargo.toml +++ b/wgpu/Cargo.toml @@ -130,6 +130,12 @@ features = ["raw-window-handle"] workspace = true features = ["raw-window-handle"] +# If we are not targeting WebGL, enable indirect-validation. +# WebGL doesn't support indirect execution so this is not needed. +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.wgc] +workspace = true +features = ["indirect-validation"] + # Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to # enable the Metal backend while being no-op on other targets. [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc]