diff --git a/include/occa/lang/exprNode.hpp b/include/occa/lang/exprNode.hpp index 55316e3ff..01e8c7046 100644 --- a/include/occa/lang/exprNode.hpp +++ b/include/occa/lang/exprNode.hpp @@ -50,6 +50,7 @@ namespace occa { extern const udim_t string; extern const udim_t identifier; extern const udim_t type; + extern const udim_t vartype; extern const udim_t variable; extern const udim_t function; extern const udim_t value; @@ -275,6 +276,7 @@ namespace occa { virtual void debugPrint(const std::string &prefix) const; }; + // |---[ Type ]-------------------- class typeNode : public exprNode { public: type_t &value; @@ -298,7 +300,35 @@ namespace occa { virtual void debugPrint(const std::string &prefix) const; }; + // |=============================== + // |---[ Vartype ]----------------- + class vartypeNode : public exprNode { + public: + vartype_t value; + + vartypeNode(token_t *token_, + const vartype_t &value_); + + vartypeNode(const vartypeNode& node); + + virtual ~vartypeNode(); + + virtual udim_t type() const; + + virtual exprNode* clone() const; + + virtual void setChildren(exprNodeRefVector &children); + + virtual bool hasAttribute(const std::string &attr) const; + + virtual void print(printer &pout) const; + + virtual void debugPrint(const std::string &prefix) const; + }; + // |=============================== + + // |---[ Variable ]---------------- class variableNode : public exprNode { public: variable_t &value; @@ -324,7 +354,9 @@ namespace occa { virtual void debugPrint(const std::string &prefix) const; }; + // |=============================== + // |---[ Function ]---------------- class functionNode : public exprNode { public: function_t &value; @@ -348,6 +380,7 @@ namespace occa { virtual void debugPrint(const std::string &prefix) const; }; + // |=============================== //================================== //---[ Operators ]------------------ diff --git a/include/occa/lang/parser.hpp b/include/occa/lang/parser.hpp index 12b3dc109..d30a06c78 100644 --- a/include/occa/lang/parser.hpp +++ b/include/occa/lang/parser.hpp @@ -163,7 +163,7 @@ namespace occa { void loadDeclarationAssignment(variableDeclaration &decl); void loadDeclarationBraceInitializer(variableDeclaration &decl); - vartype_t preloadType(); + vartype_t loadType(); void loadBaseType(vartype_t &vartype); diff --git a/include/occa/lang/token.hpp b/include/occa/lang/token.hpp index 084167016..89cdf7130 100644 --- a/include/occa/lang/token.hpp +++ b/include/occa/lang/token.hpp @@ -26,15 +26,14 @@ #include #include +#include namespace occa { namespace lang { class operator_t; class token_t; class qualifier_t; - class type_t; class variable_t; - class function_t; typedef std::vector tokenVector; @@ -75,6 +74,7 @@ namespace occa { extern const int qualifier; extern const int type; + extern const int vartype; extern const int variable; extern const int function; @@ -138,7 +138,8 @@ namespace occa { void debugPrint() const; }; - std::ostream& operator << (std::ostream &out, token_t &token); + std::ostream& operator << (std::ostream &out, + token_t &token); //---[ Unknown ]-------------------- class unknownToken : public token_t { @@ -262,6 +263,24 @@ namespace occa { }; //================================== + //---[ Vartype ]-------------------- + class vartypeToken : public token_t { + public: + vartype_t value; + + vartypeToken(const fileOrigin &origin_, + const vartype_t &value_); + + virtual ~vartypeToken(); + + virtual int type() const; + + virtual token_t* clone() const; + + virtual void print(std::ostream &out) const; + }; + //================================== + //---[ Variable ]------------------- class variableToken : public token_t { public: diff --git a/include/occa/lang/tokenContext.hpp b/include/occa/lang/tokenContext.hpp index 431e265c1..4ca492bed 100644 --- a/include/occa/lang/tokenContext.hpp +++ b/include/occa/lang/tokenContext.hpp @@ -112,10 +112,6 @@ namespace occa { int getNextOperator(const opType_t &opType); - exprNode* getExpression(); - exprNode* getExpression(const int start, - const int end); - void debugPrint(); }; } diff --git a/include/occa/lang/type.hpp b/include/occa/lang/type.hpp index 1101abb42..00b6f61da 100644 --- a/include/occa/lang/type.hpp +++ b/include/occa/lang/type.hpp @@ -128,6 +128,8 @@ namespace occa { void printError(const std::string &message) const; }; + std::ostream& operator << (std::ostream &out, + const type_t &type); printer& operator << (printer &pout, const type_t &type); //================================== @@ -153,6 +155,8 @@ namespace occa { void add(const qualifierWithSource &qualifier); }; + std::ostream& operator << (std::ostream &out, + const pointer_t &pointer); printer& operator << (printer &pout, const pointer_t &pointer); //================================== @@ -181,6 +185,8 @@ namespace occa { void printError(const std::string &message) const; }; + std::ostream& operator << (std::ostream &out, + const array_t &array); printer& operator << (printer &pout, const array_t &array); //================================== @@ -245,6 +251,8 @@ namespace occa { vartype_t& operator += (const array_t &array); vartype_t& operator += (const arrayVector &arrays_); + bool hasAttribute(const std::string &attr) const; + vartype_t declarationType() const; vartype_t flatten() const; @@ -260,6 +268,8 @@ namespace occa { void printError(const std::string &message) const; }; + std::ostream& operator << (std::ostream &out, + const vartype_t &type); printer& operator << (printer &pout, const vartype_t &type); //================================== diff --git a/src/lang/exprNode.cpp b/src/lang/exprNode.cpp index 9ba1fdfc1..88742e9b9 100644 --- a/src/lang/exprNode.cpp +++ b/src/lang/exprNode.cpp @@ -32,54 +32,56 @@ namespace occa { const udim_t string = (1L << 3); const udim_t identifier = (1L << 4); const udim_t type = (1L << 5); - const udim_t variable = (1L << 6); - const udim_t function = (1L << 7); + const udim_t vartype = (1L << 6); + const udim_t variable = (1L << 7); + const udim_t function = (1L << 8); const udim_t value = (primitive | type | + vartype | variable | function); - const udim_t rawOp = (1L << 8); - const udim_t leftUnary = (1L << 9); - const udim_t rightUnary = (1L << 10); - const udim_t binary = (1L << 11); - const udim_t ternary = (1L << 12); + const udim_t rawOp = (1L << 9); + const udim_t leftUnary = (1L << 10); + const udim_t rightUnary = (1L << 11); + const udim_t binary = (1L << 12); + const udim_t ternary = (1L << 13); const udim_t op = (leftUnary | rightUnary | binary | ternary); - const udim_t pair = (1L << 13); + const udim_t pair = (1L << 14); - const udim_t subscript = (1L << 14); - const udim_t call = (1L << 15); + const udim_t subscript = (1L << 15); + const udim_t call = (1L << 16); - const udim_t sizeof_ = (1L << 16); - const udim_t sizeof_pack_ = (1L << 17); - const udim_t new_ = (1L << 18); - const udim_t delete_ = (1L << 19); - const udim_t throw_ = (1L << 20); + const udim_t sizeof_ = (1L << 17); + const udim_t sizeof_pack_ = (1L << 18); + const udim_t new_ = (1L << 19); + const udim_t delete_ = (1L << 20); + const udim_t throw_ = (1L << 21); - const udim_t typeid_ = (1L << 21); - const udim_t noexcept_ = (1L << 22); - const udim_t alignof_ = (1L << 23); + const udim_t typeid_ = (1L << 22); + const udim_t noexcept_ = (1L << 23); + const udim_t alignof_ = (1L << 24); - const udim_t const_cast_ = (1L << 24); - const udim_t dynamic_cast_ = (1L << 25); - const udim_t static_cast_ = (1L << 26); - const udim_t reinterpret_cast_ = (1L << 27); + const udim_t const_cast_ = (1L << 25); + const udim_t dynamic_cast_ = (1L << 26); + const udim_t static_cast_ = (1L << 27); + const udim_t reinterpret_cast_ = (1L << 28); - const udim_t funcCast = (1L << 28); - const udim_t parenCast = (1L << 29); - const udim_t constCast = (1L << 30); - const udim_t staticCast = (1L << 31); - const udim_t reinterpretCast = (1L << 32); - const udim_t dynamicCast = (1L << 33); + const udim_t funcCast = (1L << 29); + const udim_t parenCast = (1L << 30); + const udim_t constCast = (1L << 31); + const udim_t staticCast = (1L << 32); + const udim_t reinterpretCast = (1L << 33); + const udim_t dynamicCast = (1L << 34); - const udim_t parentheses = (1L << 34); - const udim_t tuple = (1L << 35); - const udim_t cudaCall = (1L << 36); + const udim_t parentheses = (1L << 35); + const udim_t tuple = (1L << 36); + const udim_t cudaCall = (1L << 37); } exprNode::exprNode(token_t *token_) : @@ -348,7 +350,7 @@ namespace occa { } // |=============================== - // |---[ Type ]---------------- + // |---[ Type ]-------------------- typeNode::typeNode(token_t *token_, type_t &value_) : exprNode(token_), @@ -387,6 +389,45 @@ namespace occa { } // |=============================== + // |---[ Vartype ]----------------- + vartypeNode::vartypeNode(token_t *token_, + const vartype_t &value_) : + exprNode(token_), + value(value_) {} + + vartypeNode::vartypeNode(const vartypeNode &node) : + exprNode(node.token), + value(node.value) {} + + vartypeNode::~vartypeNode() {} + + udim_t vartypeNode::type() const { + return exprNodeType::vartype; + } + + exprNode* vartypeNode::clone() const { + return new vartypeNode(token, value); + } + + void vartypeNode::setChildren(exprNodeRefVector &children) {} + + bool vartypeNode::hasAttribute(const std::string &attr) const { + return value.hasAttribute(attr); + } + + void vartypeNode::print(printer &pout) const { + pout << value; + } + + void vartypeNode::debugPrint(const std::string &prefix) const { + printer pout(std::cerr); + std::cerr << prefix << "|\n" + << prefix << "|---["; + pout << (*this); + std::cerr << "] (vartype)\n"; + } + // |=============================== + // |---[ Variable ]---------------- variableNode::variableNode(token_t *token_, variable_t &value_) : diff --git a/src/lang/expression.cpp b/src/lang/expression.cpp index dfa8256ce..2cea5e2ee 100644 --- a/src/lang/expression.cpp +++ b/src/lang/expression.cpp @@ -27,6 +27,7 @@ namespace occa { namespace lang { static const int outputTokenType = (tokenType::identifier | tokenType::type | + tokenType::vartype | tokenType::variable | tokenType::function | tokenType::primitive | @@ -287,6 +288,10 @@ namespace occa { typeToken &t = token->to(); state.pushOutput(new typeNode(token, t.value)); } + else if (tokenType & tokenType::vartype) { + vartypeToken &t = token->to(); + state.pushOutput(new vartypeNode(token, t.value)); + } else if (tokenType & tokenType::primitive) { primitiveToken &t = token->to(); state.pushOutput(new primitiveNode(token, t.value)); @@ -388,7 +393,8 @@ namespace occa { } if (pair.opType() & operatorType::parentheses) { - if (pair.value->type() & exprNodeType::type) { + if (pair.value->type() & (exprNodeType::type | + exprNodeType::vartype)) { state.pushOperator( new leftUnaryOpNode(&opToken, op::parenCast, @@ -747,12 +753,20 @@ namespace occa { if (opType & operatorType::parenCast) { leftUnaryOpNode &parenOpNode = (leftUnaryOpNode&) opNode; - type_t &type = ((typeNode*) parenOpNode.value)->value; - state.pushOutput( - new parenCastNode(parenOpNode.token, - type, - value) - ); + exprNode *valueNode = parenOpNode.value; + if (valueNode->type() & exprNodeType::type) { + state.pushOutput( + new parenCastNode(parenOpNode.token, + ((typeNode*) valueNode)->value, + value) + ); + } else { + state.pushOutput( + new parenCastNode(parenOpNode.token, + ((vartypeNode*) valueNode)->value, + value) + ); + } } else if (opType & operatorType::sizeof_) { state.pushOutput( diff --git a/src/lang/parser.cpp b/src/lang/parser.cpp index 5be423d50..ce23c4532 100644 --- a/src/lang/parser.cpp +++ b/src/lang/parser.cpp @@ -292,17 +292,40 @@ namespace occa { keyword_t& parser_t::getKeyword(token_t *token) { static keyword_t noKeyword; + if (!token) { + return noKeyword; + } - if (!(token_t::safeType(token) & tokenType::identifier)) { + const int tType = token->type(); + if (!(tType & (tokenType::identifier | + tokenType::qualifier | + tokenType::type | + tokenType::variable | + tokenType::function))) { return noKeyword; } - std::string &identifier = token->to().value; + std::string identifier; + if (tType & tokenType::identifier) { + identifier = token->to().value; + } + else if (tType & tokenType::qualifier) { + identifier = token->to().qualifier.name; + } + else if (tType & tokenType::type) { + identifier = token->to().value.name(); + } + else if (tType & tokenType::variable) { + identifier = token->to().value.name(); + } + else if (tType & tokenType::function) { + identifier = token->to().value.name(); + } + keywordMapIterator it = keywords.find(identifier); if (it != keywords.end()) { return *(it->second); } - return up->getScopeKeyword(identifier); } @@ -321,16 +344,42 @@ namespace occa { exprNode* parser_t::getExpression(const int start, const int end) { + context.push(start, end); + const int tokenCount = context.size(); + tokenVector tokens; + tokens.reserve(tokenCount); + if (up) { - for (int i = start; i < end; ++i) { + // Replace identifier tokens with keywords if they exist + for (int i = 0; i < tokenCount; ++i) { token_t *token = context[i]; if (token->type() & tokenType::identifier) { - context.setToken(i, - replaceIdentifier((identifierToken&) *token)); + context.setToken(i, replaceIdentifier((identifierToken&) *token)); } } + while (context.size()) { + token_t *token = context[0]; + if (!(token->type() & (tokenType::qualifier | + tokenType::type))) { + context.set(1); + tokens.push_back(token->clone()); + continue; + } + + vartype_t vartype = loadType(); + if (!success) { + context.pop(); + freeTokenVector(tokens); + return NULL; + } + + tokens.push_back(new vartypeToken(token->origin, + vartype)); + } } - exprNode *expr = context.getExpression(start, end); + context.pop(); + + exprNode *expr = occa::lang::getExpression(tokens); success &= !!expr; return expr; } @@ -339,12 +388,17 @@ namespace occa { keyword_t &keyword = getKeyword(&identifier); const int kType = keyword.type(); - if (!(kType & (keywordType::type | - keywordType::variable | + if (!(kType & (keywordType::qualifier | + keywordType::type | + keywordType::variable | keywordType::function))) { return &identifier; } + if (kType & keywordType::qualifier) { + return new qualifierToken(identifier.origin, + ((qualifierKeyword&) keyword).qualifier); + } if (kType & keywordType::variable) { return new variableToken(identifier.origin, ((variableKeyword&) keyword).variable); @@ -629,7 +683,7 @@ namespace occa { attributeTokenMap attrs; loadAttributes(attrs); - vartype_t vartype = preloadType(); + vartype_t vartype = loadType(); variable_t var = (!isLoadingFunctionPointer() ? loadVariable(vartype) : loadFunctionPointer(vartype)); @@ -785,7 +839,7 @@ namespace occa { context.popAndSkip(); } - vartype_t parser_t::preloadType() { + vartype_t parser_t::loadType() { // TODO: Handle weird () cases: // int (*const (*const a)) -> int * const * const a; // int (*const (*const (*a)))() -> int (* const * const *a)(); @@ -983,7 +1037,7 @@ namespace occa { bool parser_t::isLoadingFunction() { context.push(); - vartype_t vartype = preloadType(); + vartype_t vartype = loadType(); if (!success) { context.popAndSkip(); return false; @@ -1381,7 +1435,7 @@ namespace occa { } statement_t* parser_t::loadFunctionStatement(attributeTokenMap &smntAttributes) { - vartype_t returnType = preloadType(); + vartype_t returnType = loadType(); if (!(token_t::safeType(context[0]) & tokenType::identifier)) { context.printError("Expected function name identifier"); diff --git a/src/lang/token.cpp b/src/lang/token.cpp index f70461594..34d10b286 100644 --- a/src/lang/token.cpp +++ b/src/lang/token.cpp @@ -80,18 +80,19 @@ namespace occa { const int qualifier = (1 << 8); const int type = (1 << 9); - const int variable = (1 << 10); - const int function = (1 << 11); + const int vartype = (1 << 10); + const int variable = (1 << 11); + const int function = (1 << 12); - const int primitive = (1 << 12); - const int op = (1 << 13); + const int primitive = (1 << 13); + const int op = (1 << 14); - const int char_ = (1 << 14); - const int string = (1 << 15); - const int withUDF = (1 << 16); + const int char_ = (1 << 15); + const int string = (1 << 16); + const int withUDF = (1 << 17); const int withEncoding = ((encodingType::ux | - encodingType::R) << 17); - const int encodingShift = 17; + encodingType::R) << 18); + const int encodingShift = 18; int getEncoding(const int tokenType) { return ((tokenType & withEncoding) >> encodingShift); @@ -301,6 +302,27 @@ namespace occa { } //================================== + //---[ Vartype ]-------------------- + vartypeToken::vartypeToken(const fileOrigin &origin_, + const vartype_t &value_) : + token_t(origin_), + value(value_) {} + + vartypeToken::~vartypeToken() {} + + int vartypeToken::type() const { + return tokenType::vartype; + } + + token_t* vartypeToken::clone() const { + return new vartypeToken(origin, value); + } + + void vartypeToken::print(std::ostream &out) const { + out << value; + } + //================================== + //---[ Variable ]------------------- variableToken::variableToken(const fileOrigin &origin_, variable_t &variable) : diff --git a/src/lang/tokenContext.cpp b/src/lang/tokenContext.cpp index 749ce0259..810536f32 100644 --- a/src/lang/tokenContext.cpp +++ b/src/lang/tokenContext.cpp @@ -346,23 +346,6 @@ namespace occa { return -1; } - exprNode* tokenContext::getExpression() { - if (tp.start == tp.end) { - return NULL; - } - tokenVector tokens_; - getAndCloneTokens(tokens_); - return occa::lang::getExpression(tokens_); - } - - exprNode* tokenContext::getExpression(const int start, - const int end) { - push(start, end); - exprNode *expr = getExpression(); - pop(); - return expr; - } - void tokenContext::debugPrint() { for (int i = tp.start; i < tp.end; ++i) { std::cout << '[' << *tokens[i] << "]\n"; diff --git a/src/lang/type.cpp b/src/lang/type.cpp index 826d03224..79ce740f6 100644 --- a/src/lang/type.cpp +++ b/src/lang/type.cpp @@ -131,6 +131,13 @@ namespace occa { } } + std::ostream& operator << (std::ostream &out, + const type_t &type) { + printer pout(out); + pout << type; + return out; + } + printer& operator << (printer &pout, const type_t &type) { pout << type.name(); @@ -172,6 +179,13 @@ namespace occa { qualifiers.add(qualifier); } + std::ostream& operator << (std::ostream &out, + const pointer_t &pointer) { + printer pout(out); + pout << pointer; + return out; + } + printer& operator << (printer &pout, const pointer_t &pointer) { pout << '*'; @@ -239,6 +253,13 @@ namespace occa { start->printError(message); } + std::ostream& operator << (std::ostream &out, + const array_t &array) { + printer pout(out); + pout << array; + return out; + } + printer& operator << (printer &pout, const array_t &array) { if (array.size) { @@ -496,6 +517,12 @@ namespace occa { return *this; } + bool vartype_t::hasAttribute(const std::string &attr) const { + return (type + ? type->hasAttribute(attr) + : false); + } + vartype_t vartype_t::declarationType() const { vartype_t other; other.type = type; @@ -598,6 +625,13 @@ namespace occa { } } + std::ostream& operator << (std::ostream &out, + const vartype_t &type) { + printer pout(out); + pout << type; + return out; + } + printer& operator << (printer &pout, const vartype_t &type) { type.printDeclaration(pout, "", true);