Skip to content

Commit

Permalink
Refactor dialect macros (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Dec 4, 2023
1 parent 7b5ab0b commit eb075e7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 45 deletions.
20 changes: 10 additions & 10 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser};
const LLVM_MAJOR_VERSION: usize = 17;

pub fn generate_dialect(input: DialectInput) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut td_parser = TableGenParser::new();
let mut parser = TableGenParser::new();

if let Some(source) = input.tablegen() {
td_parser = td_parser.add_source(source).map_err(create_syn_error)?;
if let Some(source) = input.table_gen() {
parser = parser.add_source(source).map_err(create_syn_error)?;
}

if let Some(file) = input.td_file() {
td_parser = td_parser.add_source_file(file).map_err(create_syn_error)?;
parser = parser.add_source_file(file).map_err(create_syn_error)?;
}

// spell-checker: disable-next-line
for include in input.includes().chain([&*llvm_config("--includedir")?]) {
td_parser = td_parser.add_include_path(include);
for path in input.includes().chain([&*llvm_config("--includedir")?]) {
parser = parser.add_include_path(path);
}

let keeper = td_parser.parse().map_err(Error::Parse)?;
let keeper = parser.parse().map_err(Error::Parse)?;

let dialect = dialect_module(
let dialect = generate_dialect_module(
input.name(),
keeper
.all_derived_definitions("Dialect")
.find(|def| def.str_value("name") == Ok(input.name()))
.find(|definition| definition.str_value("name") == Ok(input.name()))
.ok_or_else(|| create_syn_error("dialect not found"))?,
&keeper,
)
Expand All @@ -49,7 +49,7 @@ pub fn generate_dialect(input: DialectInput) -> Result<TokenStream, Box<dyn std:
Ok(quote! { #dialect }.into())
}

fn dialect_module(
fn generate_dialect_module(
name: &str,
dialect: Record,
record_keeper: &RecordKeeper,
Expand Down
14 changes: 7 additions & 7 deletions macro/src/dialect/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use syn::{bracketed, parse::Parse, punctuated::Punctuated, LitStr, Token};

pub struct DialectInput {
name: String,
tablegen: Option<String>,
table_gen: Option<String>,
td_file: Option<String>,
includes: Vec<String>,
}
Expand All @@ -15,8 +15,8 @@ impl DialectInput {
&self.name
}

pub fn tablegen(&self) -> Option<&str> {
self.tablegen.as_deref()
pub fn table_gen(&self) -> Option<&str> {
self.table_gen.as_deref()
}

pub fn td_file(&self) -> Option<&str> {
Expand All @@ -31,14 +31,14 @@ impl DialectInput {
impl Parse for DialectInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut name = None;
let mut tablegen = None;
let mut table_gen = None;
let mut td_file = None;
let mut includes = vec![];

for item in Punctuated::<InputField, Token![,]>::parse_terminated(input)? {
match item {
InputField::Name(field) => name = Some(field.value()),
InputField::TableGen(td) => tablegen = Some(td.value()),
InputField::TableGen(td) => table_gen = Some(td.value()),
InputField::TdFile(file) => td_file = Some(file.value()),
InputField::Includes(field) => {
includes = field.into_iter().map(|literal| literal.value()).collect()
Expand All @@ -48,7 +48,7 @@ impl Parse for DialectInput {

Ok(Self {
name: name.ok_or(input.error("dialect name required"))?,
tablegen,
table_gen,
td_file,
includes,
})
Expand All @@ -70,7 +70,7 @@ impl Parse for InputField {

if ident == format_ident!("name") {
Ok(Self::Name(input.parse()?))
} else if ident == format_ident!("tablegen") {
} else if ident == format_ident!("table_gen") {
Ok(Self::TableGen(input.parse()?))
} else if ident == format_ident!("td_file") {
Ok(Self::TdFile(input.parse()?))
Expand Down
5 changes: 3 additions & 2 deletions macro/src/dialect/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ pub fn sanitize_snake_case_name(name: &str) -> Result<Ident, Error> {
}

fn sanitize_name(name: &str) -> Result<Ident, Error> {
// Replace any "." with "_"
// Replace any "." with "_".
let mut name = name.replace('.', "_");

// Add "_" suffix to avoid conflicts with existing methods
// Add "_" suffix to avoid conflicts with existing methods.
if RESERVED_NAMES.contains(&name.as_str())
|| name
.chars()
Expand Down Expand Up @@ -44,6 +44,7 @@ pub fn sanitize_documentation(string: &str) -> Result<String, Error> {
};

if block.info.is_empty() {
// Mark them not in Rust to prevent documentation tests.
block.info = "text".into();
}
}
Expand Down
2 changes: 1 addition & 1 deletion macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use syn::parse_macro_input;
/// ```rust
/// melior::dialect! {
/// name: "func",
/// tablegen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
/// table_gen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
/// }
/// ```
#[proc_macro]
Expand Down
50 changes: 25 additions & 25 deletions melior/src/dialect/ods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,104 +9,104 @@ pub mod __private {

melior_macro::dialect! {
name: "affine",
tablegen: r#"include "mlir/Dialect/Affine/IR/AffineOps.td""#
table_gen: r#"include "mlir/Dialect/Affine/IR/AffineOps.td""#
}
melior_macro::dialect! {
name: "amdgpu",
tablegen: r#"include "mlir/Dialect/AMDGPU/IR/AMDGPU.td""#
table_gen: r#"include "mlir/Dialect/AMDGPU/IR/AMDGPU.td""#
}
melior_macro::dialect! {
name: "arith",
tablegen: r#"include "mlir/Dialect/Arith/IR/ArithOps.td""#
table_gen: r#"include "mlir/Dialect/Arith/IR/ArithOps.td""#
}
melior_macro::dialect! {
name: "arm_neon",
tablegen: r#"include "mlir/Dialect/ArmNeon/ArmNeon.td""#
table_gen: r#"include "mlir/Dialect/ArmNeon/ArmNeon.td""#
}
melior_macro::dialect! {
name: "arm_sve",
tablegen: r#"include "mlir/Dialect/ArmSVE/ArmSVE.td""#
table_gen: r#"include "mlir/Dialect/ArmSVE/ArmSVE.td""#
}
melior_macro::dialect! {
name: "async",
tablegen: r#"include "mlir/Dialect/Async/IR/AsyncOps.td""#
table_gen: r#"include "mlir/Dialect/Async/IR/AsyncOps.td""#
}
melior_macro::dialect! {
name: "bufferization",
tablegen: r#"include "mlir/Dialect/Bufferization/IR/BufferizationOps.td""#
table_gen: r#"include "mlir/Dialect/Bufferization/IR/BufferizationOps.td""#
}
melior_macro::dialect! {
name: "cf",
tablegen: r#"include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td""#
table_gen: r#"include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td""#
}
melior_macro::dialect! {
name: "func",
tablegen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
table_gen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
}
melior_macro::dialect! {
name: "index",
tablegen: r#"include "mlir/Dialect/Index/IR/IndexOps.td""#
table_gen: r#"include "mlir/Dialect/Index/IR/IndexOps.td""#
}
melior_macro::dialect! {
name: "llvm",
// spell-checker: disable-next-line
tablegen: r#"include "mlir/Dialect/LLVMIR/LLVMOps.td""#
table_gen: r#"include "mlir/Dialect/LLVMIR/LLVMOps.td""#
}
melior_macro::dialect! {
name: "memref",
tablegen: r#"include "mlir/Dialect/MemRef/IR/MemRefOps.td""#
table_gen: r#"include "mlir/Dialect/MemRef/IR/MemRefOps.td""#
}
melior_macro::dialect! {
name: "scf",
tablegen: r#"include "mlir/Dialect/SCF/IR/SCFOps.td""#
table_gen: r#"include "mlir/Dialect/SCF/IR/SCFOps.td""#
}
melior_macro::dialect! {
name: "pdl",
tablegen: r#"include "mlir/Dialect/PDL/IR/PDLOps.td""#
table_gen: r#"include "mlir/Dialect/PDL/IR/PDLOps.td""#
}
melior_macro::dialect! {
name: "pdl_interp",
tablegen: r#"include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td""#
table_gen: r#"include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td""#
}
melior_macro::dialect! {
name: "math",
tablegen: r#"include "mlir/Dialect/Math/IR/MathOps.td""#
table_gen: r#"include "mlir/Dialect/Math/IR/MathOps.td""#
}
melior_macro::dialect! {
name: "gpu",
tablegen: r#"include "mlir/Dialect/GPU/IR/GPUOps.td""#
table_gen: r#"include "mlir/Dialect/GPU/IR/GPUOps.td""#
}
melior_macro::dialect! {
name: "linalg",
tablegen: r#"include "mlir/Dialect/Linalg/IR/LinalgOps.td""#
table_gen: r#"include "mlir/Dialect/Linalg/IR/LinalgOps.td""#
}
melior_macro::dialect! {
name: "quant",
tablegen: r#"include "mlir/Dialect/Quant/QuantOps.td""#
table_gen: r#"include "mlir/Dialect/Quant/QuantOps.td""#
}
melior_macro::dialect! {
name: "shape",
tablegen: r#"include "mlir/Dialect/Shape/IR/ShapeOps.td""#
table_gen: r#"include "mlir/Dialect/Shape/IR/ShapeOps.td""#
}
melior_macro::dialect! {
name: "sparse_tensor",
tablegen: r#"include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td""#
table_gen: r#"include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td""#
}
melior_macro::dialect! {
name: "tensor",
tablegen: r#"include "mlir/Dialect/Tensor/IR/TensorOps.td""#
table_gen: r#"include "mlir/Dialect/Tensor/IR/TensorOps.td""#
}
melior_macro::dialect! {
name: "tosa",
tablegen: r#"include "mlir/Dialect/Tosa/IR/TosaOps.td""#
table_gen: r#"include "mlir/Dialect/Tosa/IR/TosaOps.td""#
}
melior_macro::dialect! {
name: "transform",
tablegen: r#"include "mlir/Dialect/Transform/IR/TransformOps.td""#
table_gen: r#"include "mlir/Dialect/Transform/IR/TransformOps.td""#
}
melior_macro::dialect! {
name: "vector",
tablegen: r#"include "mlir/Dialect/Vector/IR/VectorOps.td""#
table_gen: r#"include "mlir/Dialect/Vector/IR/VectorOps.td""#
}

#[cfg(test)]
Expand Down

0 comments on commit eb075e7

Please sign in to comment.