Skip to content

Commit

Permalink
Add support for out and inout parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
NiiRoZz authored and SirLynix committed Dec 3, 2024
1 parent 4f87848 commit 7fb913a
Show file tree
Hide file tree
Showing 19 changed files with 355 additions and 21 deletions.
6 changes: 6 additions & 0 deletions include/NZSL/Ast/Compare.inl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ namespace nzsl::Ast

inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs, const ComparisonParams& params)
{
if (!Compare(lhs.semantic, rhs.semantic, params))
return false;

if (!Compare(lhs.name, rhs.name, params))
return false;

Expand Down Expand Up @@ -388,6 +391,9 @@ namespace nzsl::Ast
if (!Compare(lhs.parameters, rhs.parameters, params))
return false;

if (!Compare(lhs.parametersSemantic, rhs.parametersSemantic, params))
return false;

return true;
}

Expand Down
11 changes: 11 additions & 0 deletions include/NZSL/Ast/Nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,18 @@ namespace nzsl::Ast

struct NZSL_API CallFunctionExpression : Expression
{
enum class ParameterSemantic : Nz::UInt8
{
In,
Out,
InOut,
};

NodeType GetType() const override;
void Visit(ExpressionVisitor& visitor) override;

std::vector<ExpressionPtr> parameters;
std::vector<ParameterSemantic> parametersSemantic;
ExpressionPtr targetFunction;
};

Expand Down Expand Up @@ -361,6 +369,7 @@ namespace nzsl::Ast

struct Parameter
{
CallFunctionExpression::ParameterSemantic semantic;
std::optional<std::size_t> varIndex;
std::string name;
ExpressionValue<ExpressionType> type;
Expand Down Expand Up @@ -508,6 +517,8 @@ namespace nzsl::Ast
StatementPtr body;
};

std::string_view ToString(CallFunctionExpression::ParameterSemantic attributeType);

#define NZSL_SHADERAST_NODE(X, C) using X##C##Ptr = std::unique_ptr<X##C>;

#include <NZSL/Ast/NodeList.hpp>
Expand Down
2 changes: 2 additions & 0 deletions include/NZSL/Lang/ErrorList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ NZSL_SHADERLANG_PARSER_ERROR(DuplicateModule, "duplicate module")
NZSL_SHADERLANG_PARSER_ERROR(InvalidVersion, "\"{}\" is not a valid version", std::string)
NZSL_SHADERLANG_PARSER_ERROR(MissingAttribute, "missing attribute {}", Ast::AttributeType)
NZSL_SHADERLANG_PARSER_ERROR(ModuleFeatureMultipleUnique, "module feature {} has already been specified", Ast::ModuleFeature)
NZSL_SHADERLANG_PARSER_ERROR(FunctionParameterNonLValue, "non-L-value cannot be passed for parameter #{}", std::size_t)
NZSL_SHADERLANG_PARSER_ERROR(ReservedKeyword, "reserved keyword")
NZSL_SHADERLANG_PARSER_ERROR(UnknownAttribute, "unknown attribute \"{}\"", std::string)
NZSL_SHADERLANG_PARSER_ERROR(UnknownType, "unknown type")
Expand Down Expand Up @@ -99,6 +100,7 @@ NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallOutsideOfFunction, "function calls mu
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnexpectedEntryFunction, "{} is an entry function which cannot be called by the program", std::string)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterCount, "function {} expects {} parameter(s), but got {}", std::string, std::uint32_t, std::uint32_t)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterType, "function {} parameter #{} type mismatch (expected {}, got {})", std::string, std::uint32_t, std::string, std::string)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterSemanticType, "function {} parameter #{} semantic mismatch (expected {}, got {})", std::string, std::uint32_t, std::string, std::string)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionDeclarationInsideFunction, "a function cannot be defined inside another function")
NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnWithAValue, "return with a value, in function returning no value")
NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnWithNoValue, "return with no value, in function returning {}", std::string)
Expand Down
2 changes: 2 additions & 0 deletions include/NZSL/Lang/TokenList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ NZSL_SHADERLANG_TOKEN(Identifier)
NZSL_SHADERLANG_TOKEN(If)
NZSL_SHADERLANG_TOKEN(Import)
NZSL_SHADERLANG_TOKEN(In)
NZSL_SHADERLANG_TOKEN(InOut)
NZSL_SHADERLANG_TOKEN(LessThan)
NZSL_SHADERLANG_TOKEN(LessThanEqual)
NZSL_SHADERLANG_TOKEN(Let)
Expand All @@ -72,6 +73,7 @@ NZSL_SHADERLANG_TOKEN(OpenCurlyBracket)
NZSL_SHADERLANG_TOKEN(OpenSquareBracket)
NZSL_SHADERLANG_TOKEN(OpenParenthesis)
NZSL_SHADERLANG_TOKEN(Option)
NZSL_SHADERLANG_TOKEN(Out)
NZSL_SHADERLANG_TOKEN(Return)
NZSL_SHADERLANG_TOKEN(Semicolon)
NZSL_SHADERLANG_TOKEN(ShiftLeft)
Expand Down
1 change: 1 addition & 0 deletions include/NZSL/Parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ namespace nzsl
Ast::ExpressionPtr ParseConstSelectExpression();
Ast::ExpressionPtr ParseExpression(int exprPrecedence = 0);
std::vector<Ast::ExpressionPtr> ParseExpressionList(TokenType terminationToken, SourceLocation* terminationLocation);
std::vector<Ast::ExpressionPtr> ParseFunctionExpressionList(std::vector<Ast::CallFunctionExpression::ParameterSemantic>& parametersSemantic, SourceLocation& terminationLocation);
Ast::ExpressionPtr ParseExpressionStatement();
Ast::ExpressionPtr ParseFloatingPointExpression();
Ast::ExpressionPtr ParseIdentifier();
Expand Down
4 changes: 2 additions & 2 deletions include/NZSL/ShaderBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ namespace nzsl::ShaderBuilder

struct CallFunction
{
inline Ast::CallFunctionExpressionPtr operator()(std::string functionName, std::vector<Ast::ExpressionPtr> parameters) const;
inline Ast::CallFunctionExpressionPtr operator()(Ast::ExpressionPtr functionExpr, std::vector<Ast::ExpressionPtr> parameters) const;
inline Ast::CallFunctionExpressionPtr operator()(std::string functionName, std::vector<Ast::ExpressionPtr> parameters, std::vector<Ast::CallFunctionExpression::ParameterSemantic> parametersSemantic) const;
inline Ast::CallFunctionExpressionPtr operator()(Ast::ExpressionPtr functionExpr, std::vector<Ast::ExpressionPtr> parameters, std::vector<Ast::CallFunctionExpression::ParameterSemantic> parametersSemantic) const;
};

struct Cast
Expand Down
6 changes: 4 additions & 2 deletions include/NZSL/ShaderBuilder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,22 @@ namespace nzsl::ShaderBuilder
return branchNode;
}

inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(std::string functionName, std::vector<Ast::ExpressionPtr> parameters) const
inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(std::string functionName, std::vector<Ast::ExpressionPtr> parameters, std::vector<Ast::CallFunctionExpression::ParameterSemantic> parametersSemantic) const
{
auto callFunctionExpression = std::make_unique<Ast::CallFunctionExpression>();
callFunctionExpression->targetFunction = ShaderBuilder::Identifier(std::move(functionName));
callFunctionExpression->parameters = std::move(parameters);
callFunctionExpression->parametersSemantic = std::move(parametersSemantic);

return callFunctionExpression;
}

inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(Ast::ExpressionPtr functionExpr, std::vector<Ast::ExpressionPtr> parameters) const
inline Ast::CallFunctionExpressionPtr Impl::CallFunction::operator()(Ast::ExpressionPtr functionExpr, std::vector<Ast::ExpressionPtr> parameters, std::vector<Ast::CallFunctionExpression::ParameterSemantic> parametersSemantic) const
{
auto callFunctionExpression = std::make_unique<Ast::CallFunctionExpression>();
callFunctionExpression->targetFunction = std::move(functionExpr);
callFunctionExpression->parameters = std::move(parameters);
callFunctionExpression->parametersSemantic = std::move(parametersSemantic);

return callFunctionExpression;
}
Expand Down
4 changes: 4 additions & 0 deletions src/NZSL/Ast/AstSerializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ namespace nzsl::Ast
Container(node.parameters);
for (auto& param : node.parameters)
Node(param);

Container(node.parametersSemantic);
for (auto& paramAttribute : node.parametersSemantic)
Enum(paramAttribute);
}

void SerializerBase::Serialize(CallMethodExpression& node)
Expand Down
4 changes: 4 additions & 0 deletions src/NZSL/Ast/Cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ namespace nzsl::Ast
for (auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));

clone->parametersSemantic.reserve(node.parametersSemantic.size());
for (auto& parameterAttribute : node.parametersSemantic)
clone->parametersSemantic.push_back(parameterAttribute);

clone->cachedExpressionType = node.cachedExpressionType;
clone->sourceLocation = node.sourceLocation;

Expand Down
17 changes: 17 additions & 0 deletions src/NZSL/Ast/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ namespace nzsl::Ast
{
Node::~Node() = default;

std::string_view ToString(CallFunctionExpression::ParameterSemantic attributeType)
{
switch (attributeType)
{
case nzsl::Ast::CallFunctionExpression::ParameterSemantic::In:
return "in";
case nzsl::Ast::CallFunctionExpression::ParameterSemantic::Out:
return "out";
case nzsl::Ast::CallFunctionExpression::ParameterSemantic::InOut:
return "inout";
default:
break;
}

NAZARA_UNREACHABLE();
}

#define NZSL_SHADERAST_NODE(Node, Category) NodeType Node##Category::GetType() const \
{ \
return NodeType:: Node##Category; \
Expand Down
8 changes: 8 additions & 0 deletions src/NZSL/Ast/SanitizeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,10 @@ namespace nzsl::Ast
for (const auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));

clone->parametersSemantic.reserve(node.parametersSemantic.size());
for (const auto& parameterAttribute : node.parametersSemantic)
clone->parametersSemantic.push_back(parameterAttribute);

m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex);

Validate(*clone);
Expand Down Expand Up @@ -1677,6 +1681,7 @@ NAZARA_WARNING_POP()
for (auto& parameter : node.parameters)
{
auto& cloneParam = clone->parameters.emplace_back();
cloneParam.semantic = parameter.semantic;
cloneParam.name = parameter.name;
cloneParam.type = CloneType(parameter.type);
cloneParam.varIndex = parameter.varIndex;
Expand Down Expand Up @@ -4597,6 +4602,9 @@ NAZARA_WARNING_POP()

if (ResolveAlias(*parameterType) != ResolveAlias(referenceDeclaration->parameters[i].type.GetResultingValue()))
throw CompilerFunctionCallUnmatchingParameterTypeError{ node.parameters[i]->sourceLocation, referenceDeclaration->name, Nz::SafeCast<std::uint32_t>(i), ToString(referenceDeclaration->parameters[i].type.GetResultingValue(), referenceDeclaration->parameters[i].sourceLocation), ToString(*parameterType, node.parameters[i]->sourceLocation)};

if (node.parametersSemantic[i] != referenceDeclaration->parameters[i].semantic)
throw CompilerFunctionCallUnmatchingParameterSemanticTypeError{ node.parameters[i]->sourceLocation, referenceDeclaration->name, Nz::SafeCast<std::uint32_t>(i), Ast::ToString(referenceDeclaration->parameters[i].semantic), Ast::ToString(node.parametersSemantic[i]) };
}

if (node.parameters.size() != referenceDeclaration->parameters.size())
Expand Down
9 changes: 9 additions & 0 deletions src/NZSL/GlslWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,15 @@ namespace nzsl

first = false;

if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::InOut)
{
Append("inout ");
}
else if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::Out)
{
Append("out ");
}

AppendVariableDeclaration(parameter.type.GetResultingValue(), parameter.name);
}
AppendLine((forward) ? ");" : ")");
Expand Down
18 changes: 18 additions & 0 deletions src/NZSL/LangWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,15 @@ namespace nzsl
if (i != 0)
Append(", ");

if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::InOut)
{
Append("inout ");
}
else if (node.parametersSemantic[i] == Ast::CallFunctionExpression::ParameterSemantic::Out)
{
Append("out ");
}

node.parameters[i]->Visit(*this);
}
Append(")");
Expand Down Expand Up @@ -1429,6 +1438,15 @@ namespace nzsl
if (i != 0)
Append(", ");

if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::InOut)
{
Append("inout ");
}
else if (parameter.semantic == Ast::CallFunctionExpression::ParameterSemantic::Out)
{
Append("out ");
}

Append(parameter.name);
Append(": ");
Append(parameter.type);
Expand Down
2 changes: 2 additions & 0 deletions src/NZSL/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ namespace nzsl
{ "if", TokenType::If },
{ "import", TokenType::Import },
{ "in", TokenType::In },
{ "inout", TokenType::InOut },
{ "let", TokenType::Let },
{ "module", TokenType::Module },
{ "option", TokenType::Option },
{ "out", TokenType::Out },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue },
Expand Down
72 changes: 70 additions & 2 deletions src/NZSL/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <NZSL/ShaderBuilder.hpp>
#include <NZSL/Lang/Errors.hpp>
#include <NZSL/Lang/LangData.hpp>
#include <NZSL/Ast/Utils.hpp>
#include <frozen/string.h>
#include <frozen/unordered_map.h>
#include <array>
Expand Down Expand Up @@ -997,6 +998,27 @@ namespace nzsl
{
Ast::DeclareFunctionStatement::Parameter parameter;

const Token& t = Peek();
if (t.type == TokenType::InOut)
{
Consume();
parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::InOut;
}
else if (t.type == TokenType::Out)
{
Consume();
parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::Out;
}
else if (t.type == TokenType::In)
{
Consume();
parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::In;
}
else
{
parameter.semantic = Ast::CallFunctionExpression::ParameterSemantic::In;
}

parameter.name = ParseIdentifierAsName(&parameter.sourceLocation);

Expect(Advance(), TokenType::Colon);
Expand Down Expand Up @@ -1464,10 +1486,12 @@ namespace nzsl

// Function call
SourceLocation closingLocation;
auto parameters = ParseExpressionList(TokenType::ClosingParenthesis, &closingLocation);
std::vector<Ast::CallFunctionExpression::ParameterSemantic> parametersSemantic;
auto parameters = ParseFunctionExpressionList(parametersSemantic, closingLocation);
Nz::Assert(parameters.size() == parametersSemantic.size());

const SourceLocation& lhsLoc = lhs->sourceLocation;
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters), std::move(parametersSemantic));
lhs->sourceLocation = SourceLocation::BuildFromTo(lhsLoc, closingLocation);
continue;
}
Expand Down Expand Up @@ -1574,6 +1598,50 @@ namespace nzsl

return parameters;
}

std::vector<Ast::ExpressionPtr> Parser::ParseFunctionExpressionList(std::vector<Ast::CallFunctionExpression::ParameterSemantic>& parametersSemantic, SourceLocation& terminationLocation)
{
std::vector<Ast::ExpressionPtr> parameters;
bool first = true;
size_t parameterIndex = 0;
while (Peek().type != TokenType::ClosingParenthesis)
{
if (!first)
Expect(Advance(), TokenType::Comma);

TokenType tokenType = Peek().type;
if (tokenType == TokenType::InOut || tokenType == TokenType::Out || tokenType == TokenType::In)
{
Consume();

Ast::ExpressionPtr expressionPtr = ParseExpression();
const Ast::ExpressionCategory category = Ast::GetExpressionCategory(*expressionPtr);
if (category != Ast::ExpressionCategory::LValue)
throw ParserFunctionParameterNonLValueError{ expressionPtr->sourceLocation, parameterIndex };

parameters.push_back(std::move(expressionPtr));
if (tokenType == TokenType::InOut)
parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::InOut);
else if (tokenType == TokenType::Out)
parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::Out);
else
parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::In);
}
else
{
parameters.push_back(ParseExpression());
parametersSemantic.push_back(Ast::CallFunctionExpression::ParameterSemantic::In);
}

first = false;
parameterIndex++;
}

const Token& endToken = Expect(Advance(), TokenType::ClosingParenthesis);
terminationLocation = endToken.location;

return parameters;
}

Ast::ExpressionPtr Parser::ParseExpressionStatement()
{
Expand Down
Loading

0 comments on commit 7fb913a

Please sign in to comment.