Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pass macros #373

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 5 additions & 41 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod r#type;
mod utility;

use dialect::DialectInput;
use parse::{DialectOperationSet, IdentifierList};
use parse::{DialectOperationSet, IdentifierList, PassSet};
use proc_macro::TokenStream;
use quote::quote;
use std::error::Error;
Expand Down Expand Up @@ -68,15 +68,6 @@ pub fn attribute_check_functions(stream: TokenStream) -> TokenStream {
convert_result(attribute::generate(identifiers.identifiers()))
}

#[proc_macro]
pub fn async_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Async").unwrap().into()
}))
}

#[proc_macro]
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
Expand All @@ -90,38 +81,11 @@ pub fn conversion_passes(stream: TokenStream) -> TokenStream {
}

#[proc_macro]
pub fn gpu_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("GPU").unwrap().into()
}))
}

#[proc_macro]
pub fn transform_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Transforms").unwrap().into()
}))
}

#[proc_macro]
pub fn linalg_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Linalg").unwrap().into()
}))
}

#[proc_macro]
pub fn sparse_tensor_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
pub fn passes(stream: TokenStream) -> TokenStream {
let set = parse_macro_input!(stream as PassSet);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("SparseTensor").unwrap().into()
convert_result(pass::generate(set.identifiers(), |name| {
name.strip_prefix(&set.prefix().value()).unwrap().into()
}))
}

Expand Down
33 changes: 32 additions & 1 deletion macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use syn::{
bracketed,
parse::{Parse, ParseStream},
punctuated::Punctuated,
Result, Token,
LitStr, Result, Token,
};

pub struct IdentifierList {
Expand Down Expand Up @@ -56,3 +56,34 @@ impl Parse for DialectOperationSet {
})
}
}

pub struct PassSet {
prefix: LitStr,
identifiers: IdentifierList,
}

impl PassSet {
pub const fn prefix(&self) -> &LitStr {
&self.prefix
}

pub fn identifiers(&self) -> &[Ident] {
self.identifiers.identifiers()
}
}

impl Parse for PassSet {
fn parse(input: ParseStream) -> Result<Self> {
let prefix = input.parse()?;
<Token![,]>::parse(input)?;

Ok(Self {
prefix,
identifiers: {
let content;
bracketed!(content in input);
content.parse::<IdentifierList>()?
},
})
}
}
17 changes: 10 additions & 7 deletions melior/src/pass/async.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
//! Async passes.

melior_macro::async_passes!(
mlirCreateAsyncAsyncFuncToAsyncRuntime,
mlirCreateAsyncAsyncParallelFor,
mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting,
mlirCreateAsyncAsyncRuntimeRefCounting,
mlirCreateAsyncAsyncRuntimeRefCountingOpt,
mlirCreateAsyncAsyncToAsyncRuntime,
melior_macro::passes!(
"Async",
[
mlirCreateAsyncAsyncFuncToAsyncRuntime,
mlirCreateAsyncAsyncParallelFor,
mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting,
mlirCreateAsyncAsyncRuntimeRefCounting,
mlirCreateAsyncAsyncRuntimeRefCountingOpt,
mlirCreateAsyncAsyncToAsyncRuntime,
]
);
15 changes: 9 additions & 6 deletions melior/src/pass/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! GPU passes.

melior_macro::gpu_passes!(
// spell-checker: disable-next-line
mlirCreateGPUGpuAsyncRegionPass,
mlirCreateGPUGpuKernelOutlining,
mlirCreateGPUGpuLaunchSinkIndexComputations,
mlirCreateGPUGpuMapParallelLoopsPass,
melior_macro::passes!(
"GPU",
[
// spell-checker: disable-next-line
mlirCreateGPUGpuAsyncRegionPass,
mlirCreateGPUGpuKernelOutlining,
mlirCreateGPUGpuLaunchSinkIndexComputations,
mlirCreateGPUGpuMapParallelLoopsPass,
]
);
27 changes: 15 additions & 12 deletions melior/src/pass/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
//! Linalg passes.

melior_macro::linalg_passes!(
mlirCreateLinalgConvertElementwiseToLinalg,
mlirCreateLinalgLinalgBufferize,
mlirCreateLinalgLinalgDetensorize,
mlirCreateLinalgLinalgElementwiseOpFusion,
mlirCreateLinalgLinalgFoldUnitExtentDims,
mlirCreateLinalgLinalgGeneralization,
mlirCreateLinalgLinalgInlineScalarOperands,
mlirCreateLinalgLinalgLowerToAffineLoops,
mlirCreateLinalgLinalgLowerToLoops,
mlirCreateLinalgLinalgLowerToParallelLoops,
mlirCreateLinalgLinalgNamedOpConversion,
melior_macro::passes!(
"Linalg",
[
mlirCreateLinalgConvertElementwiseToLinalg,
mlirCreateLinalgLinalgBufferize,
mlirCreateLinalgLinalgDetensorize,
mlirCreateLinalgLinalgElementwiseOpFusion,
mlirCreateLinalgLinalgFoldUnitExtentDims,
mlirCreateLinalgLinalgGeneralization,
mlirCreateLinalgLinalgInlineScalarOperands,
mlirCreateLinalgLinalgLowerToAffineLoops,
mlirCreateLinalgLinalgLowerToLoops,
mlirCreateLinalgLinalgLowerToParallelLoops,
mlirCreateLinalgLinalgNamedOpConversion,
]
);
21 changes: 12 additions & 9 deletions melior/src/pass/sparse_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Sparse tensor passes.

melior_macro::sparse_tensor_passes!(
mlirCreateSparseTensorPostSparsificationRewrite,
mlirCreateSparseTensorPreSparsificationRewrite,
mlirCreateSparseTensorSparseBufferRewrite,
mlirCreateSparseTensorSparseTensorCodegen,
mlirCreateSparseTensorSparseTensorConversionPass,
mlirCreateSparseTensorSparseVectorization,
mlirCreateSparseTensorSparsificationPass,
mlirCreateSparseTensorStorageSpecifierToLLVM,
melior_macro::passes!(
"SparseTensor",
[
mlirCreateSparseTensorPostSparsificationRewrite,
mlirCreateSparseTensorPreSparsificationRewrite,
mlirCreateSparseTensorSparseBufferRewrite,
mlirCreateSparseTensorSparseTensorCodegen,
mlirCreateSparseTensorSparseTensorConversionPass,
mlirCreateSparseTensorSparseVectorization,
mlirCreateSparseTensorSparsificationPass,
mlirCreateSparseTensorStorageSpecifierToLLVM,
]
);
33 changes: 18 additions & 15 deletions melior/src/pass/transform.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
//! Transform passes.

melior_macro::transform_passes!(
mlirCreateTransformsCSE,
mlirCreateTransformsCanonicalizer,
mlirCreateTransformsControlFlowSink,
mlirCreateTransformsGenerateRuntimeVerification,
mlirCreateTransformsInliner,
mlirCreateTransformsLocationSnapshot,
mlirCreateTransformsLoopInvariantCodeMotion,
mlirCreateTransformsPrintOpStats,
mlirCreateTransformsSCCP,
mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE,
mlirCreateTransformsSymbolPrivatize,
mlirCreateTransformsTopologicalSort,
mlirCreateTransformsViewOpGraph,
melior_macro::passes!(
"Transforms",
[
mlirCreateTransformsCSE,
mlirCreateTransformsCanonicalizer,
mlirCreateTransformsControlFlowSink,
mlirCreateTransformsGenerateRuntimeVerification,
mlirCreateTransformsInliner,
mlirCreateTransformsLocationSnapshot,
mlirCreateTransformsLoopInvariantCodeMotion,
mlirCreateTransformsPrintOpStats,
mlirCreateTransformsSCCP,
mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE,
mlirCreateTransformsSymbolPrivatize,
mlirCreateTransformsTopologicalSort,
mlirCreateTransformsViewOpGraph,
]
);