diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index f199ec2f..c762c81d 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -75,60 +75,8 @@ impl + 'static> UnaryKernel for Webgpu let cs_module = self .get_shader_module(TypeId::of::()) .ok_or(Error::WebgpuSourceLoadError)?; - let mut entries = Vec::new(); - if std::mem::size_of::() > 0 { - entries.push(wgpu::BindGroupLayoutEntry { - binding: 0, - visibility: ShaderStages::COMPUTE, - ty: BindingType::Buffer { - has_dynamic_offset: false, - ty: wgpu::BufferBindingType::Storage { read_only: true }, - min_binding_size: None, - }, - count: None, - }); - } - entries.push(wgpu::BindGroupLayoutEntry { - binding: 1, - visibility: ShaderStages::COMPUTE, - ty: BindingType::Buffer { - has_dynamic_offset: false, - ty: wgpu::BufferBindingType::Storage { read_only: true }, - min_binding_size: None, - }, - count: None, - }); - entries.push(wgpu::BindGroupLayoutEntry { - binding: 2, - visibility: ShaderStages::COMPUTE, - ty: BindingType::Buffer { - has_dynamic_offset: false, - ty: wgpu::BufferBindingType::Storage { read_only: false }, - min_binding_size: None, - }, - count: None, - }); - - let bind_group_layout = - self.dev - .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { - label: None, - entries: &entries, - }); - let pipeline_layout = self - .dev - .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { - label: None, - bind_group_layouts: &[&bind_group_layout], - push_constant_ranges: &[], - }); - let pipeline_desc = ComputePipelineDescriptor { - label: None, - layout: Some(&pipeline_layout), - module: &cs_module, - entry_point: K::FWD_FN_NAME, - }; - let pipeline = self.dev.create_compute_pipeline(&pipeline_desc); + let bind_group_layout = self.create_bind_group_layout_fwd::(); + let pipeline = self.create_pipeline::(bind_group_layout, cs_module); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.data.len::(); @@ -223,14 +171,8 @@ impl + 'static> UnaryKernel for Webgpu let cs_module = self .get_shader_module(TypeId::of::()) .ok_or(Error::WebgpuSourceLoadError)?; - let pipeline = self - .dev - .create_compute_pipeline(&ComputePipelineDescriptor { - label: None, - layout: None, - module: &cs_module, - entry_point: "main", - }); + let bind_group_layout = self.create_bind_group_layout_bwd::(); + let pipeline = self.create_pipeline::(bind_group_layout, cs_module); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.len(); @@ -322,3 +264,139 @@ impl + 'static> UnaryKernel for Webgpu Ok(()) } } + +impl Webgpu { + fn create_bind_group_layout_fwd + 'static>( + &self, + ) -> wgpu::BindGroupLayout { + let mut entries = Vec::new(); + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + } + entries.push(wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + entries.push(wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: false }, + min_binding_size: None, + }, + count: None, + }); + + let bind_group_layout = + self.dev + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &entries, + }); + bind_group_layout + } + + fn create_bind_group_layout_bwd + 'static>( + &self, + ) -> wgpu::BindGroupLayout { + let mut entries = Vec::new(); + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + } + entries.push(wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + entries.push(wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + entries.push(wgpu::BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: false }, + min_binding_size: None, + }, + count: None, + }); + entries.push(wgpu::BindGroupLayoutEntry { + binding: 4, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + has_dynamic_offset: false, + ty: wgpu::BufferBindingType::Storage { read_only: true }, + min_binding_size: None, + }, + count: None, + }); + + let bind_group_layout = + self.dev + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &entries, + }); + bind_group_layout + } + + fn create_pipeline + 'static>( + &self, + bind_group_layout: wgpu::BindGroupLayout, + cs_module: Arc, + ) -> wgpu::ComputePipeline { + let pipeline_layout = self + .dev + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + let pipeline_desc = ComputePipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + module: &cs_module, + entry_point: K::FWD_FN_NAME, + }; + let pipeline = self.dev.create_compute_pipeline(&pipeline_desc); + pipeline + } +}