Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed Dec 4, 2023
1 parent 7410514 commit d352179
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 47 deletions.
46 changes: 5 additions & 41 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod r#type;
mod utility;

use dialect::DialectInput;
use parse::{DialectOperationSet, IdentifierList};
use parse::{DialectOperationSet, IdentifierList, PassSet};
use proc_macro::TokenStream;
use quote::quote;
use std::error::Error;
Expand Down Expand Up @@ -68,15 +68,6 @@ pub fn attribute_check_functions(stream: TokenStream) -> TokenStream {
convert_result(attribute::generate(identifiers.identifiers()))
}

#[proc_macro]
pub fn async_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Async").unwrap().into()
}))
}

#[proc_macro]
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
Expand All @@ -90,38 +81,11 @@ pub fn conversion_passes(stream: TokenStream) -> TokenStream {
}

#[proc_macro]
pub fn gpu_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("GPU").unwrap().into()
}))
}

#[proc_macro]
pub fn transform_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Transforms").unwrap().into()
}))
}

#[proc_macro]
pub fn linalg_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Linalg").unwrap().into()
}))
}

#[proc_macro]
pub fn sparse_tensor_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
pub fn passes(stream: TokenStream) -> TokenStream {
let set = parse_macro_input!(stream as PassSet);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("SparseTensor").unwrap().into()
convert_result(pass::generate(set.identifiers(), |name| {
name.strip_prefix(&set.prefix().value()).unwrap().into()
}))
}

Expand Down
12 changes: 6 additions & 6 deletions macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use syn::{
bracketed,
parse::{Parse, ParseStream},
punctuated::Punctuated,
Lit, Result, Token,
LitStr, Result, Token,
};

pub struct IdentifierList {
Expand Down Expand Up @@ -58,13 +58,13 @@ impl Parse for DialectOperationSet {
}

pub struct PassSet {
prefix: Lit,
prefix: LitStr,
identifiers: IdentifierList,
}

impl PassSet {
pub const fn dialect(&self) -> &Ident {
&self.dialect
pub const fn prefix(&self) -> &LitStr {
&self.prefix
}

pub fn identifiers(&self) -> &[Ident] {
Expand All @@ -74,11 +74,11 @@ impl PassSet {

impl Parse for PassSet {
fn parse(input: ParseStream) -> Result<Self> {
let dialect = Ident::parse(input)?;
let prefix = input.parse()?;
<Token![,]>::parse(input)?;

Ok(Self {
dialect,
prefix,
identifiers: {
let content;
bracketed!(content in input);
Expand Down

0 comments on commit d352179

Please sign in to comment.