From 7fb913aab339b272ddf98c8bf51e4d6a0416498e Mon Sep 17 00:00:00 2001 From: NiiRoZz Date: Mon, 2 Dec 2024 23:46:05 +0100 Subject: [PATCH] Add support for out and inout parameters. --- include/NZSL/Ast/Compare.inl | 6 ++ include/NZSL/Ast/Nodes.hpp | 11 +++ include/NZSL/Lang/ErrorList.hpp | 2 + include/NZSL/Lang/TokenList.hpp | 2 + include/NZSL/Parser.hpp | 1 + include/NZSL/ShaderBuilder.hpp | 4 +- include/NZSL/ShaderBuilder.inl | 6 +- src/NZSL/Ast/AstSerializer.cpp | 4 + src/NZSL/Ast/Cloner.cpp | 4 + src/NZSL/Ast/Nodes.cpp | 17 ++++ src/NZSL/Ast/SanitizeVisitor.cpp | 8 ++ src/NZSL/GlslWriter.cpp | 9 ++ src/NZSL/LangWriter.cpp | 18 ++++ src/NZSL/Lexer.cpp | 2 + src/NZSL/Parser.cpp | 72 ++++++++++++++- src/NZSL/SpirV/SpirvAstVisitor.cpp | 22 ++++- tests/src/Tests/ErrorsTests.cpp | 31 +++++++ tests/src/Tests/FunctionsTests.cpp | 133 ++++++++++++++++++++++++++++ tests/src/Tests/IdentifierTests.cpp | 24 ++--- 19 files changed, 355 insertions(+), 21 deletions(-) diff --git a/include/NZSL/Ast/Compare.inl b/include/NZSL/Ast/Compare.inl index bce8fd1..6027813 100644 --- a/include/NZSL/Ast/Compare.inl +++ b/include/NZSL/Ast/Compare.inl @@ -224,6 +224,9 @@ namespace nzsl::Ast inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs, const ComparisonParams& params) { + if (!Compare(lhs.semantic, rhs.semantic, params)) + return false; + if (!Compare(lhs.name, rhs.name, params)) return false; @@ -388,6 +391,9 @@ namespace nzsl::Ast if (!Compare(lhs.parameters, rhs.parameters, params)) return false; + if (!Compare(lhs.parametersSemantic, rhs.parametersSemantic, params)) + return false; + return true; } diff --git a/include/NZSL/Ast/Nodes.hpp b/include/NZSL/Ast/Nodes.hpp index b43d2dd..7dc1f12 100644 --- a/include/NZSL/Ast/Nodes.hpp +++ b/include/NZSL/Ast/Nodes.hpp @@ -117,10 +117,18 @@ namespace nzsl::Ast struct NZSL_API CallFunctionExpression : Expression { + enum class ParameterSemantic : Nz::UInt8 + { + In, + Out, + InOut, + }; + NodeType GetType() const override; void Visit(ExpressionVisitor& visitor) override; std::vector parameters; + std::vector parametersSemantic; ExpressionPtr targetFunction; }; @@ -361,6 +369,7 @@ namespace nzsl::Ast struct Parameter { + CallFunctionExpression::ParameterSemantic semantic; std::optional varIndex; std::string name; ExpressionValue type; @@ -508,6 +517,8 @@ namespace nzsl::Ast StatementPtr body; }; + std::string_view ToString(CallFunctionExpression::ParameterSemantic attributeType); + #define NZSL_SHADERAST_NODE(X, C) using X##C##Ptr = std::unique_ptr; #include diff --git a/include/NZSL/Lang/ErrorList.hpp b/include/NZSL/Lang/ErrorList.hpp index 0798fe9..4984ca0 100644 --- a/include/NZSL/Lang/ErrorList.hpp +++ b/include/NZSL/Lang/ErrorList.hpp @@ -45,6 +45,7 @@ NZSL_SHADERLANG_PARSER_ERROR(DuplicateModule, "duplicate module") NZSL_SHADERLANG_PARSER_ERROR(InvalidVersion, "\"{}\" is not a valid version", std::string) NZSL_SHADERLANG_PARSER_ERROR(MissingAttribute, "missing attribute {}", Ast::AttributeType) NZSL_SHADERLANG_PARSER_ERROR(ModuleFeatureMultipleUnique, "module feature {} has already been specified", Ast::ModuleFeature) +NZSL_SHADERLANG_PARSER_ERROR(FunctionParameterNonLValue, "non-L-value cannot be passed for parameter #{}", std::size_t) NZSL_SHADERLANG_PARSER_ERROR(ReservedKeyword, "reserved keyword") NZSL_SHADERLANG_PARSER_ERROR(UnknownAttribute, "unknown attribute \"{}\"", std::string) NZSL_SHADERLANG_PARSER_ERROR(UnknownType, "unknown type") @@ -99,6 +100,7 @@ NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallOutsideOfFunction, "function calls mu NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnexpectedEntryFunction, "{} is an entry function which cannot be called by the program", std::string) NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterCount, "function {} expects {} parameter(s), but got {}", std::string, std::uint32_t, std::uint32_t) NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterType, "function {} parameter #{} type mismatch (expected {}, got {})", std::string, std::uint32_t, std::string, std::string) +NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterSemanticType, "function {} parameter #{} semantic mismatch (expected {}, got {})", std::string, std::uint32_t, std::string, std::string) NZSL_SHADERLANG_COMPILER_ERROR(FunctionDeclarationInsideFunction, "a function cannot be defined inside another function") NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnWithAValue, "return with a value, in function returning no value") NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnWithNoValue, "return with no value, in function returning {}", std::string) diff --git a/include/NZSL/Lang/TokenList.hpp b/include/NZSL/Lang/TokenList.hpp index b796031..5ac3b77 100644 --- a/include/NZSL/Lang/TokenList.hpp +++ b/include/NZSL/Lang/TokenList.hpp @@ -50,6 +50,7 @@ NZSL_SHADERLANG_TOKEN(Identifier) NZSL_SHADERLANG_TOKEN(If) NZSL_SHADERLANG_TOKEN(Import) NZSL_SHADERLANG_TOKEN(In) +NZSL_SHADERLANG_TOKEN(InOut) NZSL_SHADERLANG_TOKEN(LessThan) NZSL_SHADERLANG_TOKEN(LessThanEqual) NZSL_SHADERLANG_TOKEN(Let) @@ -72,6 +73,7 @@ NZSL_SHADERLANG_TOKEN(OpenCurlyBracket) NZSL_SHADERLANG_TOKEN(OpenSquareBracket) NZSL_SHADERLANG_TOKEN(OpenParenthesis) NZSL_SHADERLANG_TOKEN(Option) +NZSL_SHADERLANG_TOKEN(Out) NZSL_SHADERLANG_TOKEN(Return) NZSL_SHADERLANG_TOKEN(Semicolon) NZSL_SHADERLANG_TOKEN(ShiftLeft) diff --git a/include/NZSL/Parser.hpp b/include/NZSL/Parser.hpp index 0d4c745..aacd9b2 100644 --- a/include/NZSL/Parser.hpp +++ b/include/NZSL/Parser.hpp @@ -85,6 +85,7 @@ namespace nzsl Ast::ExpressionPtr ParseConstSelectExpression(); Ast::ExpressionPtr ParseExpression(int exprPrecedence = 0); std::vector ParseExpressionList(TokenType terminationToken, SourceLocation* terminationLocation); + std::vector ParseFunctionExpressionList(std::vector& parametersSemantic, SourceLocation& terminationLocation); Ast::ExpressionPtr ParseExpressionStatement(); Ast::ExpressionPtr ParseFloatingPointExpression(); Ast::ExpressionPtr ParseIdentifier(); diff --git a/include/NZSL/ShaderBuilder.hpp b/include/NZSL/ShaderBuilder.hpp index 9089612..452b971 100644 --- a/include/NZSL/ShaderBuilder.hpp +++ b/include/NZSL/ShaderBuilder.hpp @@ -49,8 +49,8 @@ namespace nzsl::ShaderBuilder struct CallFunction { - inline Ast::CallFunctionExpressionPtr operator()(std::string functionName, std::vector parameters) const; - inline Ast::CallFunctionExpressionPtr operator()(Ast::ExpressionPtr functionExpr, std::vector parameters) const; + inline Ast::CallFunctionExpressionPtr operator()(std::string functionName, std::vector parameters, std::vector parametersSemantic) const; + inline Ast::CallFunctionExpressionPtr operator()(Ast::ExpressionPtr functionExpr, std::vector parameters, std::vector parametersSemantic) const; }; struct Cast diff --git a/include/NZSL/ShaderBuilder.inl b/include/NZSL/ShaderBuilder.inl index ab3f36b..c3a435d 100644 --- a/include/NZSL/ShaderBuilder.inl +++ b/include/NZSL/ShaderBuilder.inl @@ -105,20 +105,22 @@ namespace nzsl::ShaderBuilder return branchNode; } - inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(std::string functionName, std::vector parameters) const + inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(std::string functionName, std::vector parameters, std::vector parametersSemantic) const { auto callFunctionExpression = std::make_unique(); callFunctionExpression->targetFunction = ShaderBuilder::Identifier(std::move(functionName)); callFunctionExpression->parameters = std::move(parameters); + callFunctionExpression->parametersSemantic = std::move(parametersSemantic); return callFunctionExpression; } - inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(Ast::ExpressionPtr functionExpr, std::vector parameters) const + inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(Ast::ExpressionPtr functionExpr, std::vector parameters, std::vector parametersSemantic) const { auto callFunctionExpression = std::make_unique(); callFunctionExpression->targetFunction = std::move(functionExpr); callFunctionExpression->parameters = std::move(parameters); + callFunctionExpression->parametersSemantic = std::move(parametersSemantic); return callFunctionExpression; } diff --git a/src/NZSL/Ast/AstSerializer.cpp b/src/NZSL/Ast/AstSerializer.cpp index e21348b..6ed3ea2 100644 --- a/src/NZSL/Ast/AstSerializer.cpp +++ b/src/NZSL/Ast/AstSerializer.cpp @@ -83,6 +83,10 @@ namespace nzsl::Ast Container(node.parameters); for (auto& param : node.parameters) Node(param); + + Container(node.parametersSemantic); + for (auto& paramAttribute : node.parametersSemantic) + Enum(paramAttribute); } void SerializerBase::Serialize(CallMethodExpression& node) diff --git a/src/NZSL/Ast/Cloner.cpp b/src/NZSL/Ast/Cloner.cpp index 62b0f70..4734925 100644 --- a/src/NZSL/Ast/Cloner.cpp +++ b/src/NZSL/Ast/Cloner.cpp @@ -423,6 +423,10 @@ namespace nzsl::Ast for (auto& parameter : node.parameters) clone->parameters.push_back(CloneExpression(parameter)); + clone->parametersSemantic.reserve(node.parametersSemantic.size()); + for (auto& parameterAttribute : node.parametersSemantic) + clone->parametersSemantic.push_back(parameterAttribute); + clone->cachedExpressionType = node.cachedExpressionType; clone->sourceLocation = node.sourceLocation; diff --git a/src/NZSL/Ast/Nodes.cpp b/src/NZSL/Ast/Nodes.cpp index e8e1e2a..ea4a6c5 100644 --- a/src/NZSL/Ast/Nodes.cpp +++ b/src/NZSL/Ast/Nodes.cpp @@ -12,6 +12,23 @@ namespace nzsl::Ast { Node::~Node() = default; + std::string_view ToString(CallFunctionExpression::ParameterSemantic attributeType) + { + switch (attributeType) + { + case nzsl::Ast::CallFunctionExpression::ParameterSemantic::In: + return "in"; + case nzsl::Ast::CallFunctionExpression::ParameterSemantic::Out: + return "out"; + case nzsl::Ast::CallFunctionExpression::ParameterSemantic::InOut: + return "inout"; + default: + break; + } + + NAZARA_UNREACHABLE(); + } + #define NZSL_SHADERAST_NODE(Node, Category) NodeType Node##Category::GetType() const \ { \ return NodeType:: Node##Category; \ diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index 9f90944..2ee64ea 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -834,6 +834,10 @@ namespace nzsl::Ast for (const auto& parameter : node.parameters) clone->parameters.push_back(CloneExpression(parameter)); + clone->parametersSemantic.reserve(node.parametersSemantic.size()); + for (const auto& parameterAttribute : node.parametersSemantic) + clone->parametersSemantic.push_back(parameterAttribute); + m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); Validate(*clone); @@ -1677,6 +1681,7 @@ NAZARA_WARNING_POP() for (auto& parameter : node.parameters) { auto& cloneParam = clone->parameters.emplace_back(); + cloneParam.semantic = parameter.semantic; cloneParam.name = parameter.name; cloneParam.type = CloneType(parameter.type); cloneParam.varIndex = parameter.varIndex; @@ -4597,6 +4602,9 @@ NAZARA_WARNING_POP() if (ResolveAlias(*parameterType) != ResolveAlias(referenceDeclaration->parameters[i].type.GetResultingValue())) throw CompilerFunctionCallUnmatchingParameterTypeError{ node.parameters[i]->sourceLocation, referenceDeclaration->name, Nz::SafeCast(i), ToString(referenceDeclaration->parameters[i].type.GetResultingValue(), referenceDeclaration->parameters[i].sourceLocation), ToString(*parameterType, node.parameters[i]->sourceLocation)}; + + if (node.parametersSemantic[i] != referenceDeclaration->parameters[i].semantic) + throw CompilerFunctionCallUnmatchingParameterSemanticTypeError{ node.parameters[i]->sourceLocation, referenceDeclaration->name, Nz::SafeCast(i), Ast::ToString(referenceDeclaration->parameters[i].semantic), Ast::ToString(node.parametersSemantic[i]) }; } if (node.parameters.size() != referenceDeclaration->parameters.size()) diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index 11d47c2..901b32c 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -931,6 +931,15 @@ namespace nzsl first = false; + if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::InOut) + { + Append("inout "); + } + else if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::Out) + { + Append("out "); + } + AppendVariableDeclaration(parameter.type.GetResultingValue(), parameter.name); } AppendLine((forward) ? ");" : ")"); diff --git a/src/NZSL/LangWriter.cpp b/src/NZSL/LangWriter.cpp index 28a7efe..a25fde0 100644 --- a/src/NZSL/LangWriter.cpp +++ b/src/NZSL/LangWriter.cpp @@ -1055,6 +1055,15 @@ namespace nzsl if (i != 0) Append(", "); + if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::InOut) + { + Append("inout "); + } + else if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::Out) + { + Append("out "); + } + node.parameters[i]->Visit(*this); } Append(")"); @@ -1429,6 +1438,15 @@ namespace nzsl if (i != 0) Append(", "); + if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::InOut) + { + Append("inout "); + } + else if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::Out) + { + Append("out "); + } + Append(parameter.name); Append(": "); Append(parameter.type); diff --git a/src/NZSL/Lexer.cpp b/src/NZSL/Lexer.cpp index 5ba674c..c5ab3c4 100644 --- a/src/NZSL/Lexer.cpp +++ b/src/NZSL/Lexer.cpp @@ -38,9 +38,11 @@ namespace nzsl { "if", TokenType::If }, { "import", TokenType::Import }, { "in", TokenType::In }, + { "inout", TokenType::InOut }, { "let", TokenType::Let }, { "module", TokenType::Module }, { "option", TokenType::Option }, + { "out", TokenType::Out }, { "return", TokenType::Return }, { "struct", TokenType::Struct }, { "true", TokenType::BoolTrue }, diff --git a/src/NZSL/Parser.cpp b/src/NZSL/Parser.cpp index 4dbf01f..5831611 100644 --- a/src/NZSL/Parser.cpp +++ b/src/NZSL/Parser.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -997,6 +998,27 @@ namespace nzsl { Ast::DeclareFunctionStatement::Parameter parameter; + const Token& t = Peek(); + if (t.type == TokenType::InOut) + { + Consume(); + parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::InOut; + } + else if (t.type == TokenType::Out) + { + Consume(); + parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::Out; + } + else if (t.type == TokenType::In) + { + Consume(); + parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::In; + } + else + { + parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::In; + } + parameter.name = ParseIdentifierAsName(¶meter.sourceLocation); Expect(Advance(), TokenType::Colon); @@ -1464,10 +1486,12 @@ namespace nzsl // Function call SourceLocation closingLocation; - auto parameters = ParseExpressionList(TokenType::ClosingParenthesis, &closingLocation); + std::vector parametersSemantic; + auto parameters = ParseFunctionExpressionList(parametersSemantic, closingLocation); + Nz::Assert(parameters.size() == parametersSemantic.size()); const SourceLocation& lhsLoc = lhs->sourceLocation; - lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters)); + lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters), std::move(parametersSemantic)); lhs->sourceLocation = SourceLocation::BuildFromTo(lhsLoc, closingLocation); continue; } @@ -1574,6 +1598,50 @@ namespace nzsl return parameters; } + + std::vector Parser::ParseFunctionExpressionList(std::vector& parametersSemantic, SourceLocation& terminationLocation) + { + std::vector parameters; + bool first = true; + size_t parameterIndex = 0; + while (Peek().type != TokenType::ClosingParenthesis) + { + if (!first) + Expect(Advance(), TokenType::Comma); + + TokenType tokenType = Peek().type; + if (tokenType == TokenType::InOut || tokenType == TokenType::Out || tokenType == TokenType::In) + { + Consume(); + + Ast::ExpressionPtr expressionPtr = ParseExpression(); + const Ast::ExpressionCategory category = Ast::GetExpressionCategory(*expressionPtr); + if (category != Ast::ExpressionCategory::LValue) + throw ParserFunctionParameterNonLValueError{ expressionPtr->sourceLocation, parameterIndex }; + + parameters.push_back(std::move(expressionPtr)); + if (tokenType == TokenType::InOut) + parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::InOut); + else if (tokenType == TokenType::Out) + parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::Out); + else + parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::In); + } + else + { + parameters.push_back(ParseExpression()); + parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::In); + } + + first = false; + parameterIndex++; + } + + const Token& endToken = Expect(Advance(), TokenType::ClosingParenthesis); + terminationLocation = endToken.location; + + return parameters; + } Ast::ExpressionPtr Parser::ParseExpressionStatement() { diff --git a/src/NZSL/SpirV/SpirvAstVisitor.cpp b/src/NZSL/SpirV/SpirvAstVisitor.cpp index 785188f..ba878d9 100644 --- a/src/NZSL/SpirV/SpirvAstVisitor.cpp +++ b/src/NZSL/SpirV/SpirvAstVisitor.cpp @@ -492,11 +492,15 @@ namespace nzsl Nz::StackArray parameterIds = NazaraStackArrayNoInit(std::uint32_t, node.parameters.size()); for (std::size_t i = 0; i < node.parameters.size(); ++i) { - std::uint32_t resultId = EvaluateExpression(*node.parameters[i]); std::uint32_t varId = m_currentFunc->variables[funcCall.firstVarIndex + i].varId; - m_currentBlock->Append(SpirvOp::OpStore, varId, resultId); - parameterIds[i] = varId; + + //Don't generate OpLoad and OpStore for out arguments + if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::Out) + continue; + + std::uint32_t resultId = EvaluateExpression(*node.parameters[i]); + m_currentBlock->Append(SpirvOp::OpStore, varId, resultId); } HandleSourceLocation(node.sourceLocation); @@ -512,6 +516,18 @@ namespace nzsl appender(parameterIds[i]); }); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + { + //Don't generate OpLoad and OpStore for in arguments + if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::In) + continue; + + std::uint32_t paramResultId = AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpLoad, targetFunc.parameters[i].typeId, paramResultId, parameterIds[i]); + SpirvExpressionStore storeVisitor(m_writer, *this, *m_currentBlock); + storeVisitor.Store(node.parameters[i], paramResultId); + } + PushResultId(resultId); } diff --git a/tests/src/Tests/ErrorsTests.cpp b/tests/src/Tests/ErrorsTests.cpp index f68e4d7..8ca8e57 100644 --- a/tests/src/Tests/ErrorsTests.cpp +++ b/tests/src/Tests/ErrorsTests.cpp @@ -739,6 +739,37 @@ fn test() return 10; } )"), "(7,2 -> 11): CFunctionReturnWithAValue error: return with a value, in function returning no value"); + + CHECK_THROWS_WITH(Compile(R"( +[nzsl_version("1.0")] +module; + +fn Test(inout color: vec3[f32]) +{ + color *= 2.0; +} + +fn main() +{ + let x = vec3[f32](1.0, 1.0, 1.0); + Test(x); +} +)"), "(13, 7): CFunctionCallUnmatchingParameterSemanticType error: function Test parameter #0 semantic mismatch (expected inout, got in)"); + + CHECK_THROWS_WITH(Compile(R"( +[nzsl_version("1.0")] +module; + +fn Test(inout color: vec3[f32]) +{ + color *= 2.0; +} + +fn main() +{ + Test(inout 2.0); +} +)"), "(12,13 -> 15): PFunctionParameterNonLValue error: non-L-value cannot be passed for parameter #0"); } /************************************************************************/ diff --git a/tests/src/Tests/FunctionsTests.cpp b/tests/src/Tests/FunctionsTests.cpp index 34bbb26..d39c354 100644 --- a/tests/src/Tests/FunctionsTests.cpp +++ b/tests/src/Tests/FunctionsTests.cpp @@ -210,6 +210,139 @@ OpFunction OpLabel OpFunctionCall OpReturnValue +OpFunctionEnd)"); + } + + SECTION("inout function call") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct FragOut +{ + [location(0)] value: f32, + [location(1)] value2: f32 +} + +fn Half(inout color: vec3[f32], out value: f32, in inValue: f32, inValue2: f32) +{ + color *= 2.0; + value = 10.0; +} + +[entry(frag)] +fn main() -> FragOut +{ + let output: FragOut; + let mainColor = vec3[f32](1.0, 1.0, 1.0); + let inValue = 2.0; + let inValue2 = 1.0; + Half(inout mainColor, out output.value2, in inValue, inValue2); + output.value = mainColor.x; + + return output; +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + shaderModule = SanitizeModule(*shaderModule); + + ExpectGLSL(*shaderModule, R"( +void Half(inout vec3 color, out float value, float inValue, float inValue2) +{ + color *= 2.0; + value = 10.0; +} + +/*************** Outputs ***************/ +layout(location = 0) out float _nzslOutvalue; +layout(location = 1) out float _nzslOutvalue2; + +void main() +{ + FragOut output_; + vec3 mainColor = vec3(1.0, 1.0, 1.0); + float inValue = 2.0; + float inValue2 = 1.0; + Half(mainColor, output_.value2, inValue, inValue2); + output_.value = mainColor.x; + + _nzslOutvalue = output_.value; + _nzslOutvalue2 = output_.value2; + return; +} +)"); + + ExpectNZSL(*shaderModule, R"( +fn Half(inout color: vec3[f32], out value: f32, inValue: f32, inValue2: f32) +{ + color *= 2.0; + value = 10.0; +} + +[entry(frag)] +fn main() -> FragOut +{ + let output: FragOut; + let mainColor: vec3[f32] = vec3[f32](1.0, 1.0, 1.0); + let inValue: f32 = 2.0; + let inValue2: f32 = 1.0; + Half(inout mainColor, out output.value2, inValue, inValue2); + output.value = mainColor.x; + return output; +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpFunctionParameter +OpFunctionParameter +OpFunctionParameter +OpFunctionParameter +OpLabel +OpLoad +OpVectorTimesScalar +OpStore +OpStore +OpReturn +OpFunctionEnd +OpFunction +OpLabel +OpVariable +OpVariable +OpVariable +OpVariable +OpVariable +OpVariable +OpVariable +OpVariable +OpCompositeConstruct +OpStore +OpStore +OpStore +OpLoad +OpStore +OpLoad +OpStore +OpLoad +OpStore +OpFunctionCall +OpLoad +OpStore +OpLoad +OpAccessChain +OpStore +OpLoad +OpCompositeExtract +OpAccessChain +OpStore +OpLoad +OpCompositeExtract +OpStore +OpCompositeExtract +OpStore +OpReturn OpFunctionEnd)"); } } diff --git a/tests/src/Tests/IdentifierTests.cpp b/tests/src/Tests/IdentifierTests.cpp index b5d2549..1f6fff9 100644 --- a/tests/src/Tests/IdentifierTests.cpp +++ b/tests/src/Tests/IdentifierTests.cpp @@ -39,10 +39,10 @@ fn main() -> output let fl__oa________t = 42.0; - let out: output; - out.active = (f32(input) + fl__oa________t).xxx; + let outValue: output; + outValue.active = (f32(input) + fl__oa________t).xxx; - return out; + return outValue; } )"; @@ -76,14 +76,14 @@ void main() int input_ = int_(); int input2_2 = 0; float fl2_oa8_t = 42.0; - output_ out_; + output_ outValue; float cachedResult = (float(input_)) + fl2_oa8_t; - out_.active_ = vec3(cachedResult, cachedResult, cachedResult); + outValue.active_ = vec3(cachedResult, cachedResult, cachedResult); - _nzslOutactive_ = out_.active_; - _nzslOutactive2_2 = out_.active2_2; - _nzslOut_ = out_._; - _nzslOut_2_2 = out_._2_2; + _nzslOutactive_ = outValue.active_; + _nzslOutactive2_2 = outValue.active2_2; + _nzslOut_ = outValue._; + _nzslOut_2_2 = outValue._2_2; return; } )"); @@ -116,9 +116,9 @@ fn main() -> output let input: i32 = int(); let input_: i32 = 0; let fl__oa________t: f32 = 42.0; - let out: output; - out.active = ((f32(input)) + fl__oa________t).xxx; - return out; + let outValue: output; + outValue.active = ((f32(input)) + fl__oa________t).xxx; + return outValue; } )");