Skip to content

[hal metal] ray tracing acceleration structures #7660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5b34aeb
Removes Option<> around AccelerationStructureTriangleIndices::buffer.
Lichtso May 3, 2025
503f197
Removes Option<> around AccelerationStructureTriangles::vertex_buffer.
Lichtso May 3, 2025
8ef0163
Removes Option<> around AccelerationStructureAABBs::buffer.
Lichtso May 3, 2025
293d51b
Removes Option<> around AccelerationStructureInstances::buffer.
Lichtso May 3, 2025
58a96b6
Fixes index_buffer label in ray_traced_triangle example.
Lichtso May 2, 2025
2b4cb04
Fixes min_push_constant_size in ray_shadows example.
Lichtso May 3, 2025
9b472f0
Updates CHANGELOG.
Lichtso Apr 30, 2025
7a1c0d2
Adds feature detection.
Lichtso Apr 30, 2025
f0a71a8
Sets raw_tlas_instance_size.
Lichtso Apr 30, 2025
b4f4ddd
Sets ray_tracing_scratch_buffer_alignment.
Lichtso May 2, 2025
59fd0d4
Adds conv::map_index_format().
Lichtso Apr 30, 2025
6564edf
Adds conv::map_acceleration_structure_descriptor().
Lichtso Apr 30, 2025
063667b
Adds AccelerationStructurePtr.
Lichtso Apr 30, 2025
06c161a
Implements AccelerationStructure.
Lichtso Apr 30, 2025
f0578d5
Adds CommandState::acceleration_structure_builder.
Lichtso Apr 30, 2025
681e8fa
Implements CommandEncoder::copy_acceleration_structure_to_acceleratio…
Lichtso Apr 30, 2025
c7456d0
Implements CommandEncoder::build_acceleration_structures().
Lichtso Apr 30, 2025
0a5c426
Implements CommandEncoder::place_acceleration_structure_barrier().
Lichtso May 2, 2025
5f47288
Implements CommandEncoder::read_acceleration_structure_compact_size().
Lichtso Apr 30, 2025
d5c7573
Implements Device::get_acceleration_structure_build_sizes().
Lichtso Apr 30, 2025
131d60e
Implements Device::get_acceleration_structure_device_address().
Lichtso May 2, 2025
e96f73d
Implements Device::create_acceleration_structure().
Lichtso Apr 30, 2025
0b64717
Implements Device::destroy_acceleration_structure().
Lichtso Apr 30, 2025
af4dfa6
Implements Device::tlas_instance_to_bytes().
Lichtso Apr 30, 2025
e2a6498
Implements resource binding.
Lichtso Apr 30, 2025
d8ad785
Memorizes BLAS dependencies of TLAS to call use_resource() in set_bin…
Lichtso May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ Naga now infers the correct binding layout when a resource appears only in an as

- Use highest SPIR-V version supported by Vulkan API version. By @robamler in [#7595](https://github.com/gfx-rs/wgpu/pull/7595)

#### Metal

- Implements ray-tracing acceleration structures for metal backend. By @lichtso in [#7660](https://github.com/gfx-rs/wgpu/pull/7660)

### Bug Fixes

#### Naga
Expand Down
4 changes: 2 additions & 2 deletions examples/features/src/ray_shadows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl crate::framework::Example for Example {

fn required_limits() -> wgpu::Limits {
wgpu::Limits {
max_push_constant_size: 12,
max_push_constant_size: 16,
..wgpu::Limits::default()
}
}
Expand Down Expand Up @@ -209,7 +209,7 @@ impl crate::framework::Example for Example {
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[wgpu::PushConstantRange {
stages: wgpu::ShaderStages::FRAGMENT,
range: 0..12,
range: 0..16,
}],
});

Expand Down
1 change: 1 addition & 0 deletions examples/features/src/ray_shadows/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var acc_struct: acceleration_structure;

struct PushConstants {
light: vec3<f32>,
padding: f32,
Copy link
Contributor Author

@Lichtso Lichtso May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that metal always sends at least 16 bytes for push constants, even if we only pass in 12 bytes. And then the shader validation complains that the receiver here only expects 12 bytes.

}
var<push_constant> pc: PushConstants;

Expand Down
2 changes: 1 addition & 1 deletion examples/features/src/ray_traced_triangle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl crate::framework::Example for Example {
});

let index_buffer = device.create_buffer_init(&BufferInitDescriptor {
label: Some("vertex buffer"),
label: Some("index buffer"),
contents: bytemuck::cast_slice(&indices),
usage: BufferUsages::BLAS_INPUT,
});
Expand Down
15 changes: 11 additions & 4 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl Global {
tlas,
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(instance_buffer),
buffer: instance_buffer,
offset: 0,
count: entry.instance_count,
},
Expand Down Expand Up @@ -584,6 +584,13 @@ impl Global {
dependencies.push(blas.clone());
}

let dependencies_raw = dependencies
.iter()
.map(|blas| blas.try_raw(&snatch_guard).unwrap())
.collect::<Vec<_>>();
let tlas_raw = tlas.try_raw(&snatch_guard)?;
tlas_raw.set_dependencies(&dependencies_raw);

build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(),
dependencies,
Expand All @@ -602,7 +609,7 @@ impl Global {
tlas: tlas.clone(),
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(tlas.instance_buffer.as_ref()),
buffer: tlas.instance_buffer.as_ref(),
offset: 0,
count: instance_count,
},
Expand Down Expand Up @@ -1141,7 +1148,7 @@ fn iter_buffers<'a, 'b>(
};

let triangles = hal::AccelerationStructureTriangles {
vertex_buffer: Some(vertex_buffer),
vertex_buffer,
vertex_format: mesh.size.vertex_format,
first_vertex: mesh.first_vertex,
vertex_count: mesh.size.vertex_count,
Expand All @@ -1150,7 +1157,7 @@ fn iter_buffers<'a, 'b>(
let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
format: mesh.size.index_format.unwrap(),
buffer: Some(index_buffer),
buffer: index_buffer,
offset: mesh.first_index.unwrap() * index_stride,
count: mesh.size.index_count.unwrap(),
}
Expand Down
6 changes: 3 additions & 3 deletions wgpu-core/src/device/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Device {
dyn hal::DynBuffer,
> {
format: desc.index_format.unwrap(),
buffer: None,
buffer: self.zero_buffer.as_ref(),
offset: 0,
count,
});
Expand Down Expand Up @@ -78,7 +78,7 @@ impl Device {
}

entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
vertex_buffer: None,
vertex_buffer: self.zero_buffer.as_ref(),
vertex_format: desc.vertex_format,
first_vertex: 0,
vertex_count: desc.vertex_count,
Expand Down Expand Up @@ -158,7 +158,7 @@ impl Device {
&hal::GetAccelerationStructureBuildSizesDescriptor {
entries: &hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: None,
buffer: self.zero_buffer.as_ref(),
offset: 0,
count: desc.max_instances,
},
Expand Down
1 change: 1 addition & 0 deletions wgpu-hal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ metal = [
"naga/msl-out",
"dep:arrayvec",
"dep:block",
"dep:bytemuck",
"dep:core-graphics-types",
"dep:hashbrown",
"dep:libc",
Expand Down
182 changes: 87 additions & 95 deletions wgpu-hal/examples/ray-traced-triangle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ impl<A: hal::Api> Example<A> {
};

let blas_triangles = vec![hal::AccelerationStructureTriangles {
vertex_buffer: Some(&vertices_buffer),
vertex_buffer: &vertices_buffer,
first_vertex: 0,
vertex_format: wgpu_types::VertexFormat::Float32x3,
// each vertex is 3 floats, and floats are stored raw in the array
vertex_count: vertices.len() as u32 / 3,
vertex_stride: 3 * 4,
indices: indices_buffer.as_ref().map(|(buf, len)| {
indices: indices_buffer.as_ref().map(|(buffer, len)| {
hal::AccelerationStructureTriangleIndices {
buffer: Some(buf),
buffer,
format: wgpu_types::IndexFormat::Uint32,
offset: 0,
count: *len as u32,
Expand All @@ -493,13 +493,6 @@ impl<A: hal::Api> Example<A> {
}];
let blas_entries = hal::AccelerationStructureEntries::Triangles(blas_triangles);

let mut tlas_entries =
hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances {
buffer: None,
count: 3,
offset: 0,
});

let blas_sizes = unsafe {
device.get_acceleration_structure_build_sizes(
&hal::GetAccelerationStructureBuildSizesDescriptor {
Expand All @@ -509,6 +502,89 @@ impl<A: hal::Api> Example<A> {
)
};

let blas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("blas"),
size: blas_sizes.acceleration_structure_size,
format: hal::AccelerationStructureFormat::BottomLevel,
allow_compaction: false,
})
}
.unwrap();

let instances = [
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 0.0,
y: 0.0,
z: 0.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: -1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
];

let instances_buffer_size = instances.len() * size_of::<AccelerationStructureInstance>();

let instances_buffer = unsafe {
let instances_buffer = device
.create_buffer(&hal::BufferDescriptor {
label: Some("instances_buffer"),
size: instances_buffer_size as u64,
usage: wgpu_types::BufferUses::MAP_WRITE
| wgpu_types::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT,
})
.unwrap();

let mapping = device
.map_buffer(&instances_buffer, 0..instances_buffer_size as u64)
.unwrap();
ptr::copy_nonoverlapping(
instances.as_ptr() as *const u8,
mapping.ptr.as_ptr(),
instances_buffer_size,
);
device.unmap_buffer(&instances_buffer);
assert!(mapping.is_coherent);

instances_buffer
};

let tlas_entries =
hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances {
buffer: &instances_buffer,
count: 3,
offset: 0,
});

let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE
| hal::AccelerationStructureBuildFlags::ALLOW_UPDATE;

Expand All @@ -521,16 +597,6 @@ impl<A: hal::Api> Example<A> {
)
};

let blas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("blas"),
size: blas_sizes.acceleration_structure_size,
format: hal::AccelerationStructureFormat::BottomLevel,
allow_compaction: false,
})
}
.unwrap();

let tlas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("tlas"),
Expand Down Expand Up @@ -653,80 +719,6 @@ impl<A: hal::Api> Example<A> {
.unwrap()
};

let instances = [
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 0.0,
y: 0.0,
z: 0.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: -1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
];

let instances_buffer_size = instances.len() * size_of::<AccelerationStructureInstance>();

let instances_buffer = unsafe {
let instances_buffer = device
.create_buffer(&hal::BufferDescriptor {
label: Some("instances_buffer"),
size: instances_buffer_size as u64,
usage: wgpu_types::BufferUses::MAP_WRITE
| wgpu_types::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT,
})
.unwrap();

let mapping = device
.map_buffer(&instances_buffer, 0..instances_buffer_size as u64)
.unwrap();
ptr::copy_nonoverlapping(
instances.as_ptr() as *const u8,
mapping.ptr.as_ptr(),
instances_buffer_size,
);
device.unmap_buffer(&instances_buffer);
assert!(mapping.is_coherent);

instances_buffer
};

if let hal::AccelerationStructureEntries::Instances(ref mut i) = tlas_entries {
i.buffer = Some(&instances_buffer);
assert!(
instances.len() <= i.count as usize,
"Tlas allocation to small"
);
}

let cmd_encoder_desc = hal::CommandEncoderDescriptor {
label: None,
queue: &queue,
Expand Down Expand Up @@ -903,7 +895,7 @@ impl<A: hal::Api> Example<A> {
ctx.encoder.begin_encoding(Some("frame")).unwrap();

let instances = hal::AccelerationStructureInstances {
buffer: Some(&self.instances_buffer),
buffer: &self.instances_buffer,
count: self.instances.len() as u32,
offset: 0,
};
Expand Down
Loading