Skip to content

Commit

Permalink
[wgsl-in] Fail on more repeated attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall committed Aug 11, 2023
1 parent 5a44632 commit f26999c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 25 deletions.
46 changes: 27 additions & 19 deletions src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ enum Rule {
GeneralExpr,
}

const fn fail_if_repeated_attribute<'a>(repeated: bool, name_span: Span) -> Result<(), Error<'a>> {
if repeated {
return Err(Error::RepeatedAttribute(name_span));
}
Ok(())
}

#[derive(Default)]
struct BindingParser {
location: Option<u32>,
Expand All @@ -136,31 +143,24 @@ impl BindingParser {
name: &'a str,
name_span: Span,
) -> Result<(), Error<'a>> {
let fail_if_repeated = |repeated| {
if repeated {
return Err(Error::RepeatedAttribute(name_span));
}
Ok(())
};

match name {
"location" => {
lexer.expect(Token::Paren('('))?;
fail_if_repeated(self.location.is_some())?;
fail_if_repeated_attribute(self.location.is_some(), name_span)?;
self.location = Some(Parser::non_negative_i32_literal(lexer)?);
lexer.expect(Token::Paren(')'))?;
}
"builtin" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
fail_if_repeated(self.built_in.is_some())?;
fail_if_repeated_attribute(self.built_in.is_some(), name_span)?;
self.built_in = Some(conv::map_built_in(raw, span)?);
lexer.expect(Token::Paren(')'))?;
}
"interpolate" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
fail_if_repeated(self.interpolation.is_some())?;
fail_if_repeated_attribute(self.interpolation.is_some(), name_span)?;
self.interpolation = Some(conv::map_interpolation(raw, span)?);
if lexer.skip(Token::Separator(',')) {
let (raw, span) = lexer.next_ident_with_span()?;
Expand All @@ -169,7 +169,7 @@ impl BindingParser {
lexer.expect(Token::Paren(')'))?;
}
"invariant" => {
fail_if_repeated(self.invariant)?;
fail_if_repeated_attribute(self.invariant, name_span)?;
self.invariant = true;
}
_ => return Err(Error::UnknownAttribute(name_span)),
Expand Down Expand Up @@ -1008,16 +1008,18 @@ impl Parser {
let mut bind_parser = BindingParser::default();
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("size", _) => {
("size", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
fail_if_repeated_attribute(size.is_some(), name_span)?;
size = Some((value, span));
}
("align", _) => {
("align", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
fail_if_repeated_attribute(align.is_some(), name_span)?;
align = Some((value, span));
}
(word, word_span) => bind_parser.parse(lexer, word, word_span)?,
Expand Down Expand Up @@ -2152,23 +2154,28 @@ impl Parser {
self.push_rule_span(Rule::Attribute, lexer);
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("binding", _) => {
("binding", name_span) => {
lexer.expect(Token::Paren('('))?;
fail_if_repeated_attribute(bind_index.is_some(), name_span)?;
bind_index = Some(Self::non_negative_i32_literal(lexer)?);
lexer.expect(Token::Paren(')'))?;
}
("group", _) => {
("group", name_span) => {
lexer.expect(Token::Paren('('))?;
fail_if_repeated_attribute(bind_group.is_some(), name_span)?;
bind_group = Some(Self::non_negative_i32_literal(lexer)?);
lexer.expect(Token::Paren(')'))?;
}
("vertex", _) => {
("vertex", name_span) => {
fail_if_repeated_attribute(stage.is_some(), name_span)?;
stage = Some(crate::ShaderStage::Vertex);
}
("fragment", _) => {
("fragment", name_span) => {
fail_if_repeated_attribute(stage.is_some(), name_span)?;
stage = Some(crate::ShaderStage::Fragment);
}
("compute", _) => {
("compute", name_span) => {
fail_if_repeated_attribute(stage.is_some(), name_span)?;
stage = Some(crate::ShaderStage::Compute);
}
("workgroup_size", _) => {
Expand All @@ -2188,7 +2195,7 @@ impl Parser {
}
}
}
("early_depth_test", _) => {
("early_depth_test", name_span) => {
let conservative = if lexer.skip(Token::Paren('(')) {
let (ident, ident_span) = lexer.next_ident_with_span()?;
let value = conv::map_conservative_depth(ident, ident_span)?;
Expand All @@ -2197,6 +2204,7 @@ impl Parser {
} else {
None
};
fail_if_repeated_attribute(early_depth_test.is_some(), name_span)?;
early_depth_test = Some(crate::EarlyDepthTest { conservative });
}
(_, word_span) => return Err(Error::UnknownAttribute(word_span)),
Expand Down
24 changes: 18 additions & 6 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,23 @@ fn parse_repeated_attributes() {
Span,
};

let template = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
for attribute in [
"location(0)",
"builtin(position)",
"interpolate(flat)",
"invariant",
let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
for (attribute, template) in [
("align(16)", template_struct),
("binding(0)", template_resource),
("builtin(position)", template_vs),
("compute", template_stage),
("fragment", template_stage),
("group(0)", template_resource),
("interpolate(flat)", template_vs),
("invariant", template_vs),
("location(0)", template_vs),
("size(16)", template_struct),
("vertex", template_stage),
("early_depth_test(less_equal)", template_resource),
] {
let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
Expand All @@ -531,6 +542,7 @@ fn parse_repeated_attributes() {
let expected_span = Span::new(span_start, span_end);

let result = Frontend::new().inner(&shader);
println!("WHAT? {} RESULT: {:?}", attribute, result);
assert!(matches!(
result.unwrap_err(),
Error::RepeatedAttribute(span) if span == expected_span
Expand Down

0 comments on commit f26999c

Please sign in to comment.