diff --git a/macro/src/lib.rs b/macro/src/lib.rs index fb8fb39339..64e01cf6e0 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -7,7 +7,7 @@ mod r#type; mod utility; use dialect::DialectInput; -use parse::{DialectOperationSet, IdentifierList}; +use parse::{DialectOperationSet, IdentifierList, PassSet}; use proc_macro::TokenStream; use quote::quote; use std::error::Error; @@ -68,15 +68,6 @@ pub fn attribute_check_functions(stream: TokenStream) -> TokenStream { convert_result(attribute::generate(identifiers.identifiers())) } -#[proc_macro] -pub fn async_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Async").unwrap().into() - })) -} - #[proc_macro] pub fn conversion_passes(stream: TokenStream) -> TokenStream { let identifiers = parse_macro_input!(stream as IdentifierList); @@ -90,38 +81,11 @@ pub fn conversion_passes(stream: TokenStream) -> TokenStream { } #[proc_macro] -pub fn gpu_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("GPU").unwrap().into() - })) -} - -#[proc_macro] -pub fn transform_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Transforms").unwrap().into() - })) -} - -#[proc_macro] -pub fn linalg_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Linalg").unwrap().into() - })) -} - -#[proc_macro] -pub fn sparse_tensor_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); +pub fn passes(stream: TokenStream) -> TokenStream { + let set = parse_macro_input!(stream as PassSet); - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("SparseTensor").unwrap().into() + convert_result(pass::generate(set.identifiers(), |name| { + name.strip_prefix(&set.prefix().value()).unwrap().into() })) } diff --git a/macro/src/parse.rs b/macro/src/parse.rs index d484f5bd80..0f8cdbae1d 100644 --- a/macro/src/parse.rs +++ b/macro/src/parse.rs @@ -3,7 +3,7 @@ use syn::{ bracketed, parse::{Parse, ParseStream}, punctuated::Punctuated, - Result, Token, + LitStr, Result, Token, }; pub struct IdentifierList { @@ -56,3 +56,34 @@ impl Parse for DialectOperationSet { }) } } + +pub struct PassSet { + prefix: LitStr, + identifiers: IdentifierList, +} + +impl PassSet { + pub const fn prefix(&self) -> &LitStr { + &self.prefix + } + + pub fn identifiers(&self) -> &[Ident] { + self.identifiers.identifiers() + } +} + +impl Parse for PassSet { + fn parse(input: ParseStream) -> Result { + let prefix = input.parse()?; + ::parse(input)?; + + Ok(Self { + prefix, + identifiers: { + let content; + bracketed!(content in input); + content.parse::()? + }, + }) + } +} diff --git a/melior/src/pass/async.rs b/melior/src/pass/async.rs index 53effd3f71..30ea238c7d 100644 --- a/melior/src/pass/async.rs +++ b/melior/src/pass/async.rs @@ -1,10 +1,13 @@ //! Async passes. -melior_macro::async_passes!( - mlirCreateAsyncAsyncFuncToAsyncRuntime, - mlirCreateAsyncAsyncParallelFor, - mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting, - mlirCreateAsyncAsyncRuntimeRefCounting, - mlirCreateAsyncAsyncRuntimeRefCountingOpt, - mlirCreateAsyncAsyncToAsyncRuntime, +melior_macro::passes!( + "Async", + [ + mlirCreateAsyncAsyncFuncToAsyncRuntime, + mlirCreateAsyncAsyncParallelFor, + mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting, + mlirCreateAsyncAsyncRuntimeRefCounting, + mlirCreateAsyncAsyncRuntimeRefCountingOpt, + mlirCreateAsyncAsyncToAsyncRuntime, + ] ); diff --git a/melior/src/pass/gpu.rs b/melior/src/pass/gpu.rs index d89bc36e87..daa2a3aea1 100644 --- a/melior/src/pass/gpu.rs +++ b/melior/src/pass/gpu.rs @@ -1,9 +1,12 @@ //! GPU passes. -melior_macro::gpu_passes!( - // spell-checker: disable-next-line - mlirCreateGPUGpuAsyncRegionPass, - mlirCreateGPUGpuKernelOutlining, - mlirCreateGPUGpuLaunchSinkIndexComputations, - mlirCreateGPUGpuMapParallelLoopsPass, +melior_macro::passes!( + "GPU", + [ + // spell-checker: disable-next-line + mlirCreateGPUGpuAsyncRegionPass, + mlirCreateGPUGpuKernelOutlining, + mlirCreateGPUGpuLaunchSinkIndexComputations, + mlirCreateGPUGpuMapParallelLoopsPass, + ] ); diff --git a/melior/src/pass/linalg.rs b/melior/src/pass/linalg.rs index 067127264d..00b091d6d1 100644 --- a/melior/src/pass/linalg.rs +++ b/melior/src/pass/linalg.rs @@ -1,15 +1,18 @@ //! Linalg passes. -melior_macro::linalg_passes!( - mlirCreateLinalgConvertElementwiseToLinalg, - mlirCreateLinalgLinalgBufferize, - mlirCreateLinalgLinalgDetensorize, - mlirCreateLinalgLinalgElementwiseOpFusion, - mlirCreateLinalgLinalgFoldUnitExtentDims, - mlirCreateLinalgLinalgGeneralization, - mlirCreateLinalgLinalgInlineScalarOperands, - mlirCreateLinalgLinalgLowerToAffineLoops, - mlirCreateLinalgLinalgLowerToLoops, - mlirCreateLinalgLinalgLowerToParallelLoops, - mlirCreateLinalgLinalgNamedOpConversion, +melior_macro::passes!( + "Linalg", + [ + mlirCreateLinalgConvertElementwiseToLinalg, + mlirCreateLinalgLinalgBufferize, + mlirCreateLinalgLinalgDetensorize, + mlirCreateLinalgLinalgElementwiseOpFusion, + mlirCreateLinalgLinalgFoldUnitExtentDims, + mlirCreateLinalgLinalgGeneralization, + mlirCreateLinalgLinalgInlineScalarOperands, + mlirCreateLinalgLinalgLowerToAffineLoops, + mlirCreateLinalgLinalgLowerToLoops, + mlirCreateLinalgLinalgLowerToParallelLoops, + mlirCreateLinalgLinalgNamedOpConversion, + ] ); diff --git a/melior/src/pass/sparse_tensor.rs b/melior/src/pass/sparse_tensor.rs index 60ef80b232..c8544bc46d 100644 --- a/melior/src/pass/sparse_tensor.rs +++ b/melior/src/pass/sparse_tensor.rs @@ -1,12 +1,15 @@ //! Sparse tensor passes. -melior_macro::sparse_tensor_passes!( - mlirCreateSparseTensorPostSparsificationRewrite, - mlirCreateSparseTensorPreSparsificationRewrite, - mlirCreateSparseTensorSparseBufferRewrite, - mlirCreateSparseTensorSparseTensorCodegen, - mlirCreateSparseTensorSparseTensorConversionPass, - mlirCreateSparseTensorSparseVectorization, - mlirCreateSparseTensorSparsificationPass, - mlirCreateSparseTensorStorageSpecifierToLLVM, +melior_macro::passes!( + "SparseTensor", + [ + mlirCreateSparseTensorPostSparsificationRewrite, + mlirCreateSparseTensorPreSparsificationRewrite, + mlirCreateSparseTensorSparseBufferRewrite, + mlirCreateSparseTensorSparseTensorCodegen, + mlirCreateSparseTensorSparseTensorConversionPass, + mlirCreateSparseTensorSparseVectorization, + mlirCreateSparseTensorSparsificationPass, + mlirCreateSparseTensorStorageSpecifierToLLVM, + ] ); diff --git a/melior/src/pass/transform.rs b/melior/src/pass/transform.rs index aaa09c6660..657e9f554a 100644 --- a/melior/src/pass/transform.rs +++ b/melior/src/pass/transform.rs @@ -1,18 +1,21 @@ //! Transform passes. -melior_macro::transform_passes!( - mlirCreateTransformsCSE, - mlirCreateTransformsCanonicalizer, - mlirCreateTransformsControlFlowSink, - mlirCreateTransformsGenerateRuntimeVerification, - mlirCreateTransformsInliner, - mlirCreateTransformsLocationSnapshot, - mlirCreateTransformsLoopInvariantCodeMotion, - mlirCreateTransformsPrintOpStats, - mlirCreateTransformsSCCP, - mlirCreateTransformsStripDebugInfo, - mlirCreateTransformsSymbolDCE, - mlirCreateTransformsSymbolPrivatize, - mlirCreateTransformsTopologicalSort, - mlirCreateTransformsViewOpGraph, +melior_macro::passes!( + "Transforms", + [ + mlirCreateTransformsCSE, + mlirCreateTransformsCanonicalizer, + mlirCreateTransformsControlFlowSink, + mlirCreateTransformsGenerateRuntimeVerification, + mlirCreateTransformsInliner, + mlirCreateTransformsLocationSnapshot, + mlirCreateTransformsLoopInvariantCodeMotion, + mlirCreateTransformsPrintOpStats, + mlirCreateTransformsSCCP, + mlirCreateTransformsStripDebugInfo, + mlirCreateTransformsSymbolDCE, + mlirCreateTransformsSymbolPrivatize, + mlirCreateTransformsTopologicalSort, + mlirCreateTransformsViewOpGraph, + ] );