Skip to content

Commit

Permalink
use 2 destination buffers for indirect dispatch validation
Browse files Browse the repository at this point in the history
This removes the required barrier prior to the validation dispatch.
  • Loading branch information
teoxoy authored and ErichDonGubler committed Oct 11, 2024
1 parent 2ec4da7 commit 65354cd
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 28 deletions.
3 changes: 3 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ async fn run_test(
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
}
// Issue multiple dispatches to test the internal destination buffer switching
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

Expand Down
21 changes: 10 additions & 11 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,13 +942,6 @@ fn dispatch_indirect(
state.raw_encoder.transition_buffers(src_barrier.as_slice());
}

unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
}]);
}

unsafe {
state.raw_encoder.dispatch([1, 1, 1]);
}
Expand Down Expand Up @@ -987,10 +980,16 @@ fn dispatch_indirect(
}

unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
}]);
state.raw_encoder.transition_buffers(&[
hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
},
hal::BufferBarrier {
buffer: params.other_dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
},
]);
}

state.flush_states(None)?;
Expand Down
93 changes: 76 additions & 17 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::AtomicBool;

use thiserror::Error;

use crate::{
Expand Down Expand Up @@ -34,14 +36,18 @@ pub struct IndirectValidation {
src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
pipeline: Box<dyn hal::DynComputePipeline>,
dst_buffer: Box<dyn hal::DynBuffer>,
dst_bind_group: Box<dyn hal::DynBindGroup>,
dst_buffer_0: Box<dyn hal::DynBuffer>,
dst_buffer_1: Box<dyn hal::DynBuffer>,
dst_bind_group_0: Box<dyn hal::DynBindGroup>,
dst_bind_group_1: Box<dyn hal::DynBindGroup>,
is_next_dst_0: AtomicBool,
}

pub struct Params<'a> {
pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
pub pipeline: &'a dyn hal::DynComputePipeline,
pub dst_buffer: &'a dyn hal::DynBuffer,
pub other_dst_buffer: &'a dyn hal::DynBuffer,
pub dst_bind_group: &'a dyn hal::DynBindGroup,
pub aligned_offset: u64,
pub offset_remainder: u64,
Expand All @@ -54,7 +60,8 @@ impl IndirectValidation {
) -> Result<Self, CreateDispatchIndirectValidationPipelineError> {
let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;

let src = format!("
let src = format!(
"
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>;
@group(1) @binding(0)
Expand All @@ -72,7 +79,8 @@ impl IndirectValidation {
dst[1] = res.y;
dst[2] = res.z;
}}
");
"
);

let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
CreateShaderModuleError::Parsing(naga::error::ShaderError {
Expand Down Expand Up @@ -207,10 +215,12 @@ impl IndirectValidation {
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
let dst_buffer =
let dst_buffer_0 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
let dst_buffer_1 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;

let dst_bind_group_desc = hal::BindGroupDescriptor {
let dst_bind_group_desc_0 = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
Expand All @@ -219,17 +229,40 @@ impl IndirectValidation {
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer.as_ref(),
buffer: dst_buffer_0.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group = unsafe {
let dst_bind_group_0 = unsafe {
device
.create_bind_group(&dst_bind_group_desc)
.create_bind_group(&dst_bind_group_desc_0)
.map_err(DeviceError::from_hal)
}?;

let dst_bind_group_desc_1 = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer_1.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group_1 = unsafe {
device
.create_bind_group(&dst_bind_group_desc_1)
.map_err(DeviceError::from_hal)
}?;

Expand All @@ -239,8 +272,11 @@ impl IndirectValidation {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: AtomicBool::new(false),
})
}

Expand Down Expand Up @@ -299,11 +335,29 @@ impl IndirectValidation {
let aligned_offset = aligned_offset.min(max_aligned_offset);
let offset_remainder = offset - aligned_offset;

let (dst_buffer, other_dst_buffer, dst_bind_group) = if self
.is_next_dst_0
.fetch_xor(true, core::sync::atomic::Ordering::AcqRel)
{
(
self.dst_buffer_0.as_ref(),
self.dst_buffer_1.as_ref(),
self.dst_bind_group_0.as_ref(),
)
} else {
(
self.dst_buffer_1.as_ref(),
self.dst_buffer_0.as_ref(),
self.dst_bind_group_1.as_ref(),
)
};

Params {
pipeline_layout: self.pipeline_layout.as_ref(),
pipeline: self.pipeline.as_ref(),
dst_buffer: self.dst_buffer.as_ref(),
dst_bind_group: self.dst_bind_group.as_ref(),
dst_buffer,
other_dst_buffer,
dst_bind_group,
aligned_offset,
offset_remainder,
}
Expand All @@ -316,13 +370,18 @@ impl IndirectValidation {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: _,
} = self;

unsafe {
device.destroy_bind_group(dst_bind_group);
device.destroy_buffer(dst_buffer);
device.destroy_bind_group(dst_bind_group_0);
device.destroy_bind_group(dst_bind_group_1);
device.destroy_buffer(dst_buffer_0);
device.destroy_buffer(dst_buffer_1);
device.destroy_compute_pipeline(pipeline);
device.destroy_pipeline_layout(pipeline_layout);
device.destroy_bind_group_layout(src_bind_group_layout);
Expand Down

0 comments on commit 65354cd

Please sign in to comment.