From 06ea8e355e4ea53f6c18f4e33d714f1df4fdf6b3 Mon Sep 17 00:00:00 2001 From: Rohan Yadav Date: Wed, 17 Mar 2021 12:11:12 -0400 Subject: [PATCH] taco: add parser support for windowing/striding/index sets Fixes #413. This commit adds support for the command line tool to accept index expressions containing windowing, striding and index sets. An example of each of these features is added to the help message of taco: ``` taco "a(i) = b(i(1, 5))" -d=a:4 # Slice b[1:4] taco "a(i) = b(i(1, 5, 2))" -d=a:2 # Slice b[1:4:2] taco "a(i) = b(i({1, 3, 5, 7}))" -d=a:4 # Slice b[[1, 3, 5, 7]] ``` --- include/taco/index_notation/index_notation.h | 10 ++- include/taco/parser/parser.h | 8 +- src/index_notation/index_notation.cpp | 2 +- src/parser/parser.cpp | 79 ++++++++++++++++++-- src/tensor.cpp | 2 +- tools/taco.cpp | 3 + 6 files changed, 90 insertions(+), 14 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 81ed1685a..461abc43b 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -877,6 +877,7 @@ Multi multi(IndexStmt stmt1, IndexStmt stmt2); class IndexVarInterface { public: virtual ~IndexVarInterface() = default; + virtual IndexVar getIndexVar() const = 0; /// match performs a dynamic case analysis of the implementers of IndexVarInterface /// as a utility for handling the different values within. It mimics the dynamic @@ -912,7 +913,7 @@ class WindowedIndexVar : public util::Comparable, public Index ~WindowedIndexVar() = default; /// getIndexVar returns the underlying IndexVar. - IndexVar getIndexVar() const; + IndexVar getIndexVar() const override; /// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of /// this index variable. @@ -940,7 +941,7 @@ class IndexSetVar : public util::Comparable, public IndexVarInterfa ~IndexSetVar() = default; /// getIndexVar returns the underlying IndexVar. - IndexVar getIndexVar() const; + IndexVar getIndexVar() const override; /// getIndexSet returns the index set. const std::vector& getIndexSet() const; @@ -957,6 +958,9 @@ class IndexVar : public util::Comparable, public IndexVarInterface { ~IndexVar() = default; IndexVar(const std::string& name); + // getIndexVar implements the IndexVarInterface. + IndexVar getIndexVar() const override { return *this; } + /// Returns the name of the index variable. std::string getName() const; @@ -967,7 +971,7 @@ class IndexVar : public util::Comparable, public IndexVarInterface { WindowedIndexVar operator()(int lo, int hi, int stride = 1); /// Indexing into an IndexVar with a vector returns an index set into it. - IndexSetVar operator()(std::vector indexSet); + IndexSetVar operator()(std::vector&& indexSet); IndexSetVar operator()(std::vector& indexSet); private: diff --git a/include/taco/parser/parser.h b/include/taco/parser/parser.h index 9a3c4cfff..356d6a67d 100644 --- a/include/taco/parser/parser.h +++ b/include/taco/parser/parser.h @@ -14,6 +14,7 @@ namespace taco { class TensorBase; class Format; class IndexVar; +class IndexVarInterface; class IndexExpr; class Access; @@ -88,10 +89,13 @@ class Parser : public util::Uncopyable { Access parseAccess(); /// varlist ::= var {, var} - std::vector parseVarList(); + std::vector> parseVarList(); /// var ::= identifier - IndexVar parseVar(); + /// | identifier '(' int ',' int ')' -- Windowed access. + /// | identifier '(' int ',' int ',' int ')' -- Windowed access with a stride. + /// | identifier '(' '{' int, ... '}' ')' -- Access with an index set. + std::shared_ptr parseVar(); std::string currentTokenString(); diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 26f69676a..e14f746ed 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1978,7 +1978,7 @@ WindowedIndexVar IndexVar::operator()(int lo, int hi, int stride) { return WindowedIndexVar(*this, lo, hi, stride); } -IndexSetVar IndexVar::operator()(std::vector indexSet) { +IndexSetVar IndexVar::operator()(std::vector&& indexSet) { return IndexSetVar(*this, indexSet); } diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 472914c60..a566989fb 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -282,7 +282,7 @@ Access Parser::parseAccess() { consume(Token::identifier); names.push_back(tensorName); - vector varlist; + vector> varlist; if (content->currentToken == Token::underscore) { consume(Token::underscore); if (content->currentToken == Token::lcurly) { @@ -322,8 +322,8 @@ Access Parser::parseAccess() { if (util::contains(content->tensorDimensions, tensorName)) { tensorDimensions[i] = content->tensorDimensions.at(tensorName)[i]; } - else if (util::contains(content->indexVarDimensions, varlist[i])) { - tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]); + else if (util::contains(content->indexVarDimensions, varlist[i]->getIndexVar())) { + tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]->getIndexVar()); } else { tensorDimensions[i] = content->defaultDimension; @@ -347,8 +347,8 @@ Access Parser::parseAccess() { return tensor(varlist); } -vector Parser::parseVarList() { - vector varlist; +vector> Parser::parseVarList() { + vector> varlist; varlist.push_back(parseVar()); while (content->currentToken == Token::comma) { consume(Token::comma); @@ -357,13 +357,78 @@ vector Parser::parseVarList() { return varlist; } -IndexVar Parser::parseVar() { +std::shared_ptr Parser::parseVar() { if (content->currentToken != Token::identifier) { throw ParseError("Expected index variable"); } IndexVar var = getIndexVar(content->lexer.getIdentifier()); consume(Token::identifier); - return var; + // If there is a paren after this identifier, then we may have a window + // or index set access. + if (this->content->currentToken == Token::lparen) { + this->consume(Token::lparen); + switch (this->content->currentToken) { + case Token::int_scalar: { + // In this case, we have a window or strided window. Start off by + // parsing the lo and hi of the window. + int lo, hi; + // Parse out lo. + std::istringstream value(this->content->lexer.getIdentifier()); + value >> lo; + this->consume(Token::int_scalar); + + // Parse the comma. + this->consume(Token::comma); + + // Parse out hi. + value = std::istringstream(this->content->lexer.getIdentifier()); + value >> hi; + this->consume(Token::int_scalar); + + // Now, there might be the stride. If there is another comma, then there + // is a stride value to parse. Otherwise, it's just the window of (lo, hi). + if (this->content->currentToken == Token::comma) { + this->consume(Token::comma); + int stride; + value = std::istringstream(this->content->lexer.getIdentifier()); + value >> stride; + this->consume(Token::int_scalar); + this->consume(Token::rparen); + return std::make_shared(var(lo, hi, stride)); + } else { + this->consume(Token::rparen); + return std::make_shared(var(lo, hi)); + } + } + case Token::lcurly: { + // If we see a curly brace, then an index set is being applied to the + // IndexVar. So, we'll parse a list of integers. + this->consume(Token::lcurly); + std::vector indexSet; + bool first = true; + do { + // If this isn't the first iteration of the loop, consume a comma. + if (!first) { + this->consume(Token::comma); + } + first = false; + // Parse and consume the next integer. + std::istringstream value(this->content->lexer.getIdentifier()); + int index; + value >> index; + indexSet.push_back(index); + this->consume(Token::int_scalar); + // Break when we hit a '}' to end the list. + } while (this->content->currentToken != Token::rcurly); + this->consume(Token::rcurly); + this->consume(Token::rparen); + return std::make_shared(var(indexSet)); + } + default: + throw ParseError("Expected windowing expression."); + } + } + return std::make_shared(var); } bool Parser::hasIndexVar(std::string name) const { diff --git a/src/tensor.cpp b/src/tensor.cpp index 234085f51..b8706de60 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -486,7 +486,7 @@ struct AccessTensorNode : public AccessNode { // Ensure that it has at most dim(t, i) elements. taco_uassert(indexSet.size() <= size_t(tensor.getDimension(i))); // Pack up the index set into a sparse tensor. - TensorBase indexSetTensor(tensor.getComponentType(), {int(indexSet.size())}, Compressed); + Tensor indexSetTensor({int(indexSet.size())}, Compressed); for (auto& coord : indexSet) { indexSetTensor.insert({coord}, 1); } diff --git a/tools/taco.cpp b/tools/taco.cpp index 5a02c2ae1..1fea302fb 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -90,6 +90,9 @@ static void printUsageInfo() { cout << " taco \"a(i) = b(i) + c(i)\" -f=b:s -f=c:s -f=a:s # Sparse vector add" << endl; cout << " taco \"a(i) = B(i,j) * c(j)\" -f=B:ds # SpMV" << endl; cout << " taco \"A(i,l) = B(i,j,k) * C(j,l) * D(k,l)\" -f=B:sss # MTTKRP" << endl; + cout << " taco \"a(i) = b(i(1, 5))\" -d=a:4 # Slice b[1:4]" << endl; + cout << " taco \"a(i) = b(i(1, 5, 2))\" -d=a:2 # Slice b[1:4:2]" << endl; + cout << " taco \"a(i) = b(i({1, 3, 5, 7}))\" -d=a:4 # Slice b[[1, 3, 5, 7]]" << endl; cout << endl; cout << "Options:" << endl; printFlag("d=:",