From f85bca88e5de1ef5ed5e07c7bda1019366aa5f40 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Tue, 5 Dec 2023 21:43:01 +0800 Subject: [PATCH] Refactor dialect macros (#376) --- macro/src/dialect.rs | 9 +- macro/src/dialect/dialect.rs | 104 ++++ macro/src/dialect/operation.rs | 504 +++++------------- macro/src/dialect/operation/accessors.rs | 14 +- macro/src/dialect/operation/element_kind.rs | 14 + macro/src/dialect/operation/field_kind.rs | 162 ++++++ .../src/dialect/operation/operation_field.rs | 77 +++ macro/src/dialect/operation/sequence_info.rs | 5 + macro/src/dialect/operation/variadic_kind.rs | 32 ++ macro/src/dialect/utility.rs | 2 +- 10 files changed, 526 insertions(+), 397 deletions(-) create mode 100644 macro/src/dialect/dialect.rs create mode 100644 macro/src/dialect/operation/element_kind.rs create mode 100644 macro/src/dialect/operation/field_kind.rs create mode 100644 macro/src/dialect/operation/operation_field.rs create mode 100644 macro/src/dialect/operation/sequence_info.rs create mode 100644 macro/src/dialect/operation/variadic_kind.rs diff --git a/macro/src/dialect.rs b/macro/src/dialect.rs index ad24b8d49a..67833e57fd 100644 --- a/macro/src/dialect.rs +++ b/macro/src/dialect.rs @@ -54,19 +54,18 @@ fn generate_dialect_module( dialect: Record, record_keeper: &RecordKeeper, ) -> Result { + let dialect_name = dialect.name()?; let operations = record_keeper .all_derived_definitions("Op") - .map(Operation::from_def) + .map(Operation::new) .collect::, _>>()? .into_iter() - .filter(|operation| operation.dialect.name() == dialect.name()) + .filter(|operation| operation.dialect_name() == dialect_name) .collect::>(); let doc = format!( "`{name}` dialect.\n\n{}", - sanitize_documentation(&unindent::unindent( - dialect.str_value("description").unwrap_or(""), - ))? + sanitize_documentation(dialect.str_value("description").unwrap_or(""),)? ); let name = sanitize_snake_case_name(name)?; diff --git a/macro/src/dialect/dialect.rs b/macro/src/dialect/dialect.rs new file mode 100644 index 0000000000..67833e57fd --- /dev/null +++ b/macro/src/dialect/dialect.rs @@ -0,0 +1,104 @@ +mod error; +mod input; +mod operation; +mod types; +mod utility; + +use self::{ + error::Error, + utility::{sanitize_documentation, sanitize_snake_case_name}, +}; +pub use input::DialectInput; +use operation::Operation; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use std::{env, fmt::Display, path::Path, process::Command, str}; +use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser}; + +const LLVM_MAJOR_VERSION: usize = 17; + +pub fn generate_dialect(input: DialectInput) -> Result> { + let mut parser = TableGenParser::new(); + + if let Some(source) = input.table_gen() { + parser = parser.add_source(source).map_err(create_syn_error)?; + } + + if let Some(file) = input.td_file() { + parser = parser.add_source_file(file).map_err(create_syn_error)?; + } + + // spell-checker: disable-next-line + for path in input.includes().chain([&*llvm_config("--includedir")?]) { + parser = parser.add_include_path(path); + } + + let keeper = parser.parse().map_err(Error::Parse)?; + + let dialect = generate_dialect_module( + input.name(), + keeper + .all_derived_definitions("Dialect") + .find(|definition| definition.str_value("name") == Ok(input.name())) + .ok_or_else(|| create_syn_error("dialect not found"))?, + &keeper, + ) + .map_err(|error| error.add_source_info(keeper.source_info()))?; + + Ok(quote! { #dialect }.into()) +} + +fn generate_dialect_module( + name: &str, + dialect: Record, + record_keeper: &RecordKeeper, +) -> Result { + let dialect_name = dialect.name()?; + let operations = record_keeper + .all_derived_definitions("Op") + .map(Operation::new) + .collect::, _>>()? + .into_iter() + .filter(|operation| operation.dialect_name() == dialect_name) + .collect::>(); + + let doc = format!( + "`{name}` dialect.\n\n{}", + sanitize_documentation(dialect.str_value("description").unwrap_or(""),)? + ); + let name = sanitize_snake_case_name(name)?; + + Ok(quote! { + #[doc = #doc] + pub mod #name { + #(#operations)* + } + }) +} + +fn llvm_config(argument: &str) -> Result> { + let prefix = env::var(format!("MLIR_SYS_{}0_PREFIX", LLVM_MAJOR_VERSION)) + .map(|path| Path::new(&path).join("bin")) + .unwrap_or_default(); + let call = format!( + "{} --link-static {}", + prefix.join("llvm-config").display(), + argument + ); + + Ok(str::from_utf8( + &if cfg!(target_os = "windows") { + Command::new("cmd").args(["/C", &call]).output()? + } else { + Command::new("sh").arg("-c").arg(&call).output()? + } + .stdout, + )? + .trim() + .to_string()) +} + +fn create_syn_error(error: impl Display) -> syn::Error { + syn::Error::new(Span::call_site(), format!("{}", error)) +} diff --git a/macro/src/dialect/operation.rs b/macro/src/dialect/operation.rs index 2b6ce9073c..106426fb0b 100644 --- a/macro/src/dialect/operation.rs +++ b/macro/src/dialect/operation.rs @@ -1,312 +1,116 @@ mod accessors; mod builder; - -use self::builder::OperationBuilder; -use super::utility::{sanitize_documentation, sanitize_snake_case_name}; +mod element_kind; +mod field_kind; +mod operation_field; +mod sequence_info; +mod variadic_kind; + +use self::{ + builder::OperationBuilder, element_kind::ElementKind, field_kind::FieldKind, + operation_field::OperationField, sequence_info::SequenceInfo, variadic_kind::VariadicKind, +}; +use super::utility::sanitize_documentation; use crate::dialect::{ error::{Error, OdsError}, types::{AttributeConstraint, RegionConstraint, SuccessorConstraint, Trait, TypeConstraint}, }; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens, TokenStreamExt}; -use syn::{parse_quote, Type}; use tblgen::{error::WithLocation, record::Record}; -#[derive(Debug, Clone, Copy)] -pub enum ElementKind { - Operand, - Result, -} - -impl ElementKind { - pub fn as_str(&self) -> &'static str { - match self { - Self::Operand => "operand", - Self::Result => "result", - } - } -} - -#[derive(Debug, Clone)] -pub enum FieldKind<'a> { - Element { - kind: ElementKind, - constraint: TypeConstraint<'a>, - sequence_info: SequenceInfo, - variadic_kind: VariadicKind, - }, - Attribute { - constraint: AttributeConstraint<'a>, - }, - Successor { - constraint: SuccessorConstraint<'a>, - sequence_info: SequenceInfo, - }, - Region { - constraint: RegionConstraint<'a>, - sequence_info: SequenceInfo, - }, -} - -impl<'a> FieldKind<'a> { - pub fn as_str(&self) -> &'static str { - match self { - Self::Element { kind, .. } => kind.as_str(), - Self::Attribute { .. } => "attribute", - Self::Successor { .. } => "successor", - Self::Region { .. } => "region", - } - } - - pub fn is_optional(&self) -> Result { - Ok(match self { - Self::Element { constraint, .. } => constraint.is_optional(), - Self::Attribute { constraint, .. } => { - constraint.is_optional()? || constraint.has_default_value()? - } - Self::Successor { .. } | Self::Region { .. } => false, - }) - } - - pub fn is_result(&self) -> bool { - matches!( - self, - Self::Element { - kind: ElementKind::Result, - .. - } - ) - } - - pub fn parameter_type(&self) -> Result { - Ok(match self { - Self::Element { - kind, constraint, .. - } => { - let base_type: Type = match kind { - ElementKind::Operand => { - parse_quote!(::melior::ir::Value<'c, '_>) - } - ElementKind::Result => { - parse_quote!(::melior::ir::Type<'c>) - } - }; - if constraint.is_variadic() { - parse_quote! { &[#base_type] } - } else { - base_type - } - } - Self::Attribute { constraint } => { - if constraint.is_unit()? { - parse_quote!(bool) - } else { - let r#type: Type = syn::parse_str(constraint.storage_type()?)?; - parse_quote!(#r#type<'c>) - } - } - Self::Successor { constraint, .. } => { - let r#type: Type = parse_quote!(&::melior::ir::Block<'c>); - if constraint.is_variadic() { - parse_quote!(&[#r#type]) - } else { - r#type - } - } - Self::Region { constraint, .. } => { - let r#type: Type = parse_quote!(::melior::ir::Region<'c>); - if constraint.is_variadic() { - parse_quote!(Vec<#r#type>) - } else { - r#type - } - } - }) - } - - fn create_result_type(r#type: Type) -> Type { - parse_quote!(Result<#r#type, ::melior::Error>) - } - - fn create_iterator_type(r#type: Type) -> Type { - parse_quote!(impl Iterator) - } - - pub fn return_type(&self) -> Result { - Ok(match self { - Self::Element { - kind, - constraint, - variadic_kind, - .. - } => { - let base_type: Type = match kind { - ElementKind::Operand => { - parse_quote!(::melior::ir::Value<'c, '_>) - } - ElementKind::Result => { - parse_quote!(::melior::ir::operation::OperationResult<'c, '_>) - } - }; - if !constraint.is_variadic() { - Self::create_result_type(base_type) - } else if let VariadicKind::AttrSized {} = variadic_kind { - Self::create_result_type(Self::create_iterator_type(base_type)) - } else { - Self::create_iterator_type(base_type) - } - } - Self::Attribute { constraint } => { - if constraint.is_unit()? { - parse_quote!(bool) - } else { - Self::create_result_type(self.parameter_type()?) - } - } - Self::Successor { constraint, .. } => { - let r#type: Type = parse_quote!(::melior::ir::BlockRef<'c, '_>); - if constraint.is_variadic() { - Self::create_iterator_type(r#type) - } else { - Self::create_result_type(r#type) - } - } - Self::Region { constraint, .. } => { - let r#type: Type = parse_quote!(::melior::ir::RegionRef<'c, '_>); - if constraint.is_variadic() { - Self::create_iterator_type(r#type) - } else { - Self::create_result_type(r#type) - } - } - }) - } -} - -#[derive(Debug, Clone)] -pub struct SequenceInfo { - index: usize, - len: usize, -} - #[derive(Clone, Debug)] -pub enum VariadicKind { - Simple { - seen_variable_length: bool, - }, - SameSize { - num_variable_length: usize, - num_preceding_simple: usize, - num_preceding_variadic: usize, - }, - AttrSized {}, +pub struct Operation<'a> { + dialect_name: &'a str, + short_name: &'a str, + full_name: String, + class_name: &'a str, + summary: String, + can_infer_type: bool, + description: String, + regions: Vec>, + successors: Vec>, + results: Vec>, + operands: Vec>, + attributes: Vec>, + derived_attributes: Vec>, } -impl VariadicKind { - pub fn new(num_variable_length: usize, same_size: bool, attr_sized: bool) -> Self { - if num_variable_length <= 1 { - VariadicKind::Simple { - seen_variable_length: false, - } - } else if same_size { - VariadicKind::SameSize { - num_variable_length, - num_preceding_simple: 0, - num_preceding_variadic: 0, - } - } else if attr_sized { - VariadicKind::AttrSized {} - } else { - unimplemented!() - } - } -} +impl<'a> Operation<'a> { + pub fn new(definition: Record<'a>) -> Result { + let dialect = definition.def_value("opDialect")?; + let traits = Self::collect_traits(definition)?; + let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name)); + + let arguments = Self::dag_constraints(definition, "arguments")?; + let regions = Self::collect_regions(definition)?; + let (results, variable_length_results_count) = Self::collect_results( + definition, + has_trait("::mlir::OpTrait::SameVariadicResultSize"), + has_trait("::mlir::OpTrait::AttrSizedResultSegments"), + )?; -#[derive(Debug, Clone)] -pub struct OperationField<'a> { - pub(crate) name: &'a str, - pub(crate) sanitized_name: Ident, - pub(crate) kind: FieldKind<'a>, -} + let name = definition.name()?; + let class_name = if name.starts_with('_') { + name + } else if let Some(name) = name.split('_').nth(1) { + // Trim dialect prefix from name. + name + } else { + name + }; + let short_name = definition.str_value("opName")?; -impl<'a> OperationField<'a> { - fn new(name: &'a str, kind: FieldKind<'a>) -> Result { Ok(Self { - name, - sanitized_name: sanitize_snake_case_name(name)?, - kind, - }) - } - - fn new_attribute(name: &'a str, constraint: AttributeConstraint<'a>) -> Result { - Self::new(name, FieldKind::Attribute { constraint }) - } + dialect_name: dialect.name()?, + short_name, + full_name: { + let dialect_name = dialect.string_value("name")?; - fn new_region( - name: &'a str, - constraint: RegionConstraint<'a>, - sequence_info: SequenceInfo, - ) -> Result { - Self::new( - name, - FieldKind::Region { - constraint, - sequence_info, + if dialect_name.is_empty() { + short_name.into() + } else { + format!("{dialect_name}.{short_name}") + } }, - ) - } + class_name, + successors: Self::collect_successors(definition)?, + operands: Self::collect_operands( + &arguments, + has_trait("::mlir::OpTrait::SameVariadicOperandSize"), + has_trait("::mlir::OpTrait::AttrSizedOperandSegments"), + )?, + results, + attributes: Self::collect_attributes(&arguments)?, + derived_attributes: Self::collect_derived_attributes(definition)?, + can_infer_type: traits.iter().any(|r#trait| { + (r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") + || r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType")) + && variable_length_results_count == 0 + || r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty() + }), + summary: { + let summary = definition.str_value("summary")?; - fn new_successor( - name: &'a str, - constraint: SuccessorConstraint<'a>, - sequence_info: SequenceInfo, - ) -> Result { - Self::new( - name, - FieldKind::Successor { - constraint, - sequence_info, + [ + format!("[`{short_name}`]({class_name}) operation."), + if summary.is_empty() { + Default::default() + } else { + summary[0..1].to_uppercase() + &summary[1..] + "." + }, + ] + .join(" ") }, - ) + description: sanitize_documentation(definition.str_value("description")?)?, + regions, + }) } - fn new_element( - name: &'a str, - constraint: TypeConstraint<'a>, - kind: ElementKind, - sequence_info: SequenceInfo, - variadic_kind: VariadicKind, - ) -> Result { - Self::new( - name, - FieldKind::Element { - kind, - constraint, - sequence_info, - variadic_kind, - }, - ) + pub fn dialect_name(&self) -> &str { + self.dialect_name } -} -#[derive(Debug, Clone)] -pub struct Operation<'a> { - pub(crate) dialect: Record<'a>, - pub(crate) short_name: &'a str, - pub(crate) full_name: String, - pub(crate) class_name: &'a str, - pub(crate) summary: String, - pub(crate) can_infer_type: bool, - description: String, - regions: Vec>, - successors: Vec>, - results: Vec>, - operands: Vec>, - attributes: Vec>, - derived_attributes: Vec>, -} - -impl<'a> Operation<'a> { pub fn fields(&self) -> impl Iterator> + Clone { self.results .iter() @@ -317,8 +121,8 @@ impl<'a> Operation<'a> { .chain(self.derived_attributes.iter()) } - fn collect_successors(def: Record<'a>) -> Result, Error> { - let successors_dag = def.dag_value("successors")?; + fn collect_successors(definition: Record<'a>) -> Result, Error> { + let successors_dag = definition.dag_value("successors")?; let len = successors_dag.num_args(); successors_dag .args() @@ -329,7 +133,7 @@ impl<'a> Operation<'a> { SuccessorConstraint::new( value .try_into() - .map_err(|error: tblgen::Error| error.set_location(def))?, + .map_err(|error: tblgen::Error| error.set_location(definition))?, ), SequenceInfo { index, len }, ) @@ -337,8 +141,8 @@ impl<'a> Operation<'a> { .collect() } - fn collect_regions(def: Record<'a>) -> Result, Error> { - let regions_dag = def.dag_value("regions")?; + fn collect_regions(definition: Record<'a>) -> Result, Error> { + let regions_dag = definition.dag_value("regions")?; let len = regions_dag.num_args(); regions_dag .args() @@ -349,7 +153,7 @@ impl<'a> Operation<'a> { RegionConstraint::new( value .try_into() - .map_err(|error: tblgen::Error| error.set_location(def))?, + .map_err(|error: tblgen::Error| error.set_location(definition))?, ), SequenceInfo { index, len }, ) @@ -357,15 +161,15 @@ impl<'a> Operation<'a> { .collect() } - fn collect_traits(def: Record<'a>) -> Result, Error> { - let mut work_list = vec![def.list_value("traits")?]; + fn collect_traits(definition: Record<'a>) -> Result, Error> { + let mut work_list = vec![definition.list_value("traits")?]; let mut traits = Vec::new(); - while let Some(trait_def) = work_list.pop() { - for value in trait_def.iter() { + while let Some(trait_definition) = work_list.pop() { + for value in trait_definition.iter() { let trait_def: Record = value .try_into() - .map_err(|error: tblgen::Error| error.set_location(def))?; + .map_err(|error: tblgen::Error| error.set_location(definition))?; if trait_def.subclass_of("TraitList") { work_list.push(trait_def.list_value("traits")?); @@ -382,21 +186,22 @@ impl<'a> Operation<'a> { } fn dag_constraints( - def: Record<'a>, + definition: Record<'a>, dag_field_name: &str, ) -> Result)>, Error> { - def.dag_value(dag_field_name)? + definition + .dag_value(dag_field_name)? .args() - .map(|(name, arg)| { - let mut arg_def: Record = arg + .map(|(name, argument)| { + let mut argument_definition: Record = argument .try_into() - .map_err(|error: tblgen::Error| error.set_location(def))?; + .map_err(|error: tblgen::Error| error.set_location(definition))?; - if arg_def.subclass_of("OpVariable") { - arg_def = arg_def.def_value("constraint")?; + if argument_definition.subclass_of("OpVariable") { + argument_definition = argument_definition.def_value("constraint")?; } - Ok((name, arg_def)) + Ok((name, argument_definition)) }) .collect() } @@ -441,11 +246,11 @@ impl<'a> Operation<'a> { same_size: bool, attr_sized: bool, ) -> Result<(Vec>, usize), Error> { - let num_variable_length = elements + let variable_length_count = elements .iter() .filter(|(_, constraint)| constraint.has_variable_length()) .count(); - let mut variadic_kind = VariadicKind::new(num_variable_length, same_size, attr_sized); + let mut variadic_kind = VariadicKind::new(variable_length_count, same_size, attr_sized); let mut fields = vec![]; for (index, (name, constraint)) in elements.iter().enumerate() { @@ -462,28 +267,28 @@ impl<'a> Operation<'a> { match &mut variadic_kind { VariadicKind::Simple { - seen_variable_length, + variable_length_seen: seen_variable_length, } => { if constraint.has_variable_length() { *seen_variable_length = true; } } VariadicKind::SameSize { - num_preceding_simple, - num_preceding_variadic, + preceding_simple_count, + preceding_variadic_count, .. } => { if constraint.has_variable_length() { - *num_preceding_variadic += 1; + *preceding_variadic_count += 1; } else { - *num_preceding_simple += 1; + *preceding_simple_count += 1; } } VariadicKind::AttrSized {} => {} } } - Ok((fields, num_variable_length)) + Ok((fields, variable_length_count)) } fn collect_attributes( @@ -520,78 +325,10 @@ impl<'a> Operation<'a> { }) .collect() } - - pub fn from_def(def: Record<'a>) -> Result { - let dialect = def.def_value("opDialect")?; - let traits = Self::collect_traits(def)?; - let has_trait = |name: &str| traits.iter().any(|r#trait| r#trait.has_name(name)); - - let arguments = Self::dag_constraints(def, "arguments")?; - let regions = Self::collect_regions(def)?; - let (results, num_variable_length_results) = Self::collect_results( - def, - has_trait("::mlir::OpTrait::SameVariadicResultSize"), - has_trait("::mlir::OpTrait::AttrSizedResultSegments"), - )?; - - let name = def.name()?; - let class_name = if name.starts_with('_') { - name - } else if let Some(name) = name.split('_').nth(1) { - // Trim dialect prefix from name - name - } else { - name - }; - let short_name = def.str_value("opName")?; - - Ok(Self { - dialect, - short_name, - full_name: { - let dialect_name = dialect.string_value("name")?; - - if dialect_name.is_empty() { - short_name.into() - } else { - format!("{dialect_name}.{short_name}") - } - }, - class_name, - successors: Self::collect_successors(def)?, - operands: Self::collect_operands( - &arguments, - has_trait("::mlir::OpTrait::SameVariadicOperandSize"), - has_trait("::mlir::OpTrait::AttrSizedOperandSegments"), - )?, - results, - attributes: Self::collect_attributes(&arguments)?, - derived_attributes: Self::collect_derived_attributes(def)?, - can_infer_type: traits.iter().any(|r#trait| { - (r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") - || r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType")) - && num_variable_length_results == 0 - || r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty() - }), - summary: { - let summary = def.str_value("summary")?; - - if summary.is_empty() { - format!("[`{short_name}`]({class_name}) operation") - } else { - format!( - "[`{short_name}`]({class_name}) operation: {}", - summary[0..1].to_uppercase() + &summary[1..] - ) - } - }, - description: unindent::unindent(def.str_value("description")?), - regions, - }) - } } impl<'a> ToTokens for Operation<'a> { + // TODO Compile values for proper error handling and remove `Result::expect()`. fn to_tokens(&self, tokens: &mut TokenStream) { let class_name = format_ident!("{}", &self.class_name); let name = &self.full_name; @@ -605,8 +342,7 @@ impl<'a> ToTokens for Operation<'a> { .create_default_constructor() .expect("valid constructor"); let summary = &self.summary; - let description = - sanitize_documentation(&self.description).expect("valid Markdown documentation"); + let description = &self.description; tokens.append_all(quote! { #[doc = #summary] diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index b2f87ff374..ee08465c89 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -26,7 +26,7 @@ impl<'a> OperationField<'a> { Some(match variadic_kind { VariadicKind::Simple { - seen_variable_length, + variable_length_seen: seen_variable_length, } => { if constraint.is_optional() { // Optional element, and some singular elements. @@ -62,14 +62,14 @@ impl<'a> OperationField<'a> { } } VariadicKind::SameSize { - num_variable_length, - num_preceding_simple, - num_preceding_variadic, + variable_length_count, + preceding_simple_count, + preceding_variadic_count, } => { let compute_start_length = quote! { - let total_var_len = self.operation.#count() - #num_variable_length + 1; - let group_len = total_var_len / #num_variable_length; - let start = #num_preceding_simple + #num_preceding_variadic * group_len; + let total_var_len = self.operation.#count() - #variable_length_count + 1; + let group_len = total_var_len / #variable_length_count; + let start = #preceding_simple_count + #preceding_variadic_count * group_len; }; let get_elements = if constraint.has_variable_length() { quote! { diff --git a/macro/src/dialect/operation/element_kind.rs b/macro/src/dialect/operation/element_kind.rs new file mode 100644 index 0000000000..6f1de942ef --- /dev/null +++ b/macro/src/dialect/operation/element_kind.rs @@ -0,0 +1,14 @@ +#[derive(Debug, Clone, Copy)] +pub enum ElementKind { + Operand, + Result, +} + +impl ElementKind { + pub fn as_str(&self) -> &'static str { + match self { + Self::Operand => "operand", + Self::Result => "result", + } + } +} diff --git a/macro/src/dialect/operation/field_kind.rs b/macro/src/dialect/operation/field_kind.rs new file mode 100644 index 0000000000..944d3b2495 --- /dev/null +++ b/macro/src/dialect/operation/field_kind.rs @@ -0,0 +1,162 @@ +use super::{element_kind::ElementKind, SequenceInfo, VariadicKind}; +use crate::dialect::{ + error::Error, + types::{AttributeConstraint, RegionConstraint, SuccessorConstraint, TypeConstraint}, +}; +use syn::{parse_quote, Type}; + +#[derive(Debug, Clone)] +pub enum FieldKind<'a> { + Element { + kind: ElementKind, + constraint: TypeConstraint<'a>, + sequence_info: SequenceInfo, + variadic_kind: VariadicKind, + }, + Attribute { + constraint: AttributeConstraint<'a>, + }, + Successor { + constraint: SuccessorConstraint<'a>, + sequence_info: SequenceInfo, + }, + Region { + constraint: RegionConstraint<'a>, + sequence_info: SequenceInfo, + }, +} + +impl<'a> FieldKind<'a> { + pub fn as_str(&self) -> &'static str { + match self { + Self::Element { kind, .. } => kind.as_str(), + Self::Attribute { .. } => "attribute", + Self::Successor { .. } => "successor", + Self::Region { .. } => "region", + } + } + + pub fn is_optional(&self) -> Result { + Ok(match self { + Self::Element { constraint, .. } => constraint.is_optional(), + Self::Attribute { constraint, .. } => { + constraint.is_optional()? || constraint.has_default_value()? + } + Self::Successor { .. } | Self::Region { .. } => false, + }) + } + + pub fn is_result(&self) -> bool { + matches!( + self, + Self::Element { + kind: ElementKind::Result, + .. + } + ) + } + + pub fn parameter_type(&self) -> Result { + Ok(match self { + Self::Element { + kind, constraint, .. + } => { + let base_type: Type = match kind { + ElementKind::Operand => { + parse_quote!(::melior::ir::Value<'c, '_>) + } + ElementKind::Result => { + parse_quote!(::melior::ir::Type<'c>) + } + }; + if constraint.is_variadic() { + parse_quote! { &[#base_type] } + } else { + base_type + } + } + Self::Attribute { constraint } => { + if constraint.is_unit()? { + parse_quote!(bool) + } else { + let r#type: Type = syn::parse_str(constraint.storage_type()?)?; + parse_quote!(#r#type<'c>) + } + } + Self::Successor { constraint, .. } => { + let r#type: Type = parse_quote!(&::melior::ir::Block<'c>); + if constraint.is_variadic() { + parse_quote!(&[#r#type]) + } else { + r#type + } + } + Self::Region { constraint, .. } => { + let r#type: Type = parse_quote!(::melior::ir::Region<'c>); + if constraint.is_variadic() { + parse_quote!(Vec<#r#type>) + } else { + r#type + } + } + }) + } + + fn create_result_type(r#type: Type) -> Type { + parse_quote!(Result<#r#type, ::melior::Error>) + } + + fn create_iterator_type(r#type: Type) -> Type { + parse_quote!(impl Iterator) + } + + pub fn return_type(&self) -> Result { + Ok(match self { + Self::Element { + kind, + constraint, + variadic_kind, + .. + } => { + let base_type: Type = match kind { + ElementKind::Operand => { + parse_quote!(::melior::ir::Value<'c, '_>) + } + ElementKind::Result => { + parse_quote!(::melior::ir::operation::OperationResult<'c, '_>) + } + }; + if !constraint.is_variadic() { + Self::create_result_type(base_type) + } else if let VariadicKind::AttrSized {} = variadic_kind { + Self::create_result_type(Self::create_iterator_type(base_type)) + } else { + Self::create_iterator_type(base_type) + } + } + Self::Attribute { constraint } => { + if constraint.is_unit()? { + parse_quote!(bool) + } else { + Self::create_result_type(self.parameter_type()?) + } + } + Self::Successor { constraint, .. } => { + let r#type: Type = parse_quote!(::melior::ir::BlockRef<'c, '_>); + if constraint.is_variadic() { + Self::create_iterator_type(r#type) + } else { + Self::create_result_type(r#type) + } + } + Self::Region { constraint, .. } => { + let r#type: Type = parse_quote!(::melior::ir::RegionRef<'c, '_>); + if constraint.is_variadic() { + Self::create_iterator_type(r#type) + } else { + Self::create_result_type(r#type) + } + } + }) + } +} diff --git a/macro/src/dialect/operation/operation_field.rs b/macro/src/dialect/operation/operation_field.rs new file mode 100644 index 0000000000..1977823b79 --- /dev/null +++ b/macro/src/dialect/operation/operation_field.rs @@ -0,0 +1,77 @@ +use super::{element_kind::ElementKind, field_kind::FieldKind, SequenceInfo, VariadicKind}; +use crate::dialect::{ + error::Error, + types::{AttributeConstraint, RegionConstraint, SuccessorConstraint, TypeConstraint}, + utility::sanitize_snake_case_name, +}; +use proc_macro2::Ident; + +#[derive(Debug, Clone)] +pub struct OperationField<'a> { + pub(crate) name: &'a str, + pub(crate) sanitized_name: Ident, + pub(crate) kind: FieldKind<'a>, +} + +impl<'a> OperationField<'a> { + fn new(name: &'a str, kind: FieldKind<'a>) -> Result { + Ok(Self { + name, + sanitized_name: sanitize_snake_case_name(name)?, + kind, + }) + } + + pub fn new_attribute( + name: &'a str, + constraint: AttributeConstraint<'a>, + ) -> Result { + Self::new(name, FieldKind::Attribute { constraint }) + } + + pub fn new_region( + name: &'a str, + constraint: RegionConstraint<'a>, + sequence_info: SequenceInfo, + ) -> Result { + Self::new( + name, + FieldKind::Region { + constraint, + sequence_info, + }, + ) + } + + pub fn new_successor( + name: &'a str, + constraint: SuccessorConstraint<'a>, + sequence_info: SequenceInfo, + ) -> Result { + Self::new( + name, + FieldKind::Successor { + constraint, + sequence_info, + }, + ) + } + + pub fn new_element( + name: &'a str, + constraint: TypeConstraint<'a>, + kind: ElementKind, + sequence_info: SequenceInfo, + variadic_kind: VariadicKind, + ) -> Result { + Self::new( + name, + FieldKind::Element { + kind, + constraint, + sequence_info, + variadic_kind, + }, + ) + } +} diff --git a/macro/src/dialect/operation/sequence_info.rs b/macro/src/dialect/operation/sequence_info.rs new file mode 100644 index 0000000000..1dd7f85553 --- /dev/null +++ b/macro/src/dialect/operation/sequence_info.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone)] +pub struct SequenceInfo { + pub index: usize, + pub len: usize, +} diff --git a/macro/src/dialect/operation/variadic_kind.rs b/macro/src/dialect/operation/variadic_kind.rs new file mode 100644 index 0000000000..68531d4b5e --- /dev/null +++ b/macro/src/dialect/operation/variadic_kind.rs @@ -0,0 +1,32 @@ +#[derive(Clone, Debug)] +pub enum VariadicKind { + Simple { + variable_length_seen: bool, + }, + SameSize { + variable_length_count: usize, + preceding_simple_count: usize, + preceding_variadic_count: usize, + }, + AttrSized {}, +} + +impl VariadicKind { + pub fn new(variable_length_count: usize, same_size: bool, attr_sized: bool) -> Self { + if variable_length_count <= 1 { + VariadicKind::Simple { + variable_length_seen: false, + } + } else if same_size { + VariadicKind::SameSize { + variable_length_count, + preceding_simple_count: 0, + preceding_variadic_count: 0, + } + } else if attr_sized { + VariadicKind::AttrSized {} + } else { + unimplemented!() + } + } +} diff --git a/macro/src/dialect/utility.rs b/macro/src/dialect/utility.rs index bc43fae820..a13cf22235 100644 --- a/macro/src/dialect/utility.rs +++ b/macro/src/dialect/utility.rs @@ -32,7 +32,7 @@ fn sanitize_name(name: &str) -> Result { pub fn sanitize_documentation(string: &str) -> Result { let arena = Arena::new(); - let node = parse_document(&arena, string, &Default::default()); + let node = parse_document(&arena, &unindent::unindent(string), &Default::default()); for node in node.traverse() { let NodeEdge::Start(node) = node else {