Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
DonIsaac committed Jan 2, 2024
1 parent f9f985b commit d0f4ac9
Showing 1 changed file with 6 additions and 37 deletions.
43 changes: 6 additions & 37 deletions dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,11 @@ use crate::{

use alloc::{borrow::Cow, sync::Arc};

fn to_entries<'a, I>(bindings: I) -> Vec<wgpu::BindGroupEntry<'a>>
where
I: IntoIterator<Item = wgpu::BindingResource<'a>>,
{
bindings
.into_iter()
.enumerate()
.map(|(i, binding)| wgpu::BindGroupEntry {
binding: i as u32,
resource: binding,
})
.collect()
}
/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s.
macro_rules! webgpu_params {
($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => {
{
let bindings = [$($x.as_entire_binding()),+];
// let entries = to_entries([$($x.as_entire_binding()),+]);
let entries: Vec<_> = bindings
.into_iter()
.enumerate()
Expand Down Expand Up @@ -130,13 +117,13 @@ pub trait BinaryOpWebgpuKernel<E> {
/// Unique name for the kernel
const MODULE_NAME: &'static str;

/// Name of function in the .cu file
/// Name of function in the .wgsl file
const FWD_FN_NAME: &'static str;

/// Name of function in the .cu file
/// Name of function in the .wgsl file
const BWD_LHS_FN_NAME: &'static str;

/// Name of function in the .cu file
/// Name of function in the .wgsl file
const BWD_RHS_FN_NAME: &'static str;

const ALL_FN_NAMES: [&'static str; 3] = [
Expand All @@ -149,9 +136,6 @@ macro_rules! wgpu_binary {
($Op:path, $TypeName:ty, $Wgsl:tt, $Mod:tt, $Fwd:tt, $Bwd_Lhs:tt, $Bwd_Rhs:tt) => {
impl crate::tensor_ops::webgpu_kernels::BinaryOpWebgpuKernel<$TypeName> for $Op {
const HAS_CONST_DF: bool = false;
// const WGSL_SRC: &'static str = include_str!(concat!("./", $Wgsl, ".wgsl"));
// const MODULE_NAME: &'static str = $Wgsl;
// const WGSL_SRC: &'static str = include_str!($Wgsl);
const WGSL_SRC: &'static str = $Wgsl;
const MODULE_NAME: &'static str = $Mod;
const FWD_FN_NAME: &'static str = $Fwd;
Expand All @@ -162,9 +146,6 @@ macro_rules! wgpu_binary {
(const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Mod:tt, $Fwd:tt, $Bwd_Lhs:tt, $Bwd_Rhs:tt) => {
impl crate::tensor_ops::webgpu_kernels::BinaryOpWebgpuKernel<$TypeName> for $Op {
const HAS_CONST_DF: bool = true;
// const WGSL_SRC: &'static str = include_str!(concat!("./", $Wgsl, ".wgsl"));
// const MODULE_NAME: &'static str = $Wgsl;
// const WGSL_SRC: &'static str = include_str!($Wgsl);
const WGSL_SRC: &'static str = $Wgsl;
const MODULE_NAME: &'static str = $Mod;
const FWD_FN_NAME: &'static str = $Fwd;
Expand Down Expand Up @@ -194,7 +175,6 @@ impl<E: Dtype, K: BinaryOpWebgpuKernel<E> + Clone> BinaryKernel<K, E> for Webgpu
// https://github.com/WebAssembly/memory64
let work_groups: (u32, u32, u32) = (numel as u32, 1, 1);

let layout = self.binary_op_layout();
// todo: pipeline caching
let fwd_pipeline = self.load_binary_pipeline(K::MODULE_NAME, K::WGSL_SRC, K::FWD_FN_NAME);

Expand All @@ -207,18 +187,7 @@ impl<E: Dtype, K: BinaryOpWebgpuKernel<E> + Clone> BinaryKernel<K, E> for Webgpu
// let (lhs, rhs) = (&lhs, &rhs);
let lhs: &Tensor<S, E, Self> = lhs.as_ref();
let rhs: &Tensor<S, E, Self> = rhs.as_ref();
let params: wgpu::BindGroup = {
let entries = to_entries([
lhs.data.as_entire_binding(),
rhs.data.as_entire_binding(),
output.as_entire_binding()
]);
self.dev.create_bind_group(&wgpu::BindGroupDescriptor {
layout,
entries: &entries,
label: Some(K::FWD_FN_NAME),
})
};
let params: wgpu::BindGroup = webgpu_params!(self, fwd_pipeline; lhs.data, rhs.data, output);
let _idx = self.submit_basic_op(&fwd_pipeline, &params, Some(K::FWD_FN_NAME), &work_groups);
}
Ok(self.build_tensor(shape, strides, output))
Expand All @@ -233,6 +202,6 @@ impl<E: Dtype, K: BinaryOpWebgpuKernel<E> + Clone> BinaryKernel<K, E> for Webgpu
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Error> {
todo!("Webgpu#backward()")
todo!("Webgpu binary backwards")
}
}

0 comments on commit d0f4ac9

Please sign in to comment.