Skip to content

Commit

Permalink
taco: add parser support for windowing/striding/index sets
Browse files Browse the repository at this point in the history
Fixes tensor-compiler#413.

This commit adds support for the command line tool to accept index
expressions containing windowing, striding and index sets. An example
of each of these features is added to the help message of taco:
```
taco "a(i) = b(i(1, 5))" -d=a:4          # Slice b[1:4]
taco "a(i) = b(i(1, 5, 2))" -d=a:2       # Slice b[1:4:2]
taco "a(i) = b(i({1, 3, 5, 7}))" -d=a:4  # Slice b[[1, 3, 5, 7]]
```
  • Loading branch information
rohany authored and Infinoid committed Mar 23, 2021
1 parent ceeabe4 commit 06ea8e3
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 14 deletions.
10 changes: 7 additions & 3 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ Multi multi(IndexStmt stmt1, IndexStmt stmt2);
class IndexVarInterface {
public:
virtual ~IndexVarInterface() = default;
virtual IndexVar getIndexVar() const = 0;

/// match performs a dynamic case analysis of the implementers of IndexVarInterface
/// as a utility for handling the different values within. It mimics the dynamic
Expand Down Expand Up @@ -912,7 +913,7 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index
~WindowedIndexVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;
IndexVar getIndexVar() const override;

/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
/// this index variable.
Expand Down Expand Up @@ -940,7 +941,7 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
~IndexSetVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;
IndexVar getIndexVar() const override;
/// getIndexSet returns the index set.
const std::vector<int>& getIndexSet() const;

Expand All @@ -957,6 +958,9 @@ class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
~IndexVar() = default;
IndexVar(const std::string& name);

// getIndexVar implements the IndexVarInterface.
IndexVar getIndexVar() const override { return *this; }

/// Returns the name of the index variable.
std::string getName() const;

Expand All @@ -967,7 +971,7 @@ class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
WindowedIndexVar operator()(int lo, int hi, int stride = 1);

/// Indexing into an IndexVar with a vector returns an index set into it.
IndexSetVar operator()(std::vector<int> indexSet);
IndexSetVar operator()(std::vector<int>&& indexSet);
IndexSetVar operator()(std::vector<int>& indexSet);

private:
Expand Down
8 changes: 6 additions & 2 deletions include/taco/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace taco {
class TensorBase;
class Format;
class IndexVar;
class IndexVarInterface;
class IndexExpr;
class Access;

Expand Down Expand Up @@ -88,10 +89,13 @@ class Parser : public util::Uncopyable {
Access parseAccess();

/// varlist ::= var {, var}
std::vector<IndexVar> parseVarList();
std::vector<std::shared_ptr<IndexVarInterface>> parseVarList();

/// var ::= identifier
IndexVar parseVar();
/// | identifier '(' int ',' int ')' -- Windowed access.
/// | identifier '(' int ',' int ',' int ')' -- Windowed access with a stride.
/// | identifier '(' '{' int, ... '}' ')' -- Access with an index set.
std::shared_ptr<IndexVarInterface> parseVar();

std::string currentTokenString();

Expand Down
2 changes: 1 addition & 1 deletion src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1978,7 +1978,7 @@ WindowedIndexVar IndexVar::operator()(int lo, int hi, int stride) {
return WindowedIndexVar(*this, lo, hi, stride);
}

IndexSetVar IndexVar::operator()(std::vector<int> indexSet) {
IndexSetVar IndexVar::operator()(std::vector<int>&& indexSet) {
return IndexSetVar(*this, indexSet);
}

Expand Down
79 changes: 72 additions & 7 deletions src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ Access Parser::parseAccess() {
consume(Token::identifier);
names.push_back(tensorName);

vector<IndexVar> varlist;
vector<std::shared_ptr<IndexVarInterface>> varlist;
if (content->currentToken == Token::underscore) {
consume(Token::underscore);
if (content->currentToken == Token::lcurly) {
Expand Down Expand Up @@ -322,8 +322,8 @@ Access Parser::parseAccess() {
if (util::contains(content->tensorDimensions, tensorName)) {
tensorDimensions[i] = content->tensorDimensions.at(tensorName)[i];
}
else if (util::contains(content->indexVarDimensions, varlist[i])) {
tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]);
else if (util::contains(content->indexVarDimensions, varlist[i]->getIndexVar())) {
tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]->getIndexVar());
}
else {
tensorDimensions[i] = content->defaultDimension;
Expand All @@ -347,8 +347,8 @@ Access Parser::parseAccess() {
return tensor(varlist);
}

vector<IndexVar> Parser::parseVarList() {
vector<IndexVar> varlist;
vector<std::shared_ptr<IndexVarInterface>> Parser::parseVarList() {
vector<std::shared_ptr<IndexVarInterface>> varlist;
varlist.push_back(parseVar());
while (content->currentToken == Token::comma) {
consume(Token::comma);
Expand All @@ -357,13 +357,78 @@ vector<IndexVar> Parser::parseVarList() {
return varlist;
}

IndexVar Parser::parseVar() {
std::shared_ptr<IndexVarInterface> Parser::parseVar() {
if (content->currentToken != Token::identifier) {
throw ParseError("Expected index variable");
}
IndexVar var = getIndexVar(content->lexer.getIdentifier());
consume(Token::identifier);
return var;
// If there is a paren after this identifier, then we may have a window
// or index set access.
if (this->content->currentToken == Token::lparen) {
this->consume(Token::lparen);
switch (this->content->currentToken) {
case Token::int_scalar: {
// In this case, we have a window or strided window. Start off by
// parsing the lo and hi of the window.
int lo, hi;
// Parse out lo.
std::istringstream value(this->content->lexer.getIdentifier());
value >> lo;
this->consume(Token::int_scalar);

// Parse the comma.
this->consume(Token::comma);

// Parse out hi.
value = std::istringstream(this->content->lexer.getIdentifier());
value >> hi;
this->consume(Token::int_scalar);

// Now, there might be the stride. If there is another comma, then there
// is a stride value to parse. Otherwise, it's just the window of (lo, hi).
if (this->content->currentToken == Token::comma) {
this->consume(Token::comma);
int stride;
value = std::istringstream(this->content->lexer.getIdentifier());
value >> stride;
this->consume(Token::int_scalar);
this->consume(Token::rparen);
return std::make_shared<WindowedIndexVar>(var(lo, hi, stride));
} else {
this->consume(Token::rparen);
return std::make_shared<WindowedIndexVar>(var(lo, hi));
}
}
case Token::lcurly: {
// If we see a curly brace, then an index set is being applied to the
// IndexVar. So, we'll parse a list of integers.
this->consume(Token::lcurly);
std::vector<int> indexSet;
bool first = true;
do {
// If this isn't the first iteration of the loop, consume a comma.
if (!first) {
this->consume(Token::comma);
}
first = false;
// Parse and consume the next integer.
std::istringstream value(this->content->lexer.getIdentifier());
int index;
value >> index;
indexSet.push_back(index);
this->consume(Token::int_scalar);
// Break when we hit a '}' to end the list.
} while (this->content->currentToken != Token::rcurly);
this->consume(Token::rcurly);
this->consume(Token::rparen);
return std::make_shared<IndexSetVar>(var(indexSet));
}
default:
throw ParseError("Expected windowing expression.");
}
}
return std::make_shared<IndexVar>(var);
}

bool Parser::hasIndexVar(std::string name) const {
Expand Down
2 changes: 1 addition & 1 deletion src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ struct AccessTensorNode : public AccessNode {
// Ensure that it has at most dim(t, i) elements.
taco_uassert(indexSet.size() <= size_t(tensor.getDimension(i)));
// Pack up the index set into a sparse tensor.
TensorBase indexSetTensor(tensor.getComponentType(), {int(indexSet.size())}, Compressed);
Tensor<int> indexSetTensor({int(indexSet.size())}, Compressed);
for (auto& coord : indexSet) {
indexSetTensor.insert({coord}, 1);
}
Expand Down
3 changes: 3 additions & 0 deletions tools/taco.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ static void printUsageInfo() {
cout << " taco \"a(i) = b(i) + c(i)\" -f=b:s -f=c:s -f=a:s # Sparse vector add" << endl;
cout << " taco \"a(i) = B(i,j) * c(j)\" -f=B:ds # SpMV" << endl;
cout << " taco \"A(i,l) = B(i,j,k) * C(j,l) * D(k,l)\" -f=B:sss # MTTKRP" << endl;
cout << " taco \"a(i) = b(i(1, 5))\" -d=a:4 # Slice b[1:4]" << endl;
cout << " taco \"a(i) = b(i(1, 5, 2))\" -d=a:2 # Slice b[1:4:2]" << endl;
cout << " taco \"a(i) = b(i({1, 3, 5, 7}))\" -d=a:4 # Slice b[[1, 3, 5, 7]]" << endl;
cout << endl;
cout << "Options:" << endl;
printFlag("d=<var/tensor>:<size>",
Expand Down

0 comments on commit 06ea8e3

Please sign in to comment.