Skip to content

Commit

Permalink
prevent SOA overflow SRV limit on dx
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 8, 2023
1 parent 7d31694 commit f8fd379
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
18 changes: 14 additions & 4 deletions luisa_compute/src/lang/soa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::types::SoaValue;
pub struct SoaBuffer<T: SoaValue> {
pub(crate) device: Device,
pub(crate) storage: Arc<ByteBuffer>,
pub(crate) metadata_buf: Buffer<SoaMetadata>,
pub(crate) metadata_buf: Arc<Buffer<SoaMetadata>>,
pub(crate) metadata: SoaMetadata,
pub(crate) copy_kernel: Mutex<Option<SoaBufferCopyKernel<T>>>,
pub(crate) _marker: std::marker::PhantomData<T>,
Expand Down Expand Up @@ -51,6 +51,7 @@ impl<T: SoaValue> SoaBuffer<T> {
pub fn len_expr(&self) -> Expr<u64> {
self.metadata_buf.read(0).count
}

pub fn view<S: RangeBounds<usize>>(&self, range: S) -> SoaBufferView<T> {
let lower = range.start_bound();
let upper = range.end_bound();
Expand All @@ -71,8 +72,13 @@ impl<T: SoaValue> SoaBuffer<T> {
view_start: lower as u64,
view_count: (upper - lower) as u64,
};
let is_full = lower == 0 && upper == self.len();
SoaBufferView {
metadata_buf: self.device.create_buffer_from_slice(&[metadata]),
metadata_buf: if is_full {
self.metadata_buf.clone()
} else {
Arc::new(self.device.create_buffer_from_slice(&[metadata]))
},
metadata,
buffer: self,
}
Expand All @@ -97,6 +103,10 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> {
*copy_kernel = Some(SoaBufferCopyKernel::new(&self.buffer.device));
}
}

/// **WARNING** when capturing the view, if the view is not equal to the full range, a new metadata buffer will be created.
/// However, DX has a limit on the number of SRVs, so it is not recommended to call this method repeatedly.
/// Instead, call it once per view and store the result.
pub fn var(&self) -> SoaBufferVar<T> {
SoaBufferVar {
proxy: T::SoaBuffer::from_soa_storage(
Expand Down Expand Up @@ -137,7 +147,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(buffer)]);
}
}
#[derive(Clone, Copy, Value)]
#[derive(Clone, Copy, Value, PartialEq, Eq, Hash, Debug)]
#[repr(C)]
pub struct SoaMetadata {
/// number of elements in the global buffer
Expand All @@ -147,7 +157,7 @@ pub struct SoaMetadata {
pub view_count: u64,
}
pub struct SoaBufferView<'a, T: SoaValue> {
pub(crate) metadata_buf: Buffer<SoaMetadata>,
pub(crate) metadata_buf: Arc<Buffer<SoaMetadata>>,
pub(crate) metadata: SoaMetadata,
pub(crate) buffer: &'a SoaBuffer<T>,
}
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl Device {
view_start: 0,
view_count: count as u64,
};
let metadata_buf = self.create_buffer_from_slice(&[metadata]);
let metadata_buf = Arc::new(self.create_buffer_from_slice(&[metadata]));
let buffer = SoaBuffer {
storage,
metadata_buf,
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_sys/LuisaCompute

0 comments on commit f8fd379

Please sign in to comment.