Skip to content

Commit

Permalink
ensure safety of indirect dispatch
Browse files Browse the repository at this point in the history
by injecting a compute shader that validates the content of the indirect buffer

also adds missing indirect buffer offset validation
  • Loading branch information
teoxoy committed Jun 11, 2024
1 parent be4eabc commit ab03bc6
Show file tree
Hide file tree
Showing 20 changed files with 862 additions and 21 deletions.
2 changes: 2 additions & 0 deletions deno_webgpu/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ pub fn op_webgpu_create_pipeline_layout(
label: Some(label),
bind_group_layouts: Cow::from(bind_group_layouts),
push_constant_ranges: Default::default(),
ignore_push_constant_check: false,
};

gfx_put!(device => instance.device_create_pipeline_layout(
Expand Down Expand Up @@ -288,6 +289,7 @@ pub fn op_webgpu_create_bind_group(
buffer_id: buffer_resource.1,
offset: entry.offset.unwrap_or(0),
size: std::num::NonZeroU64::new(entry.size.unwrap_or(0)),
allow_indirect_as_storage: false,
},
)
}
Expand Down
1 change: 1 addition & 0 deletions deno_webgpu/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub fn op_webgpu_create_shader_module(
let descriptor = wgpu_core::pipeline::ShaderModuleDescriptor {
label: Some(label),
shader_bound_checks: wgpu_types::ShaderBoundChecks::default(),
ignore_push_constant_check: false,
};

gfx_put!(device => instance.device_create_shader_module(
Expand Down
197 changes: 197 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits::downlevel_defaults())
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;

let res = run_test(&ctx, &[max, max, max], false).await;
assert_eq!(res, [max; 3]);

let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, max + 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, 1, max + 1], false).await;
assert_eq!(res, [0; 3]);
});

/// Make sure that unsetting the bind group set by the validation code works properly.
#[gpu_test]
static UNSET_INTERNAL_BIND_GROUP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits::downlevel_defaults()),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

let _ = run_test(&ctx, &[0, 0, 0], true).await;

let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| format!("{error}")
.contains("Expected bind group is missing")));
});

async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> out: array<u32, 3>;
@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3<u32>, @builtin(workgroup_id) workgroup_id: vec3<u32>) {
if (all(workgroup_id == vec3<u32>())) {
out[0] = num_workgroups.x;
out[1] = num_workgroups.y;
out[2] = num_workgroups.z;
}
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut res = None;

for (indirect_offset, indirect_buffer_size) in [
// internal src buffer binding size will be buffer.size
(0, 12),
(4, 4 + 12),
(4, 8 + 12),
(256 * 2 - 4 - 12, 256 * 2 - 4),
// internal src buffer binding size will be 256 * 2 + x
(0, 256 * 2 * 2 + 4),
(256, 256 * 2 * 2 + 8),
(256 + 4, 256 * 2 * 2 + 12),
(256 * 2 + 16, 256 * 2 * 2 + 16),
(256 * 2 * 2, 256 * 2 * 2 + 32),
(256 + 12, 256 * 2 * 2 + 64),
] {
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: indirect_buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

ctx.queue.write_buffer(
&indirect_buffer,
indirect_offset,
bytemuck::bytes_of(num_workgroups),
);

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&pipeline);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, &bind_group, &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();

let current_res = *bytemuck::from_bytes(&view);
drop(view);
readback_buffer.unmap();

if let Some(past_res) = res {
assert_eq!(past_res, current_res);
} else {
res = Some(current_res);
}
}

res.unwrap()
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod clear_texture;
mod compute_pass_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ renderdoc = ["hal/renderdoc"]
## to the validation carried out at public APIs in all builds.
strict_asserts = ["wgt/strict_asserts"]

## Validates indirect draw/dispatch calls. This will also enable naga's
## WGSL frontend since we use a WGSL compute shader to do the validation.
indirect-validation = ["naga/wgsl-in"]

## Enables serialization via `serde` on common wgpu types.
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]

Expand Down
7 changes: 7 additions & 0 deletions wgpu-core/src/binding_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ pub struct PipelineLayoutDescriptor<'a> {
/// [`Features::PUSH_CONSTANTS`](wgt::Features::PUSH_CONSTANTS) feature must
/// be enabled.
pub push_constant_ranges: Cow<'a, [wgt::PushConstantRange]>,
/// This is an internal flag used by indirect validation.
/// It allows usage of push constants without having the
/// [`Features::PUSH_CONSTANTS`](wgt::Features::PUSH_CONSTANTS) feature enabled.
pub ignore_push_constant_check: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -758,6 +762,9 @@ pub struct BufferBinding {
pub buffer_id: BufferId,
pub offset: wgt::BufferAddress,
pub size: Option<wgt::BufferSize>,
/// This is an internal flag used by indirect validation.
/// It allows indirect buffers to be bound as storage buffers.
pub allow_indirect_as_storage: bool,
}

// Note: Duplicated in `wgpu-rs` as `BindingResource`
Expand Down
14 changes: 13 additions & 1 deletion wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod compat {
diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}"))
}
} else {
diff.push("Assigned bind group layout not found (internal error)".to_owned());
diff.push("Expected bind group is missing".to_owned());
}
} else {
diff.push("Expected bind group layout not found (internal error)".to_owned());
Expand Down Expand Up @@ -191,6 +191,10 @@ mod compat {
self.make_range(index)
}

pub fn unassign(&mut self, index: usize) {
self.entries[index].assigned = None;
}

pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
self.entries
.iter()
Expand Down Expand Up @@ -358,6 +362,14 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn unassign_group(&mut self, index: usize) {
log::trace!("\tBinding [{}] = null", index);

self.payloads[index].reset();

self.manager.unassign(index);
}

pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
Expand Down
31 changes: 31 additions & 0 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ pub enum ComputePassErrorInner {
InvalidQuerySet(id::QuerySetId),
#[error("Indirect buffer {0:?} is invalid or destroyed")]
InvalidIndirectBuffer(id::BufferId),
#[error("Indirect buffer offset {0:?} is not a multiple of 4")]
UnalignedIndirectBufferOffset(BufferAddress),
#[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
IndirectBufferOverrun {
offset: u64,
Expand Down Expand Up @@ -473,6 +475,16 @@ impl Global {
.map_pass_err(pass_scope);
}

#[cfg(feature = "indirect-validation")]
let mut base = base;
#[cfg(feature = "indirect-validation")]
device
.indirect_validation
.get()
.unwrap()
.inject_dispatch_indirect_validation(device, &mut base)
.map_pass_err(pass_scope)?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down Expand Up @@ -654,6 +666,20 @@ impl Global {
}
}
}
ArcComputeCommand::UnsetBindGroup { index } => {
let scope = PassErrorScope::UnsetBindGroup(index);

let max_bind_groups = cmd_buf.limits.max_bind_groups;
if index >= max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
index,
max: max_bind_groups,
})
.map_pass_err(scope);
}

state.binder.unassign_group(index as usize);
}
ArcComputeCommand::SetPipeline(pipeline) => {
let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
Expand Down Expand Up @@ -811,6 +837,11 @@ impl Global {
check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
.map_pass_err(scope)?;

if offset % 4 != 0 {
return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset))
.map_pass_err(scope);
}

let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
if end_offset > buffer.size {
return Err(ComputePassErrorInner::IndirectBufferOverrun {
Expand Down
16 changes: 16 additions & 0 deletions wgpu-core/src/command/compute_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub enum ComputeCommand {
bind_group_id: id::BindGroupId,
},

UnsetBindGroup {
index: u32,
},

SetPipeline(id::ComputePipelineId),

/// Set a range of push constants to values stored in `push_constant_data`.
Expand Down Expand Up @@ -103,6 +107,10 @@ impl ComputeCommand {
})?,
},

ComputeCommand::UnsetBindGroup { index } => {
ArcComputeCommand::UnsetBindGroup { index }
}

ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
pipelines_guard
.get_owned(pipeline_id)
Expand Down Expand Up @@ -194,6 +202,10 @@ pub enum ArcComputeCommand<A: HalApi> {
bind_group: Arc<BindGroup<A>>,
},

UnsetBindGroup {
index: u32,
},

SetPipeline(Arc<ComputePipeline<A>>),

/// Set a range of push constants to values stored in `push_constant_data`.
Expand Down Expand Up @@ -261,6 +273,10 @@ impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
bind_group_id: bind_group.as_info().id(),
},

ArcComputeCommand::UnsetBindGroup { index } => {
ComputeCommand::UnsetBindGroup { index: *index }
}

ArcComputeCommand::SetPipeline(pipeline) => {
ComputeCommand::SetPipeline(pipeline.as_info().id())
}
Expand Down
Loading

0 comments on commit ab03bc6

Please sign in to comment.