Skip to content

The Problem of Tensor level optimization

Ashar edited this page Aug 20, 2019 · 2 revisions

RESOLVED: WE HAVE DECIDED TO USE std::variant TO SOLVE THIS PROBLEM.

PRELUDE

The problem for optimization using a YAP transform is discussed below. I have tried to make it more simple and easy to understand at the same time making it detailed enough so that we could discuss the possible solutions.

In this section we talk about optimization using distributive law, if we can efficiently solve this optimization, We could apply the same to other kind of optimization.

We wish to design a transform as such:

auto optimized_expr = boost::yap::transform(expr, optimizer{});

Since,expr is a yap expression whose type conveys the operations to be carried out, we cannot write a transform that modifies the expr after taking it as reference and returns void. So we can agree that the only way to optimize an expression is to create and returns a new expression.

EFFICIENTLY OPTIMIZING

Optimizing an expression involves comparing operands to decide which operands to take as common and which to add (in case of distribution over addition). This can be a overhead if we decide to check operand equality by their values. After some discussion, we decided to use address equality check for operands rather value check as it is faster and very efficient.

In order to check for the address equality of operands we check them via std::addressof(...), which has a constexpr version which forces us to provide constexpr values as argument. Since operands cannot be constexpr the compiler almost certainly fails to use constexpr version of the function.

BODY OF TRANSFORM: PROBLEM 1

Before we write an actual transform for optimization, we need to write some helper_traits. In order to optimize via distribution we need one trait called is_tensor_tensor_multiply. We defined it as follows:

template <class T>
struct is_tensor_tensor_multiply: std::false_type {};

template <class operandA, class operandB>
struct is_tensor_tensor_multiply< /* Some specialization */ >: std::true_type {};

This trait basically identifies if an tensor expression represents multiplication of two tensor.

Now, we design our YAP transform for the optimization as:

struct optimize_xform{
  using namespace ::boost::yap;
    template <class LExpr, class RExpr>
        decltype(auto) 
        operator()(expr_tag<expr_kind::plus>, LExpr &&lhs, RExpr &&rhs){
        // Match all `+` nodes in AST.
        // lhs is the expression to left of + and rhs is expression to right
        // We can only optimize if both lhs and rhs are multiplication expression
        // here our trait we wrote will come in handy.
        
        if constexpr (is_tensor_tensor_multiply<LExpr> && 
                     is_tensor_tensor_multiply<RExpr>){
            // We have both operands as multiply. We now take the operands by reference
            
            auto& op1 = value(left(lhs));
            auto& op2 = value(right(lhs));
            auto& op3 = value(left(rhs));
            auto& op4 = value(right(rhs));
            
            // We now have to check which are equal and return a new expression
            if(std::addressof(op1) == std::addressof(op3)) return op1*(op2 + op4);
            if(std::addressof(op1) == std::addressof(op4)) return op1*(op2 + op3);
            if(std::addressof(op2) == std::addressof(op3)) return op2*(op1 + op4);
            if(std::addressof(op2) == std::addressof(op4)) return op2*(op1 + op3);
            
            // All these returns are of same type as long as all op's are of same type
            // It is because they all convey the same structure of AST
            // But if none of the above conditions are true and control reaches here
            // It means expression was un-optimizable (because all operands are different)
            //
            // **********************************************************************
            // But here is a problem. op1*op2 + op3*op4 the expression we want to
            // return is not the same type as the above. Hence Compiler Cannot deduce
            // auto type for us.
            // ---------------------------------------------------------------------
            
             return op1*op2 + op3*op4; // (In consistent auto type deduction here)
            
        }
        else{
            // The expression is not something that can be optimized via transform.
            // In this case we recurse to lhs and right side and optimize the expression
            
            auto lexpr = transform(lhs, *this); // Assume perfect forwarding
            auto rexpr = transform(rhs, *this);
            
            // We now have optimized left and right side recusively
            // It is time to merge them and return the expression with plus in between.
            
            return make_expression<expr_kind::plus>(lexpr, rexpr);
        
        } // end else
    } // end operator()
}; // end transform

SOLUTION FOR AUTO DEDUCTION INCONSISTENCY

As mentioned in the code above, we had a problem of auto type deduction. After a lot of thinking, I decided to make the optimize_xforma stateful transform. So we did:

auto xform = optimize_xform{}; // Default construct the transform;
bool b = xform.usable; 
// This is a state of transform, set to false now because transform has not been applied.

// We now apply the transform to some expr
auto new_expr = transform(expr, xform);

bool x = xform.usable;
// x will be true, if expr was optimized and can be used else false;
// before using this new_expr we check, if x is true or false;

if(x) tensor_type v = new_expr; // Use optimized expression
else tensor_type v = expr; // Use old expr

The transform now becomes like this:

struct optimize_xform{
  using namespace ::boost::yap;
    template <class LExpr, class RExpr>
        decltype(auto) 
        operator()(expr_tag<expr_kind::plus>, LExpr &&lhs, RExpr &&rhs){
        if constexpr (is_tensor_tensor_multiply<LExpr> && 
                     is_tensor_tensor_multiply<RExpr>){
            auto& op1 = value(left(lhs));
            auto& op2 = value(right(lhs));
            auto& op3 = value(left(rhs));
            auto& op4 = value(right(rhs));
            
            usable = true;
            
            // Assume that they we could optimize and set the state to true.
            
            if(std::addressof(op1) == std::addressof(op3)) return op1*(op2 + op4);
            if(std::addressof(op1) == std::addressof(op4)) return op1*(op2 + op3);
            if(std::addressof(op2) == std::addressof(op3)) return op2*(op1 + op4);
            if(std::addressof(op2) == std::addressof(op4)) return op2*(op1 + op3);
            
            // If we reach here it means that we cannot optimize, if so we set
            // state to false and return a dummy invalid expression.
            
            usable = false;
            
            // This expression we are returning will not be evaluated or should
            // not be used as it is invalid and random, the flag set to false specifies
            // this state. It is so returned to make sure that return type inside all
            // conditionals are of one type.
            return op1*(op1+op1);
        }
        else{
            // In this case we recurse to lhs and right side and optimize the expression
            std::remove_reference_t<decltype(*this)> a, b;
            // We create a new transform of *this type named a and b;
            // At this moment their usable flag is set to false;
            
            auto lexpr = transform(lhs, a); // Assume perfect forwarding
            auto rexpr = transform(rhs, b);
            
            usable = a.usable && b.usable; // if both sides were optimized, set flag true
            // if any one side failed, flag will be set to false, letting caller know
            // not to use the expression. While this works nicely, it has a problem.
            return make_expression<expr_kind::plus>(lexpr, rexpr);
        
        } // end else
    } // end operator()
    bool usable = false;
}; // end transform

This approach is able to optimize expressions like:

auto expr = (a*b + a*c) + (a*d + d*c);
auto xform = optimize_xform{};
auto new_expr = transform(expr, xform);

if(xform) print(std::cout, new_expr); 
else std::cout<<"Oppsie.. Optimization was not done.";

Will print the optimized expression's AST.

new_expr = (a*(b+c))+(d*(a+c));

There is a reason I had chosen this expression, It is optimized recursively. See top level Plus operator does not have left and right operands that satisfy the trait is_tensor_tensor_muliply. The transform hence recurses to left and right side and optimization is performed on both side. In this case both LHS and RHS could be optimized successfully as there are some common operands on both sides of top level plus.

THE PROBLEM

Notice how we recurse and set the usable true.

auto lexpr = transform(lhs, a); 
auto rexpr = transform(rhs, b);

usable = a.usable && b.usable;

return make_expression<expr_kind::plus>(lexpr, rexpr);

What if only lhs was optimized and that a.usable is true and b.usable isn't true. We have completely regarded the whole expression as unoptimized even though one of the sides was optimized.

This example expression is not well optimized:

auto expr = (a*b + a*c) + (a*b + 2*d);
auto xform = optimize_xform{};
auto new_expr = transform(expr, xform);

if(xform) print(std::cout, new_expr); 
else std::cout<<"Oppsie.. Optimization was not done.";

"Oppsie.. Optimization was not done."

The expression however could have been optimized and new_expr should have value asnew_expr = (a*(b+c) + (a*b + 2*d)).

A possible solution is to fix check for both sides optimization and then build and return the expression. So we should have:

auto lexpr = transform(lhs, a); 
auto rexpr = transform(rhs, b);

usable = a.usable || b.usable; // Okay, either of them optimized

//***************************************************************************
// The problem persists, Notice the type of lhs and lexpr are different
// regardless of whether they were optimized. Same for rexpr and rhs.
//***************************************************************************

if(a.usable && !b.usable) return make_expression<expr_kind::plus>(lexpr, rhs);
if(!a.usable && b.usable) return make_expression<expr_kind::plus>(lhs, rexpr);
return make_expression<expr_kind::plus>(lexpr, rexpr);

// This is also, return type inconsistent deduction as all return a new type. 
// Also `usable` flag cannot be constexp-ed because a stateful transform cannot 
// be constexpr. It rules out the possibilty of branching the return 
// types with if constexpr (...)
Clone this wiki locally