Skip to content

Commit

Permalink
Refactor operation builders in dialect macros (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Feb 16, 2024
1 parent 43140fe commit 45965f0
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 199 deletions.
4 changes: 2 additions & 2 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod utility;
use self::{
error::Error,
generation::generate_operation,
utility::{sanitize_documentation, sanitize_snake_case_name},
utility::{sanitize_documentation, sanitize_snake_case_identifier},
};
pub use input::DialectInput;
use operation::Operation;
Expand Down Expand Up @@ -81,7 +81,7 @@ fn generate_dialect_module(
"`{name}` dialect.\n\n{}",
sanitize_documentation(dialect.str_value("description").unwrap_or(""),)?
);
let name = sanitize_snake_case_name(name)?;
let name = sanitize_snake_case_identifier(name)?;

Ok(quote! {
#[doc = #doc]
Expand Down
11 changes: 7 additions & 4 deletions macro/src/dialect/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ mod field_accessor;
mod operation_builder;

use self::{
attribute_accessor::generate_attribute_accessors, field_accessor::generate_accessor,
operation_builder::generate_operation_builder,
attribute_accessor::generate_attribute_accessors,
field_accessor::generate_accessor,
operation_builder::{
generate_default_constructor, generate_operation_builder, generate_operation_builder_fn,
},
};
use super::operation::{Operation, OperationBuilder};
use crate::dialect::error::Error;
Expand All @@ -28,8 +31,8 @@ pub fn generate_operation(operation: &Operation) -> Result<TokenStream, Error> {

let builder = OperationBuilder::new(operation)?;
let builder_tokens = generate_operation_builder(&builder)?;
let builder_fn = builder.create_op_builder_fn()?;
let default_constructor = builder.create_default_constructor()?;
let builder_fn = generate_operation_builder_fn(&builder)?;
let default_constructor = generate_default_constructor(&builder)?;

Ok(quote! {
#[doc = #summary]
Expand Down
10 changes: 5 additions & 5 deletions macro/src/dialect/generation/attribute_accessor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::dialect::{
error::Error,
operation::{Attribute, OperationFieldLike},
utility::sanitize_snake_case_name,
utility::sanitize_snake_case_identifier,
};
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -21,7 +21,7 @@ pub fn generate_attribute_accessors(attribute: &Attribute) -> Result<TokenStream
fn generate_getter(attribute: &Attribute) -> Result<TokenStream, Error> {
let name = attribute.name();

let ident = attribute.sanitized_name();
let identifier = attribute.singular_identifier();
let return_type = attribute.return_type();
let body = if attribute.is_unit() {
quote! { self.operation.attribute(#name).is_some() }
Expand All @@ -32,7 +32,7 @@ fn generate_getter(attribute: &Attribute) -> Result<TokenStream, Error> {

Ok(quote! {
#[allow(clippy::needless_question_mark)]
pub fn #ident(&self, context: &'c ::melior::Context) -> #return_type {
pub fn #identifier(&self, context: &'c ::melior::Context) -> #return_type {
#body
}
})
Expand All @@ -55,7 +55,7 @@ fn generate_setter(attribute: &Attribute) -> Result<TokenStream, Error> {
}
};

let ident = sanitize_snake_case_name(&format!("set_{}", attribute.name()))?;
let ident = sanitize_snake_case_identifier(&format!("set_{}", attribute.name()))?;
let r#type = attribute.parameter_type();

Ok(quote! {
Expand All @@ -68,7 +68,7 @@ fn generate_setter(attribute: &Attribute) -> Result<TokenStream, Error> {
fn generate_remover(attribute: &Attribute) -> Result<Option<TokenStream>, Error> {
Ok(if attribute.is_unit() || attribute.is_optional() {
let name = attribute.name();
let ident = sanitize_snake_case_name(&format!("remove_{}", attribute.name()))?;
let ident = sanitize_snake_case_identifier(&format!("remove_{}", attribute.name()))?;

Some(quote! {
pub fn #ident(&mut self, context: &'c ::melior::Context) -> Result<(), ::melior::Error> {
Expand Down
166 changes: 153 additions & 13 deletions macro/src/dialect/generation/operation_builder.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::dialect::{
error::Error, operation::OperationBuilder, utility::sanitize_snake_case_name,
error::Error, operation::OperationBuilder, utility::sanitize_snake_case_identifier,
};
use proc_macro2::TokenStream;
use quote::quote;
use quote::{format_ident, quote};
use syn::Ident;

pub fn generate_operation_builder(builder: &OperationBuilder) -> Result<TokenStream, Error> {
let field_names = builder
.type_state()
.field_names()
.map(sanitize_snake_case_name)
.map(sanitize_snake_case_identifier)
.collect::<Result<Vec<_>, _>>()?;

let phantom_fields =
Expand All @@ -27,29 +28,168 @@ pub fn generate_operation_builder(builder: &OperationBuilder) -> Result<TokenStr
.map(|name| quote! { #name: ::std::marker::PhantomData })
.collect::<Vec<_>>();

let builder_fns = builder
.create_builder_fns(&field_names, phantom_arguments.as_slice())
.collect::<Result<Vec<_>, _>>()?;
let builder_fns = generate_builder_fns(builder, &field_names, phantom_arguments.as_slice())?;

let new = builder.create_new_fn(phantom_arguments.as_slice())?;
let build = builder.create_build_fn()?;
let new_fn = generate_new_fn(builder, phantom_arguments.as_slice())?;
let build_fn = generate_build_fn(builder)?;

let builder_identifier = builder.identifier();
let doc = format!("Builder for {}", builder.operation().summary()?);
let iter_arguments = builder.type_state().parameters();
let doc = format!("A builder for {}", builder.operation().summary()?);
let type_arguments = builder.type_state().parameters();

Ok(quote! {
#[doc = #doc]
pub struct #builder_identifier<'c, #(#iter_arguments),*> {
pub struct #builder_identifier<'c, #(#type_arguments),*> {
builder: ::melior::ir::operation::OperationBuilder<'c>,
context: &'c ::melior::Context,
#(#phantom_fields),*
}

#new
#new_fn

#(#builder_fns)*

#build
#build_fn
})
}

fn generate_builder_fns(
builder: &OperationBuilder,
field_names: &[Ident],
phantoms: &[TokenStream],
) -> Result<Vec<TokenStream>, Error> {
builder.operation().fields().map(move |field| {
let builder_identifier = builder.identifier();
let identifier = sanitize_snake_case_identifier(field.name())?;
let parameter_type = field.parameter_type();
let argument = quote! { #identifier: #parameter_type };
let add = format_ident!("add_{}", field.plural_kind_identifier());

// Argument types can be singular and variadic. But `add` functions in Melior
// are always variadic, so we need to create a slice or `Vec` for singular
// arguments.
let add_arguments = field.add_arguments(&identifier);

Ok(if field.is_optional() {
let parameters = builder.type_state().parameters().collect::<Vec<_>>();

quote! {
impl<'c, #(#parameters),*> #builder_identifier<'c, #(#parameters),*> {
pub fn #identifier(mut self, #argument) -> #builder_identifier<'c, #(#parameters),*> {
self.builder = self.builder.#add(#add_arguments);
self
}
}
}
} else if field.is_result() && builder.operation().can_infer_type() {
quote!()
} else {
let parameters = builder.type_state().parameters_without(field.name());
let arguments_set = builder.type_state().arguments_set(field.name(), true);
let arguments_unset = builder.type_state().arguments_set(field.name(), false);

quote! {
impl<'c, #(#parameters),*> #builder_identifier<'c, #(#arguments_unset),*> {
pub fn #identifier(mut self, #argument) -> #builder_identifier<'c, #(#arguments_set),*> {
self.builder = self.builder.#add(#add_arguments);
let Self { context, mut builder, #(#field_names),* } = self;
#builder_identifier {
context,
builder,
#(#phantoms),*
}
}
}
}
})
}).collect()
}

fn generate_build_fn(builder: &OperationBuilder) -> Result<TokenStream, Error> {
let builder_ident = builder.identifier();
let arguments = builder.type_state().arguments_all_set(true);
let class_name = format_ident!("{}", &builder.operation().class_name()?);
let error = format!("should be a valid {class_name}");
let maybe_infer = builder
.operation()
.can_infer_type()
.then_some(quote! { .enable_result_type_inference() });

Ok(quote! {
impl<'c> #builder_ident<'c, #(#arguments),*> {
pub fn build(self) -> #class_name<'c> {
self.builder #maybe_infer.build().expect("valid operation").try_into().expect(#error)
}
}
})
}

fn generate_new_fn(
builder: &OperationBuilder,
phantoms: &[TokenStream],
) -> Result<TokenStream, Error> {
let builder_ident = builder.identifier();
let name = &builder.operation().full_name()?;
let arguments = builder.type_state().arguments_all_set(false);

Ok(quote! {
impl<'c> #builder_ident<'c, #(#arguments),*> {
pub fn new(context: &'c ::melior::Context, location: ::melior::ir::Location<'c>) -> Self {
Self {
context,
builder: ::melior::ir::operation::OperationBuilder::new( #name, location),
#(#phantoms),*
}
}
}
})
}

pub fn generate_operation_builder_fn(builder: &OperationBuilder) -> Result<TokenStream, Error> {
let builder_ident = builder.identifier();
let arguments = builder.type_state().arguments_all_set(false);

Ok(quote! {
pub fn builder(
context: &'c ::melior::Context,
location: ::melior::ir::Location<'c>
) -> #builder_ident<'c, #(#arguments),*> {
#builder_ident::new(context, location)
}
})
}

pub fn generate_default_constructor(builder: &OperationBuilder) -> Result<TokenStream, Error> {
let class_name = format_ident!("{}", &builder.operation().class_name()?);
let name = sanitize_snake_case_identifier(builder.operation().short_name()?)?;
let arguments = builder
.operation()
.required_fields()
.map(|field| {
let parameter_type = &field.parameter_type();
let parameter_name = &field.singular_identifier();

quote! { #parameter_name: #parameter_type }
})
.chain([quote! { location: ::melior::ir::Location<'c> }])
.collect::<Vec<_>>();
let builder_calls = builder
.operation()
.required_fields()
.map(|field| {
let parameter_name = &field.singular_identifier();

quote! { .#parameter_name(#parameter_name) }
})
.collect::<Vec<_>>();

let doc = format!("Creates a new {}", builder.operation().summary()?);

Ok(quote! {
#[allow(clippy::too_many_arguments)]
#[doc = #doc]
pub fn #name<'c>(context: &'c ::melior::Context, #(#arguments),*) -> #class_name<'c> {
#class_name::builder(context, location)#(#builder_calls)*.build()
}
})
}
9 changes: 9 additions & 0 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ impl<'a> Operation<'a> {
})
}

pub fn can_infer_type(&self) -> bool {
self.can_infer_type
}

fn dialect(&self) -> Result<Record, Error> {
Ok(self.definition.def_value("opDialect")?)
}
Expand Down Expand Up @@ -148,6 +152,11 @@ impl<'a> Operation<'a> {
self.attributes.iter().chain(&self.derived_attributes)
}

pub fn required_fields(&self) -> impl Iterator<Item = &dyn OperationFieldLike> {
self.fields()
.filter(|field| (!field.is_result() || !self.can_infer_type) && !field.is_optional())
}

fn collect_successors(definition: Record<'a>) -> Result<Vec<OperationField>, Error> {
let successors_dag = definition.dag_value("successors")?;
let len = successors_dag.num_args();
Expand Down
17 changes: 9 additions & 8 deletions macro/src/dialect/operation/attribute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::dialect::{
error::Error,
operation::operation_field::OperationFieldLike,
utility::{generate_result_type, sanitize_snake_case_name},
utility::{generate_result_type, sanitize_snake_case_identifier},
};
use once_cell::sync::Lazy;
use proc_macro2::{Ident, TokenStream};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use std::collections::HashMap;
use syn::Ident;
use syn::{parse_quote, Type};
use tblgen::{error::TableGenError, Record};

Expand Down Expand Up @@ -60,7 +61,7 @@ static ATTRIBUTE_TYPES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(||
#[derive(Debug)]
pub struct Attribute<'a> {
name: &'a str,
sanitized_name: Ident,
singular_identifier: Ident,
storage_type_string: String,
storage_type: Type,
optional: bool,
Expand All @@ -73,7 +74,7 @@ impl<'a> Attribute<'a> {

Ok(Self {
name,
sanitized_name: sanitize_snake_case_name(name)?,
singular_identifier: sanitize_snake_case_identifier(name)?,
storage_type: syn::parse_str(
ATTRIBUTE_TYPES
.get(storage_type_string.trim())
Expand Down Expand Up @@ -114,12 +115,12 @@ impl OperationFieldLike for Attribute<'_> {
self.name
}

fn plural_identifier(&self) -> &str {
"attributes"
fn singular_identifier(&self) -> &Ident {
&self.singular_identifier
}

fn sanitized_name(&self) -> &Ident {
&self.sanitized_name
fn plural_kind_identifier(&self) -> Ident {
Ident::new("attributes", Span::call_site())
}

fn parameter_type(&self) -> Type {
Expand Down
Loading

0 comments on commit 45965f0

Please sign in to comment.