From 40761af4d62ada4a66ab704cf0d2e8f7ae43d8c2 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:38:17 +0200 Subject: [PATCH] use 2 destination buffers for indirect dispatch validation This removes the required barrier prior to the validation dispatch. --- tests/tests/dispatch_workgroups_indirect.rs | 3 + wgpu-core/src/command/compute.rs | 28 +++---- wgpu-core/src/indirect_validation.rs | 87 +++++++++++++++++---- 3 files changed, 88 insertions(+), 30 deletions(-) diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs index 745de5f295b..b37a1f2bc02 100644 --- a/tests/tests/dispatch_workgroups_indirect.rs +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -212,6 +212,9 @@ async fn run_test( if !forget_to_set_bind_group { compute_pass.set_bind_group(0, &bind_group, &[]); } + // Issue multiple dispatches to test the internal destination buffer switching + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); } diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 917011c7070..2d8fd661fc7 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -979,15 +979,6 @@ fn dispatch_indirect( .transition_buffers(src_barrier.into_iter()); } - unsafe { - state - .raw_encoder - .transition_buffers(std::iter::once(hal::BufferBarrier { - buffer: params.dst_buffer, - usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, - })); - } - unsafe { state.raw_encoder.dispatch([1, 1, 1]); } @@ -1026,12 +1017,19 @@ fn dispatch_indirect( } unsafe { - state - .raw_encoder - .transition_buffers(std::iter::once(hal::BufferBarrier { - buffer: params.dst_buffer, - usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, - })); + state.raw_encoder.transition_buffers( + [ + hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, + }, + hal::BufferBarrier { + buffer: params.other_dst_buffer, + usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, + }, + ] + .into_iter(), + ); } state.flush_states(None)?; diff --git a/wgpu-core/src/indirect_validation.rs b/wgpu-core/src/indirect_validation.rs index 5e802172a29..9a719e80a05 100644 --- a/wgpu-core/src/indirect_validation.rs +++ b/wgpu-core/src/indirect_validation.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::AtomicBool; + use thiserror::Error; use crate::{ @@ -35,14 +37,18 @@ pub struct IndirectValidation { src_bind_group_layout: A::BindGroupLayout, pipeline_layout: A::PipelineLayout, pipeline: A::ComputePipeline, - dst_buffer: A::Buffer, - dst_bind_group: A::BindGroup, + dst_buffer_0: A::Buffer, + dst_buffer_1: A::Buffer, + dst_bind_group_0: A::BindGroup, + dst_bind_group_1: A::BindGroup, + is_next_dst_0: AtomicBool, } pub struct Params<'a, A: HalApi> { pub pipeline_layout: &'a A::PipelineLayout, pub pipeline: &'a A::ComputePipeline, pub dst_buffer: &'a A::Buffer, + pub other_dst_buffer: &'a A::Buffer, pub dst_bind_group: &'a A::BindGroup, pub aligned_offset: u64, pub offset_remainder: u64, @@ -202,10 +208,12 @@ impl IndirectValidation { usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, memory_flags: hal::MemoryFlags::empty(), }; - let dst_buffer = + let dst_buffer_0 = + unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?; + let dst_buffer_1 = unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?; - let dst_bind_group_desc = hal::BindGroupDescriptor { + let dst_bind_group_desc_0 = hal::BindGroupDescriptor { label: None, layout: &dst_bind_group_layout, entries: &[hal::BindGroupEntry { @@ -214,7 +222,7 @@ impl IndirectValidation { count: 1, }], buffers: &[hal::BufferBinding { - buffer: &dst_buffer, + buffer: &dst_buffer_0, offset: 0, size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), }], @@ -222,9 +230,32 @@ impl IndirectValidation { textures: &[], acceleration_structures: &[], }; - let dst_bind_group = unsafe { + let dst_bind_group_0 = unsafe { device - .create_bind_group(&dst_bind_group_desc) + .create_bind_group(&dst_bind_group_desc_0) + .map_err(DeviceError::from) + }?; + + let dst_bind_group_desc_1 = hal::BindGroupDescriptor { + label: None, + layout: &dst_bind_group_layout, + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer: &dst_buffer_1, + offset: 0, + size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + let dst_bind_group_1 = unsafe { + device + .create_bind_group(&dst_bind_group_desc_1) .map_err(DeviceError::from) }?; @@ -234,8 +265,11 @@ impl IndirectValidation { src_bind_group_layout, pipeline_layout, pipeline, - dst_buffer, - dst_bind_group, + dst_buffer_0, + dst_buffer_1, + dst_bind_group_0, + dst_bind_group_1, + is_next_dst_0: AtomicBool::new(false), }) } @@ -299,11 +333,29 @@ impl IndirectValidation { let aligned_offset = aligned_offset.min(max_aligned_offset); let offset_remainder = offset - aligned_offset; + let (dst_buffer, other_dst_buffer, dst_bind_group) = if self + .is_next_dst_0 + .fetch_xor(true, core::sync::atomic::Ordering::AcqRel) + { + ( + &self.dst_buffer_0, + &self.dst_buffer_1, + &self.dst_bind_group_0, + ) + } else { + ( + &self.dst_buffer_1, + &self.dst_buffer_0, + &self.dst_bind_group_1, + ) + }; + Params { pipeline_layout: &self.pipeline_layout, pipeline: &self.pipeline, - dst_buffer: &self.dst_buffer, - dst_bind_group: &self.dst_bind_group, + dst_buffer, + other_dst_buffer, + dst_bind_group, aligned_offset, offset_remainder, } @@ -316,14 +368,19 @@ impl IndirectValidation { src_bind_group_layout, pipeline_layout, pipeline, - dst_buffer, - dst_bind_group, + dst_buffer_0, + dst_buffer_1, + dst_bind_group_0, + dst_bind_group_1, + is_next_dst_0: _, } = self; use hal::Device; unsafe { - device.destroy_bind_group(dst_bind_group); - device.destroy_buffer(dst_buffer); + device.destroy_bind_group(dst_bind_group_0); + device.destroy_bind_group(dst_bind_group_1); + device.destroy_buffer(dst_buffer_0); + device.destroy_buffer(dst_buffer_1); device.destroy_compute_pipeline(pipeline); device.destroy_pipeline_layout(pipeline_layout); device.destroy_bind_group_layout(src_bind_group_layout);