diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 81ed1685a..461abc43b 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -877,6 +877,7 @@ Multi multi(IndexStmt stmt1, IndexStmt stmt2); class IndexVarInterface { public: virtual ~IndexVarInterface() = default; + virtual IndexVar getIndexVar() const = 0; /// match performs a dynamic case analysis of the implementers of IndexVarInterface /// as a utility for handling the different values within. It mimics the dynamic @@ -912,7 +913,7 @@ class WindowedIndexVar : public util::Comparable, public Index ~WindowedIndexVar() = default; /// getIndexVar returns the underlying IndexVar. - IndexVar getIndexVar() const; + IndexVar getIndexVar() const override; /// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of /// this index variable. @@ -940,7 +941,7 @@ class IndexSetVar : public util::Comparable, public IndexVarInterfa ~IndexSetVar() = default; /// getIndexVar returns the underlying IndexVar. - IndexVar getIndexVar() const; + IndexVar getIndexVar() const override; /// getIndexSet returns the index set. const std::vector& getIndexSet() const; @@ -957,6 +958,9 @@ class IndexVar : public util::Comparable, public IndexVarInterface { ~IndexVar() = default; IndexVar(const std::string& name); + // getIndexVar implements the IndexVarInterface. + IndexVar getIndexVar() const override { return *this; } + /// Returns the name of the index variable. std::string getName() const; @@ -967,7 +971,7 @@ class IndexVar : public util::Comparable, public IndexVarInterface { WindowedIndexVar operator()(int lo, int hi, int stride = 1); /// Indexing into an IndexVar with a vector returns an index set into it. - IndexSetVar operator()(std::vector indexSet); + IndexSetVar operator()(std::vector&& indexSet); IndexSetVar operator()(std::vector& indexSet); private: diff --git a/include/taco/parser/parser.h b/include/taco/parser/parser.h index 9a3c4cfff..356d6a67d 100644 --- a/include/taco/parser/parser.h +++ b/include/taco/parser/parser.h @@ -14,6 +14,7 @@ namespace taco { class TensorBase; class Format; class IndexVar; +class IndexVarInterface; class IndexExpr; class Access; @@ -88,10 +89,13 @@ class Parser : public util::Uncopyable { Access parseAccess(); /// varlist ::= var {, var} - std::vector parseVarList(); + std::vector> parseVarList(); /// var ::= identifier - IndexVar parseVar(); + /// | identifier '(' int ',' int ')' -- Windowed access. + /// | identifier '(' int ',' int ',' int ')' -- Windowed access with a stride. + /// | identifier '(' '{' int, ... '}' ')' -- Access with an index set. + std::shared_ptr parseVar(); std::string currentTokenString(); diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 26f69676a..e14f746ed 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1978,7 +1978,7 @@ WindowedIndexVar IndexVar::operator()(int lo, int hi, int stride) { return WindowedIndexVar(*this, lo, hi, stride); } -IndexSetVar IndexVar::operator()(std::vector indexSet) { +IndexSetVar IndexVar::operator()(std::vector&& indexSet) { return IndexSetVar(*this, indexSet); } diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 472914c60..a566989fb 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -282,7 +282,7 @@ Access Parser::parseAccess() { consume(Token::identifier); names.push_back(tensorName); - vector varlist; + vector> varlist; if (content->currentToken == Token::underscore) { consume(Token::underscore); if (content->currentToken == Token::lcurly) { @@ -322,8 +322,8 @@ Access Parser::parseAccess() { if (util::contains(content->tensorDimensions, tensorName)) { tensorDimensions[i] = content->tensorDimensions.at(tensorName)[i]; } - else if (util::contains(content->indexVarDimensions, varlist[i])) { - tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]); + else if (util::contains(content->indexVarDimensions, varlist[i]->getIndexVar())) { + tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]->getIndexVar()); } else { tensorDimensions[i] = content->defaultDimension; @@ -347,8 +347,8 @@ Access Parser::parseAccess() { return tensor(varlist); } -vector Parser::parseVarList() { - vector varlist; +vector> Parser::parseVarList() { + vector> varlist; varlist.push_back(parseVar()); while (content->currentToken == Token::comma) { consume(Token::comma); @@ -357,13 +357,78 @@ vector Parser::parseVarList() { return varlist; } -IndexVar Parser::parseVar() { +std::shared_ptr Parser::parseVar() { if (content->currentToken != Token::identifier) { throw ParseError("Expected index variable"); } IndexVar var = getIndexVar(content->lexer.getIdentifier()); consume(Token::identifier); - return var; + // If there is a paren after this identifier, then we may have a window + // or index set access. + if (this->content->currentToken == Token::lparen) { + this->consume(Token::lparen); + switch (this->content->currentToken) { + case Token::int_scalar: { + // In this case, we have a window or strided window. Start off by + // parsing the lo and hi of the window. + int lo, hi; + // Parse out lo. + std::istringstream value(this->content->lexer.getIdentifier()); + value >> lo; + this->consume(Token::int_scalar); + + // Parse the comma. + this->consume(Token::comma); + + // Parse out hi. + value = std::istringstream(this->content->lexer.getIdentifier()); + value >> hi; + this->consume(Token::int_scalar); + + // Now, there might be the stride. If there is another comma, then there + // is a stride value to parse. Otherwise, it's just the window of (lo, hi). + if (this->content->currentToken == Token::comma) { + this->consume(Token::comma); + int stride; + value = std::istringstream(this->content->lexer.getIdentifier()); + value >> stride; + this->consume(Token::int_scalar); + this->consume(Token::rparen); + return std::make_shared(var(lo, hi, stride)); + } else { + this->consume(Token::rparen); + return std::make_shared(var(lo, hi)); + } + } + case Token::lcurly: { + // If we see a curly brace, then an index set is being applied to the + // IndexVar. So, we'll parse a list of integers. + this->consume(Token::lcurly); + std::vector indexSet; + bool first = true; + do { + // If this isn't the first iteration of the loop, consume a comma. + if (!first) { + this->consume(Token::comma); + } + first = false; + // Parse and consume the next integer. + std::istringstream value(this->content->lexer.getIdentifier()); + int index; + value >> index; + indexSet.push_back(index); + this->consume(Token::int_scalar); + // Break when we hit a '}' to end the list. + } while (this->content->currentToken != Token::rcurly); + this->consume(Token::rcurly); + this->consume(Token::rparen); + return std::make_shared(var(indexSet)); + } + default: + throw ParseError("Expected windowing expression."); + } + } + return std::make_shared(var); } bool Parser::hasIndexVar(std::string name) const { diff --git a/src/tensor.cpp b/src/tensor.cpp index 234085f51..b8706de60 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -486,7 +486,7 @@ struct AccessTensorNode : public AccessNode { // Ensure that it has at most dim(t, i) elements. taco_uassert(indexSet.size() <= size_t(tensor.getDimension(i))); // Pack up the index set into a sparse tensor. - TensorBase indexSetTensor(tensor.getComponentType(), {int(indexSet.size())}, Compressed); + Tensor indexSetTensor({int(indexSet.size())}, Compressed); for (auto& coord : indexSet) { indexSetTensor.insert({coord}, 1); } diff --git a/tools/taco.cpp b/tools/taco.cpp index 5a02c2ae1..1fea302fb 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -90,6 +90,9 @@ static void printUsageInfo() { cout << " taco \"a(i) = b(i) + c(i)\" -f=b:s -f=c:s -f=a:s # Sparse vector add" << endl; cout << " taco \"a(i) = B(i,j) * c(j)\" -f=B:ds # SpMV" << endl; cout << " taco \"A(i,l) = B(i,j,k) * C(j,l) * D(k,l)\" -f=B:sss # MTTKRP" << endl; + cout << " taco \"a(i) = b(i(1, 5))\" -d=a:4 # Slice b[1:4]" << endl; + cout << " taco \"a(i) = b(i(1, 5, 2))\" -d=a:2 # Slice b[1:4:2]" << endl; + cout << " taco \"a(i) = b(i({1, 3, 5, 7}))\" -d=a:4 # Slice b[[1, 3, 5, 7]]" << endl; cout << endl; cout << "Options:" << endl; printFlag("d=:",