From 4ea0e1e85f6332f4c746eac5a14ae12cb3e48673 Mon Sep 17 00:00:00 2001 From: ReversedCausality Date: Fri, 20 Oct 2023 18:26:13 +0100 Subject: [PATCH] Added `repr(u32)` enum Value derives. --- luisa_compute/src/lang/types.rs | 6 +- luisa_compute/src/lang/types/core.rs | 13 +-- luisa_compute_derive/src/lib.rs | 2 +- .../src/bin/derive-debug.rs | 2 +- luisa_compute_derive_impl/src/lib.rs | 84 ++++++++++++++++++- 5 files changed, 96 insertions(+), 11 deletions(-) diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 9659212b..05007dce 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -60,8 +60,8 @@ pub trait SoaValue: Value { pub trait ExprProxy: Copy + 'static { type Value: Value; - fn from_expr(expr: Expr) -> Self; fn as_expr_from_proxy(&self) -> &Expr; + fn from_expr(expr: Expr) -> Self; } /// A trait for implementing remote impls on top of an [`Var`] using [`Deref`]. @@ -71,15 +71,15 @@ pub trait ExprProxy: Copy + 'static { /// impls. pub trait VarProxy: Copy + 'static + Deref> { type Value: Value; - fn as_var_from_proxy(&self) -> &Var; + fn as_var_from_proxy(&self) -> &Var; fn from_var(expr: Var) -> Self; } pub unsafe trait AtomicRefProxy: Copy + 'static { type Value: Value; - fn as_atomic_ref_from_proxy(&self) -> &AtomicRef; + fn as_atomic_ref_from_proxy(&self) -> &AtomicRef; fn from_atomic_ref(expr: AtomicRef) -> Self; } diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index de89b310..d5349060 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -262,7 +262,7 @@ impl Value for T { type AtomicRef = PrimitiveAtomicRef; fn expr(self) -> Expr { - let node = __current_scope(|s| -> NodeRef { s.const_(self.const_()) }); + let node = __current_scope(|s| s.const_(self.const_())); Expr::::from_node(node.into()) } } @@ -288,7 +288,10 @@ macro_rules! impl_atomic { lower_atomic_ref( self.node().get(), Func::AtomicCompareExchange, - &[expected.as_expr().node().get(), desired.as_expr().node().get()], + &[ + expected.as_expr().node().get(), + desired.as_expr().node().get(), + ], ) } pub fn exchange(&self, operand: impl AsExpr) -> Expr<$t> { @@ -375,9 +378,9 @@ fn lower_atomic_ref(node: NodeRef, op: Func, args: &[NodeRef]) -> Expr .chain(args.iter()) .map(|n| *n) .collect::>(); - Expr::::from_node(__current_scope(|b| { - b.call(op, &new_args, ::type_()) - }).into()) + Expr::::from_node( + __current_scope(|b| b.call(op, &new_args, ::type_())).into(), + ) } _ => unreachable!("{:?}", inst), }, diff --git a/luisa_compute_derive/src/lib.rs b/luisa_compute_derive/src/lib.rs index 6bdf800d..41e4ee27 100644 --- a/luisa_compute_derive/src/lib.rs +++ b/luisa_compute_derive/src/lib.rs @@ -5,7 +5,7 @@ use syn::spanned::Spanned; #[proc_macro_derive(Value, attributes(value_new))] pub fn derive_value(item: TokenStream) -> TokenStream { - let item: syn::ItemStruct = syn::parse(item).unwrap(); + let item: syn::Item = syn::parse(item).unwrap(); let compiler = luisa_compute_derive_impl::Compiler; compiler.derive_value(&item).into() } diff --git a/luisa_compute_derive_impl/src/bin/derive-debug.rs b/luisa_compute_derive_impl/src/bin/derive-debug.rs index 2720beac..2429cb1e 100644 --- a/luisa_compute_derive_impl/src/bin/derive-debug.rs +++ b/luisa_compute_derive_impl/src/bin/derive-debug.rs @@ -15,6 +15,6 @@ fn main() { ) .unwrap(); println!("{:?}", item.to_token_stream()); - let out = compiler.derive_value(&item); + let out = compiler.derive_value_for_struct(&item); println!("{:?}", out.to_string()); } diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index d28259a1..b97dd353 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -212,7 +212,89 @@ impl Compiler { } ) } - pub fn derive_value(&self, struct_: &ItemStruct) -> TokenStream { + pub fn derive_value(&self, item: &Item) -> TokenStream { + match item { + Item::Struct(struct_) => self.derive_value_for_struct(struct_), + Item::Enum(enum_) => self.derive_value_for_enum(enum_), + _ => todo!(), + } + } + pub fn derive_value_for_enum(&self, enum_: &ItemEnum) -> TokenStream { + let repr = enum_ + .attrs + .iter() + .find_map(|attr| { + let meta = &attr.meta; + match meta { + syn::Meta::List(list) => { + let path = &list.path; + if path.is_ident("repr") { + list.parse_args::().ok() + } else { + None + } + } + _ => None, + } + }) + .expect("Enum must have repr attribute."); + let span = enum_.span(); + let lang_path = self.lang_path(); + let name = &enum_.ident; + let expr_proxy_name = syn::Ident::new(&format!("{}Expr", name), name.span()); + let var_proxy_name = syn::Ident::new(&format!("{}Var", name), name.span()); + let atomic_ref_proxy_name = syn::Ident::new(&format!("{}AtomicRef", name), name.span()); + let as_repr = syn::Ident::new(&format!("as_{}", repr), repr.span()); + if !(["bool", "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"] + .contains(&&*repr.to_string())) + { + panic!("Enum repr must be one of bool, u8, u16, u32, u64, i8, i16, i32, i64"); + } + quote_spanned! {span=> + impl #lang_path::types::Value for #name { + type Expr = #expr_proxy_name; + type Var = #var_proxy_name; + type AtomicRef = #atomic_ref_proxy_name; + + fn expr(self) -> Expr { + let node = #lang_path::__current_scope(|s| s.const_(<#repr as #lang_path::types::core::Primitive>::const_(&(self as #repr)))); + as #lang_path::FromNode>::from_node(node.into()) + } + } + impl #lang_path::ir::TypeOf for #name { + fn type_() -> #lang_path::ir::CArc<#lang_path::ir::Type> { + <#repr as #lang_path::ir::TypeOf>::type_() + } + } + + ::luisa_compute::impl_simple_expr_proxy!(#expr_proxy_name for #name); + ::luisa_compute::impl_simple_var_proxy!(#var_proxy_name for #name); + ::luisa_compute::impl_simple_atomic_ref_proxy!(#atomic_ref_proxy_name for #name); + + impl #expr_proxy_name { + pub fn #as_repr(&self) -> #lang_path::types::Expr<#repr> { + use #lang_path::ToNode; + use #lang_path::types::ExprProxy; + #lang_path::FromNode::from_node(self.as_expr_from_proxy().node()) + } + } + impl #var_proxy_name { + pub fn #as_repr(&self) -> #lang_path::types::Var<#repr> { + use #lang_path::ToNode; + use #lang_path::types::VarProxy; + #lang_path::FromNode::from_node(self.as_var_from_proxy().node()) + } + } + impl #atomic_ref_proxy_name { + pub fn #as_repr(&self) -> #lang_path::types::AtomicRef<#repr> { + use #lang_path::ToNode; + use #lang_path::types::AtomicRefProxy; + #lang_path::FromNode::from_node(self.as_atomic_ref_from_proxy().node()) + } + } + } + } + pub fn derive_value_for_struct(&self, struct_: &ItemStruct) -> TokenStream { let ordering = self.value_attributes(&struct_.attrs); let span = struct_.span(); let lang_path = self.lang_path();