Skip to content

Commit

Permalink
use 2 destination buffers for indirect dispatch validation
Browse files Browse the repository at this point in the history
This removes the required barrier prior to the validation dispatch.
  • Loading branch information
teoxoy committed Jul 24, 2024
1 parent 63074d6 commit 40761af
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 30 deletions.
3 changes: 3 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ async fn run_test(
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, &bind_group, &[]);
}
// Issue multiple dispatches to test the internal destination buffer switching
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

Expand Down
28 changes: 13 additions & 15 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -979,15 +979,6 @@ fn dispatch_indirect<A: HalApi>(
.transition_buffers(src_barrier.into_iter());
}

unsafe {
state
.raw_encoder
.transition_buffers(std::iter::once(hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
}));
}

unsafe {
state.raw_encoder.dispatch([1, 1, 1]);
}
Expand Down Expand Up @@ -1026,12 +1017,19 @@ fn dispatch_indirect<A: HalApi>(
}

unsafe {
state
.raw_encoder
.transition_buffers(std::iter::once(hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
}));
state.raw_encoder.transition_buffers(
[
hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
},
hal::BufferBarrier {
buffer: params.other_dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
},
]
.into_iter(),
);
}

state.flush_states(None)?;
Expand Down
87 changes: 72 additions & 15 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::AtomicBool;

use thiserror::Error;

use crate::{
Expand Down Expand Up @@ -35,14 +37,18 @@ pub struct IndirectValidation<A: HalApi> {
src_bind_group_layout: A::BindGroupLayout,
pipeline_layout: A::PipelineLayout,
pipeline: A::ComputePipeline,
dst_buffer: A::Buffer,
dst_bind_group: A::BindGroup,
dst_buffer_0: A::Buffer,
dst_buffer_1: A::Buffer,
dst_bind_group_0: A::BindGroup,
dst_bind_group_1: A::BindGroup,
is_next_dst_0: AtomicBool,
}

pub struct Params<'a, A: HalApi> {
pub pipeline_layout: &'a A::PipelineLayout,
pub pipeline: &'a A::ComputePipeline,
pub dst_buffer: &'a A::Buffer,
pub other_dst_buffer: &'a A::Buffer,
pub dst_bind_group: &'a A::BindGroup,
pub aligned_offset: u64,
pub offset_remainder: u64,
Expand Down Expand Up @@ -202,10 +208,12 @@ impl<A: HalApi> IndirectValidation<A> {
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
let dst_buffer =
let dst_buffer_0 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?;
let dst_buffer_1 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?;

let dst_bind_group_desc = hal::BindGroupDescriptor {
let dst_bind_group_desc_0 = hal::BindGroupDescriptor {
label: None,
layout: &dst_bind_group_layout,
entries: &[hal::BindGroupEntry {
Expand All @@ -214,17 +222,40 @@ impl<A: HalApi> IndirectValidation<A> {
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: &dst_buffer,
buffer: &dst_buffer_0,
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group = unsafe {
let dst_bind_group_0 = unsafe {
device
.create_bind_group(&dst_bind_group_desc)
.create_bind_group(&dst_bind_group_desc_0)
.map_err(DeviceError::from)
}?;

let dst_bind_group_desc_1 = hal::BindGroupDescriptor {
label: None,
layout: &dst_bind_group_layout,
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: &dst_buffer_1,
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group_1 = unsafe {
device
.create_bind_group(&dst_bind_group_desc_1)
.map_err(DeviceError::from)
}?;

Expand All @@ -234,8 +265,11 @@ impl<A: HalApi> IndirectValidation<A> {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: AtomicBool::new(false),
})
}

Expand Down Expand Up @@ -299,11 +333,29 @@ impl<A: HalApi> IndirectValidation<A> {
let aligned_offset = aligned_offset.min(max_aligned_offset);
let offset_remainder = offset - aligned_offset;

let (dst_buffer, other_dst_buffer, dst_bind_group) = if self
.is_next_dst_0
.fetch_xor(true, core::sync::atomic::Ordering::AcqRel)
{
(
&self.dst_buffer_0,
&self.dst_buffer_1,
&self.dst_bind_group_0,
)
} else {
(
&self.dst_buffer_1,
&self.dst_buffer_0,
&self.dst_bind_group_1,
)
};

Params {
pipeline_layout: &self.pipeline_layout,
pipeline: &self.pipeline,
dst_buffer: &self.dst_buffer,
dst_bind_group: &self.dst_bind_group,
dst_buffer,
other_dst_buffer,
dst_bind_group,
aligned_offset,
offset_remainder,
}
Expand All @@ -316,14 +368,19 @@ impl<A: HalApi> IndirectValidation<A> {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: _,
} = self;

use hal::Device;
unsafe {
device.destroy_bind_group(dst_bind_group);
device.destroy_buffer(dst_buffer);
device.destroy_bind_group(dst_bind_group_0);
device.destroy_bind_group(dst_bind_group_1);
device.destroy_buffer(dst_buffer_0);
device.destroy_buffer(dst_buffer_1);
device.destroy_compute_pipeline(pipeline);
device.destroy_pipeline_layout(pipeline_layout);
device.destroy_bind_group_layout(src_bind_group_layout);
Expand Down

0 comments on commit 40761af

Please sign in to comment.