From f26999c1e18e7ac9f99a639b62201ede11da74b0 Mon Sep 17 00:00:00 2001 From: Fredrik Fornwall Date: Fri, 11 Aug 2023 15:44:59 +0200 Subject: [PATCH] [wgsl-in] Fail on more repeated attributes --- src/front/wgsl/parse/mod.rs | 46 ++++++++++++++++++++++--------------- src/front/wgsl/tests.rs | 24 ++++++++++++++----- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index fc7ab87012..55377cc884 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -120,6 +120,13 @@ enum Rule { GeneralExpr, } +const fn fail_if_repeated_attribute<'a>(repeated: bool, name_span: Span) -> Result<(), Error<'a>> { + if repeated { + return Err(Error::RepeatedAttribute(name_span)); + } + Ok(()) +} + #[derive(Default)] struct BindingParser { location: Option, @@ -136,31 +143,24 @@ impl BindingParser { name: &'a str, name_span: Span, ) -> Result<(), Error<'a>> { - let fail_if_repeated = |repeated| { - if repeated { - return Err(Error::RepeatedAttribute(name_span)); - } - Ok(()) - }; - match name { "location" => { lexer.expect(Token::Paren('('))?; - fail_if_repeated(self.location.is_some())?; + fail_if_repeated_attribute(self.location.is_some(), name_span)?; self.location = Some(Parser::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } "builtin" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; - fail_if_repeated(self.built_in.is_some())?; + fail_if_repeated_attribute(self.built_in.is_some(), name_span)?; self.built_in = Some(conv::map_built_in(raw, span)?); lexer.expect(Token::Paren(')'))?; } "interpolate" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; - fail_if_repeated(self.interpolation.is_some())?; + fail_if_repeated_attribute(self.interpolation.is_some(), name_span)?; self.interpolation = Some(conv::map_interpolation(raw, span)?); if lexer.skip(Token::Separator(',')) { let (raw, span) = lexer.next_ident_with_span()?; @@ -169,7 +169,7 @@ impl BindingParser { lexer.expect(Token::Paren(')'))?; } "invariant" => { - fail_if_repeated(self.invariant)?; + fail_if_repeated_attribute(self.invariant, name_span)?; self.invariant = true; } _ => return Err(Error::UnknownAttribute(name_span)), @@ -1008,16 +1008,18 @@ impl Parser { let mut bind_parser = BindingParser::default(); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("size", _) => { + ("size", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; + fail_if_repeated_attribute(size.is_some(), name_span)?; size = Some((value, span)); } - ("align", _) => { + ("align", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; + fail_if_repeated_attribute(align.is_some(), name_span)?; align = Some((value, span)); } (word, word_span) => bind_parser.parse(lexer, word, word_span)?, @@ -2152,23 +2154,28 @@ impl Parser { self.push_rule_span(Rule::Attribute, lexer); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("binding", _) => { + ("binding", name_span) => { lexer.expect(Token::Paren('('))?; + fail_if_repeated_attribute(bind_index.is_some(), name_span)?; bind_index = Some(Self::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } - ("group", _) => { + ("group", name_span) => { lexer.expect(Token::Paren('('))?; + fail_if_repeated_attribute(bind_group.is_some(), name_span)?; bind_group = Some(Self::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } - ("vertex", _) => { + ("vertex", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Vertex); } - ("fragment", _) => { + ("fragment", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Fragment); } - ("compute", _) => { + ("compute", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Compute); } ("workgroup_size", _) => { @@ -2188,7 +2195,7 @@ impl Parser { } } } - ("early_depth_test", _) => { + ("early_depth_test", name_span) => { let conservative = if lexer.skip(Token::Paren('(')) { let (ident, ident_span) = lexer.next_ident_with_span()?; let value = conv::map_conservative_depth(ident, ident_span)?; @@ -2197,6 +2204,7 @@ impl Parser { } else { None }; + fail_if_repeated_attribute(early_depth_test.is_some(), name_span)?; early_depth_test = Some(crate::EarlyDepthTest { conservative }); } (_, word_span) => return Err(Error::UnknownAttribute(word_span)), diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index cabb6f80bb..af65127572 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -517,12 +517,23 @@ fn parse_repeated_attributes() { Span, }; - let template = "@vertex fn vs() -> __REPLACE__ vec4 { return vec4(0.0); }"; - for attribute in [ - "location(0)", - "builtin(position)", - "interpolate(flat)", - "invariant", + let template_vs = "@vertex fn vs() -> __REPLACE__ vec4 { return vec4(0.0); }"; + let template_struct = "struct A { __REPLACE__ data: vec3 }"; + let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array;"; + let template_stage = "__REPLACE__ fn vs() -> vec4 { return vec4(0.0); }"; + for (attribute, template) in [ + ("align(16)", template_struct), + ("binding(0)", template_resource), + ("builtin(position)", template_vs), + ("compute", template_stage), + ("fragment", template_stage), + ("group(0)", template_resource), + ("interpolate(flat)", template_vs), + ("invariant", template_vs), + ("location(0)", template_vs), + ("size(16)", template_struct), + ("vertex", template_stage), + ("early_depth_test(less_equal)", template_resource), ] { let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; @@ -531,6 +542,7 @@ fn parse_repeated_attributes() { let expected_span = Span::new(span_start, span_end); let result = Frontend::new().inner(&shader); + println!("WHAT? {} RESULT: {:?}", attribute, result); assert!(matches!( result.unwrap_err(), Error::RepeatedAttribute(span) if span == expected_span