-
Notifications
You must be signed in to change notification settings - Fork 5
The Problem of Tensor level optimization
RESOLVED: WE HAVE DECIDED TO USE
std::variant
TO SOLVE THIS PROBLEM.
The problem for optimization using a YAP transform is discussed below. I have tried to make it more simple and easy to understand at the same time making it detailed enough so that we could discuss the possible solutions.
In this section we talk about optimization using distributive law, if we can efficiently solve this optimization, We could apply the same to other kind of optimization.
We wish to design a transform as such:
auto optimized_expr = boost::yap::transform(expr, optimizer{});
Since,expr
is a yap expression whose type conveys the operations to be carried out, we cannot write a transform that modifies the expr
after taking it as reference and returns void
. So we can agree that the only way to optimize an expression is to create and returns a new expression.
Optimizing an expression involves comparing operands to decide which operands to take as common and which to add (in case of distribution over addition). This can be a overhead if we decide to check operand equality by their values. After some discussion, we decided to use address equality check for operands rather value check as it is faster and very efficient.
In order to check for the address equality of operands we check them via std::addressof(...)
, which has a constexpr
version which forces us to provide constexpr
values as argument. Since operands cannot be constexpr
the compiler almost certainly fails to use constexpr
version of the function.
Before we write an actual transform for optimization, we need to write some helper_traits
. In order to optimize via distribution we need one trait called is_tensor_tensor_multiply
. We defined it as follows:
template <class T>
struct is_tensor_tensor_multiply: std::false_type {};
template <class operandA, class operandB>
struct is_tensor_tensor_multiply< /* Some specialization */ >: std::true_type {};
This trait basically identifies if an tensor expression represents multiplication of two tensor.
Now, we design our YAP transform for the optimization as:
struct optimize_xform{
using namespace ::boost::yap;
template <class LExpr, class RExpr>
decltype(auto)
operator()(expr_tag<expr_kind::plus>, LExpr &&lhs, RExpr &&rhs){
// Match all `+` nodes in AST.
// lhs is the expression to left of + and rhs is expression to right
// We can only optimize if both lhs and rhs are multiplication expression
// here our trait we wrote will come in handy.
if constexpr (is_tensor_tensor_multiply<LExpr> &&
is_tensor_tensor_multiply<RExpr>){
// We have both operands as multiply. We now take the operands by reference
auto& op1 = value(left(lhs));
auto& op2 = value(right(lhs));
auto& op3 = value(left(rhs));
auto& op4 = value(right(rhs));
// We now have to check which are equal and return a new expression
if(std::addressof(op1) == std::addressof(op3)) return op1*(op2 + op4);
if(std::addressof(op1) == std::addressof(op4)) return op1*(op2 + op3);
if(std::addressof(op2) == std::addressof(op3)) return op2*(op1 + op4);
if(std::addressof(op2) == std::addressof(op4)) return op2*(op1 + op3);
// All these returns are of same type as long as all op's are of same type
// It is because they all convey the same structure of AST
// But if none of the above conditions are true and control reaches here
// It means expression was un-optimizable (because all operands are different)
//
// **********************************************************************
// But here is a problem. op1*op2 + op3*op4 the expression we want to
// return is not the same type as the above. Hence Compiler Cannot deduce
// auto type for us.
// ---------------------------------------------------------------------
return op1*op2 + op3*op4; // (In consistent auto type deduction here)
}
else{
// The expression is not something that can be optimized via transform.
// In this case we recurse to lhs and right side and optimize the expression
auto lexpr = transform(lhs, *this); // Assume perfect forwarding
auto rexpr = transform(rhs, *this);
// We now have optimized left and right side recusively
// It is time to merge them and return the expression with plus in between.
return make_expression<expr_kind::plus>(lexpr, rexpr);
} // end else
} // end operator()
}; // end transform
As mentioned in the code above, we had a problem of auto type deduction. After a lot of thinking, I decided to make the optimize_xform
a stateful transform. So we did:
auto xform = optimize_xform{}; // Default construct the transform;
bool b = xform.usable;
// This is a state of transform, set to false now because transform has not been applied.
// We now apply the transform to some expr
auto new_expr = transform(expr, xform);
bool x = xform.usable;
// x will be true, if expr was optimized and can be used else false;
// before using this new_expr we check, if x is true or false;
if(x) tensor_type v = new_expr; // Use optimized expression
else tensor_type v = expr; // Use old expr
The transform now becomes like this:
struct optimize_xform{
using namespace ::boost::yap;
template <class LExpr, class RExpr>
decltype(auto)
operator()(expr_tag<expr_kind::plus>, LExpr &&lhs, RExpr &&rhs){
if constexpr (is_tensor_tensor_multiply<LExpr> &&
is_tensor_tensor_multiply<RExpr>){
auto& op1 = value(left(lhs));
auto& op2 = value(right(lhs));
auto& op3 = value(left(rhs));
auto& op4 = value(right(rhs));
usable = true;
// Assume that they we could optimize and set the state to true.
if(std::addressof(op1) == std::addressof(op3)) return op1*(op2 + op4);
if(std::addressof(op1) == std::addressof(op4)) return op1*(op2 + op3);
if(std::addressof(op2) == std::addressof(op3)) return op2*(op1 + op4);
if(std::addressof(op2) == std::addressof(op4)) return op2*(op1 + op3);
// If we reach here it means that we cannot optimize, if so we set
// state to false and return a dummy invalid expression.
usable = false;
// This expression we are returning will not be evaluated or should
// not be used as it is invalid and random, the flag set to false specifies
// this state. It is so returned to make sure that return type inside all
// conditionals are of one type.
return op1*(op1+op1);
}
else{
// In this case we recurse to lhs and right side and optimize the expression
std::remove_reference_t<decltype(*this)> a, b;
// We create a new transform of *this type named a and b;
// At this moment their usable flag is set to false;
auto lexpr = transform(lhs, a); // Assume perfect forwarding
auto rexpr = transform(rhs, b);
usable = a.usable && b.usable; // if both sides were optimized, set flag true
// if any one side failed, flag will be set to false, letting caller know
// not to use the expression. While this works nicely, it has a problem.
return make_expression<expr_kind::plus>(lexpr, rexpr);
} // end else
} // end operator()
bool usable = false;
}; // end transform
This approach is able to optimize expressions like:
auto expr = (a*b + a*c) + (a*d + d*c);
auto xform = optimize_xform{};
auto new_expr = transform(expr, xform);
if(xform) print(std::cout, new_expr);
else std::cout<<"Oppsie.. Optimization was not done.";
Will print the optimized expression's AST.
new_expr = (a*(b+c))+(d*(a+c));
There is a reason I had chosen this expression, It is optimized recursively. See top level Plus operator does not have left and right operands that satisfy the trait is_tensor_tensor_muliply
. The transform hence recurses to left and right side and optimization is performed on both side. In this case both LHS and RHS could be optimized successfully as there are some common operands on both sides of top level plus.
Notice how we recurse and set the usable
true.
auto lexpr = transform(lhs, a);
auto rexpr = transform(rhs, b);
usable = a.usable && b.usable;
return make_expression<expr_kind::plus>(lexpr, rexpr);
What if only lhs
was optimized and that a.usable
is true and b.usable
isn't true. We have completely regarded the whole expression as unoptimized even though one of the sides was optimized.
This example expression is not well optimized:
auto expr = (a*b + a*c) + (a*b + 2*d);
auto xform = optimize_xform{};
auto new_expr = transform(expr, xform);
if(xform) print(std::cout, new_expr);
else std::cout<<"Oppsie.. Optimization was not done.";
"Oppsie.. Optimization was not done."
The expression however could have been optimized and new_expr
should have value asnew_expr = (a*(b+c) + (a*b + 2*d))
.
A possible solution is to fix check for both sides optimization and then build and return the expression. So we should have:
auto lexpr = transform(lhs, a);
auto rexpr = transform(rhs, b);
usable = a.usable || b.usable; // Okay, either of them optimized
//***************************************************************************
// The problem persists, Notice the type of lhs and lexpr are different
// regardless of whether they were optimized. Same for rexpr and rhs.
//***************************************************************************
if(a.usable && !b.usable) return make_expression<expr_kind::plus>(lexpr, rhs);
if(!a.usable && b.usable) return make_expression<expr_kind::plus>(lhs, rexpr);
return make_expression<expr_kind::plus>(lexpr, rexpr);
// This is also, return type inconsistent deduction as all return a new type.
// Also `usable` flag cannot be constexp-ed because a stateful transform cannot
// be constexpr. It rules out the possibilty of branching the return
// types with if constexpr (...)
We both would like to thank our mentor Cem for his constant support and help in achieving our goals. We always find him helpful and he was always easy to reach for help or discussion regarding the work. We would also like to thank Google for the Google Summer of Code Programme, without which all these wouldn't be possible. Lastly, we express our gratitude to our parents for helping and providing us with all the indirect or direct needs for carrying out our work nicely in our homes.
- Home
- Project Proposal
- Milestones and Tasks
- Implementation
- Documentation
- Discussions
- Examples
- Experimental features
- Project Proposal
- Milestones and Tasks
- Implementation
- Documentation
- Discussions
- Example Code