Skip to content

Commit

Permalink
Refactor out stuff and hopefully get backward working.
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo committed Dec 27, 2023
1 parent 8edbaf0 commit 161e597
Showing 1 changed file with 140 additions and 62 deletions.
202 changes: 140 additions & 62 deletions dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,60 +75,8 @@ impl<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu
let cs_module = self
.get_shader_module(TypeId::of::<K>())
.ok_or(Error::WebgpuSourceLoadError)?;
let mut entries = Vec::new();
if std::mem::size_of::<K>() > 0 {
entries.push(wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
}
entries.push(wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: false },
min_binding_size: None,
},
count: None,
});

let bind_group_layout =
self.dev
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &entries,
});
let pipeline_layout = self
.dev
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline_desc = ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &cs_module,
entry_point: K::FWD_FN_NAME,
};
let pipeline = self.dev.create_compute_pipeline(&pipeline_desc);
let bind_group_layout = self.create_bind_group_layout_fwd::<E, K>();
let pipeline = self.create_pipeline::<E, K>(bind_group_layout, cs_module);
let bind_group_layout = pipeline.get_bind_group_layout(0);
let op_storage = self.alloc_init::<K>(&[op])?;
let numel = inp.data.len::<E>();
Expand Down Expand Up @@ -223,14 +171,8 @@ impl<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu
let cs_module = self
.get_shader_module(TypeId::of::<K>())
.ok_or(Error::WebgpuSourceLoadError)?;
let pipeline = self
.dev
.create_compute_pipeline(&ComputePipelineDescriptor {
label: None,
layout: None,
module: &cs_module,
entry_point: "main",
});
let bind_group_layout = self.create_bind_group_layout_bwd::<E, K>();
let pipeline = self.create_pipeline::<E, K>(bind_group_layout, cs_module);
let bind_group_layout = pipeline.get_bind_group_layout(0);
let op_storage = self.alloc_init::<K>(&[op])?;
let numel = inp.len();
Expand Down Expand Up @@ -322,3 +264,139 @@ impl<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu
Ok(())
}
}

impl Webgpu {
fn create_bind_group_layout_fwd<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static>(
&self,
) -> wgpu::BindGroupLayout {
let mut entries = Vec::new();
if std::mem::size_of::<K>() > 0 {
entries.push(wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
}
entries.push(wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: false },
min_binding_size: None,
},
count: None,
});

let bind_group_layout =
self.dev
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &entries,
});
bind_group_layout
}

fn create_bind_group_layout_bwd<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static>(
&self,
) -> wgpu::BindGroupLayout {
let mut entries = Vec::new();
if std::mem::size_of::<K>() > 0 {
entries.push(wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
}
entries.push(wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: false },
min_binding_size: None,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
has_dynamic_offset: false,
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: None,
},
count: None,
});

let bind_group_layout =
self.dev
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &entries,
});
bind_group_layout
}

fn create_pipeline<E: Dtype, K: UnaryOpWebgpuKernel<E> + 'static>(
&self,
bind_group_layout: wgpu::BindGroupLayout,
cs_module: Arc<wgpu::ShaderModule>,
) -> wgpu::ComputePipeline {
let pipeline_layout = self
.dev
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline_desc = ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &cs_module,
entry_point: K::FWD_FN_NAME,
};
let pipeline = self.dev.create_compute_pipeline(&pipeline_desc);
pipeline
}
}

0 comments on commit 161e597

Please sign in to comment.