Thursday, July 31, 2008

Expression Templates Demystified

Often in course of our programming work, we want to be able to modify or customize part of a function’s logic at the point of invocation. A very simple example is the C++ STL algorithm std::find_if.

template<class InputIterator, class Predicate>
InputIterator find_if (InputIterator first, InputIterator last, Predicate pred)

Here Predicate represents a functor (an object that overloads the function-call operator). In this particular case, the functor should take an element of the container of which first and last are iterators, and return a boolean truth value based on some criteria, possibly as a function of the element passed.

Using find_if, one can find all elements in a container, that meet a particular criteria. What’s remarkable is that, what find_if identifies as matching elements in the container is entirely dependent on the logic encoded in the Predicate functor – not on find_if’s implementation. This makes the functionality of find_if open-ended. What we are able to achieve in the process is a form of polymorphism with a very high degree of type-safety and performance. There is a specific name for such idioms – functional composition. In fact, by combining an arbitrary number of Functors with varied functionality, we can construct a fairly complex piece of functionality. This has a very useful application in developing mini-languages (also called DSELs or Domain Specific Embedded Languages - expressions with a syntactic form alien to C++ syntax, being embedded inside C++ programs as valid code. In this article and the next in this series, we look at the powerful technique called Expression Templates which helps solve these classes of problems.

We have all learned in high school Calculus course about such functions of single variable as:

F(x) = x^3 + 3x - 7

or

G(x) = x.sin x + 6

Representing such simple algebraic functions using functors is not a big hassle. For example, to represent G(x), one could write a functor like the one below:

Show line numbers
 Struct FuncG {
double operator(double x) {
return (x*sin(x) + 6) ;
}
}

Imagine you have a Calculus library written in C++ and you have a function called integrate, that is defined as follows:

Show line numbers
 template<class Func>
double integrate(Func f, double low, double high, double epsilon = 0.001) {
assert(low <= high );
double auc1 = 0, auc2 = 0;
double running_x = low;
while ( running_x < high ) {
auc1 += f(running_x) * epsilon;
running_x += epsilon;
auc2 += f(running_x) * epsilon;
}

return (auc1 + auc2)/2;
}

This is essentially an area under the curve calculation that takes the average of upper bound and lower bound sums.

Now here is the deal: in course of a mathematical programming one could need to integrate dozens of such mathematical expressions. If there are 50 different expressions to be integrated, a program needs to define 50 different functors – and this is where the efficiency and ease of the system breaks down.

Imagine being able to define function objects that can evaluate arbitrary mathematical expressions and being able to pass such expressions to functions like integrate. Something like the following:

Show line numbers
 MathFunction f = x*x + 2*x + 3;
double d = integrate(f, 0, 2);
f = x*x*sin*sin + 2*x*cos*sin + cos*cos;
const double PI = 3.1415926535897;
d = f(PI/2);

This almost looks like magic, doesn’t it – but it isn’t. The complex algebraic expression can, in each case, be engineered to give a functor which evaluates the expression for different values of the function argument. The standard technique or idiom used for enabling such expression building is known as Expression Templates and it is the focus area of this article. But, Expression Templates have a notorious reputation of being difficult to understand and learn and turned off many a learner. Therefore, we would try to build the logic of constructing such expressions intuitively, bit by bit. We will start off with a completely non-template version of the code – an idiom that I discovered for myself and which I call Expression Functor. We would then ‘deduce’ the Expression Template idiom as a special case of this idiom which achieves phenomenal performance improvements and code improvement, using templates.

So here we go.

Breaking down the problem



Let us first break the problem into its basic elements. We begin with a simple polynomial expression:

f(x) = x + 3

Such an expression has two kinds of entities – a variable (x) and a constant (the literal 3). Both these entities are valid expressions by themselves. For example, the straight line going parallel to the x-axis at a distance of 3 units above x-axis is represented by the function f(x) = 3. Similarly, the straight line going through the origin at an angle of 45 degrees to both x and y axes is represented by the function f(x) = x. Modelling these most trivial functions will be our building blocks for non-trivial expression building.

Consider the case of the horizontal line – f(x) = 3. We want to model a function which takes any value of x and always returns 3. The basic function would be:

Show line numbers
 double f(double x) {
return 3;
}


We want this to become an applicative functor like this:

Show line numbers
 struct C3 {
double operator() (double) {
return 3;
}
};

C3 above represents the function f(x) = 3. Now we want to generalize this function to represent any real number – this is fairly straight-forward:

Show line numbers
 struct Constant {
Constant(const double& d) : d_(d) {
}

double operator() (double) {
return d_;
}

const double d_;
};

Such a functor can be used to model any constant expression. C3 would now be equivalent to Constant(3).

Next, let’s try to model the 45 degree straight line through the origin – f(x) = x. Turns out that this is even easier – "given any value x, return that very value" should summarize our function. In short:

Show line numbers
 struct Variable {
double operator() (double x) {
return x;
}
};

These two classes are remarkable on their own, but what happens when we try to model a function like:

f(x) = x+3

What type will be x + 3 – it is not Constant, and it is not an independent variable like Variable either. So we possibly need another class to represent more complex expressions. Come to think of it. Both a Constant and a Variable are each, in themselves, an expression, and so is a more generic expression like x+3. So it would seem logical that we should have an Expression base class of which, Constant, Variable and other Expression classes can be derived classes. We define it like this:

Show line numbers
 struct Expression {
virtual double operator() (double) = 0;
virtual Expression* clone() = 0;
virtual ~Expression() {
}
};

There is a reason for the curious looking Expression* clone() virtual function. I have included this with the benefit of contrived foresight - having implemented this whole solution already. Without getting ahead of the story, let me tell you that it plays a small but important role in the life cycle management of Expressions. It is used to copy Expressions where needed.

Show line numbers
 struct Constant : Expression {
...
Expression* clone() {
return new Constant(*this);
}
};

struct Variable : Expression {
...
Expression* clone() {
return new Variable(*this);
}
};

We would refer to Expressions like Constant and variable as simple expressions. We also need to define a sub-type of Expression to represent all non-trivial (i.e. other than simple) expressions. But before that can happen, we need to understand how multiple expressions can be combined using arithmetic operators. For example, we want to be able to write such expressions as:

Show line numbers
 Variable x;
Expression& e = x*x + 2*x + 1;
double d = e(5); // d == 36
double d2 = integrate(e, 0, 1); // d2 == 2.33

Clearly, we need to overload operators like * and + between Variables (x*x), between Constants and Variables (2*x), and between multiple ComplexExpressions (x*x and 2*x). Moreover, the literals like 1 and 2 are not Constant objects, they are doubles that need to be converted to Constant objects. So, these operators should also be overloaded for double arguments. Clearly, given the fact that all of these (Variable, Constant, ComplexExpression) are different sub-types of Expression, it will be easier if we just overload these operators between Expressions, and between Expressions and doubles.

Now any complex expression can be represented as:

f(x) = u(x) OP v(x)

OP is a binary arithmetic operator. For example:

f(x) = c is a degenerate case where, u(x) = 1, v(x) = c and OP = *. Or perhaps u(x) = 0, v(x) = c and OP = +. Assuming that we have such an operation available for each appropriate arithmetic operation, we write the class for complex expressions as a template class - the template parameter being the Binary Operation.

Show line numbers
 template<class Op>
struct ComplexExpression : Expression {
Expression* l_;
Expression* r_;

ComplexExpression(Expression& l, Expression& r) : l_(l.clone()), r_(r.clone()) {
}

~ComplexExpression() {
delete l_;
delete r_;
}

double operator() (double d) {
return Op::apply( (*l_)(d), (*r_)(d) );
}

Expression* clone() {
return new ComplexExpression(*l_, *r_);
}
};

The binary operators can be easily defined using simple functors like the following:

Show line numbers
 struct Add {
static double apply(double l, double r) {
return l+r;
}
};

struct Subtract {
static double apply(double l, double r) {
return l-r;
}
};

struct Multiply {
static double apply(double l, double r) {
return l*r;
}
};

struct Divide {
static double apply(double l, double r) {
return l/r;
}
};

Finally, when we combine two simple Expressions, we get a complex expression. For example, x might be a simple expression, but x*x is a complex expression. We want to be able to cascade the operators to any degree, thus:

x*x*x

should be a valid expression. Clearly x*x must return an object of such type that can be combined with x using a * operator. x*x must return a reference or pointer to Expression - because Expression is an abstract class so it cannot be returned by value, and besides it is not much point returning it by value because we want to treat it polymorphically (virtual function calls to operator() and clone() in ComplexExpression). Something like the following:


Show line numbers
 Expression* operator * (Expression& l, Expression& r) {
return new ComplexExpression<Multiply>(l, r);
}

The above does not work because operator* returns a pointer to an Expression object - x*x returns a pointer. Consider an expression like x*x*2. x*x*2 is equivalent to (x*x)*2 - since x*x returns an Expression*, we need an operator* which takes a pointer (Expression*) and a double (2). The Standard does not allow an operator to be overloaded on arguments of integer types alone. Thus, operator * can only return a reference.


Show line numbers
 Expression* operator * (Expression& l, Expression& r) {
return new ComplexExpression<Multiply>(l, r);
}

The above works but no one takes the responsibility of deallocating object referred to be the returned reference - we are left with a memory leak.

What about the following:

Show line numbers
 Expression& operator * (Expression& l, Expression& r) {
return ComplexExpression<Multiply>(l, r);
}

The above does not even work - because we are passing a reference to a local object created in the function. It has undefined behaviour. What do we do - well, a bit of inevitable complexity creeps in here. We need to return a pointer wrapped in smart wrappers, which can take care of the life-cycles of the underlying pointer and act as proxy objects when participating in mathematical expressions.

Consider the following definition of a reference counted wrapper cum Proxy for heap allocated Expression pointers:

Show line numbers
 struct ExpressionRef {
ExpressionRef(Expression* ptr) : sp_(ptr), ref_cnt_(1) {
}

~ExpressionRef() {
if (--ref_cnt_ == 0) {
delete sp_;
sp_ = 0;
}
}

ExpressionRef(const ExpressionRef& source) : sp_(source.sp_) {
++ref_cnt_;
}

inline ExpressionRef& operator = (const ExpressionRef& rhs) {
if ( this != &rhs ) {
if ( --ref_cnt_ <= 0 ) {
delete sp_;
sp_ = 0;
}

sp_ = rhs.sp_;
++ref_cnt_;
}

return *this;
}

double operator() (double d) {
return (*sp_)(d);
}

ExpressionRef clone() {
ExpressionRef copy(sp_->clone());
return copy;
}

inline Expression* get() {
return sp_;
}

inline Expression& getref() {
return *sp_;
}

inline operator void*() {
return sp_;
}

private:
Expression *sp_;
int ref_cnt_;
};

Using this class, we can rewrite operator* like below:

Show line numbers
 ExpressionRef operator * (Expression& l, Expression& r) {
return ExpressionRef(new ComplexExpression<Multiply>(l, r));
}


Thus, the return value of x*x is of type ExpressionRef, and at least it wraps a polymorphic reference to Expression, and ensure that there would be no memory leaks. The only problem now is to make expressions such as x*x*x valid. x*x*x translates to (x*x)*x - whereby x*x returns ExpressionRef and x is some sub-type Expression. This is not such a big deal - we define operator * (ExpressionRef&, Expression&). We also need the alternative permutations: operator * (Expression&, ExpressionRef&). In fact, to allow expressions such as (x*x)*(x*x), one must also allow operators like: operator * (ExpressionRef&, ExpressionRef&).

At this point, all the issues with our solution are resolved - and we need to take stock of the final set of operator overloads we need to make the expressions work naturally. Here is a summary:



Expression& Op Expression&
ExpressionRef& Op ExpressionRef&
Expression& Op double - and reverse permutation
ExpressionRef& Op double - and reverse permutation
ExpressionRef& Op Expression& - and reverse permutation

That makes it 8 types of signatures. For each type, if we plan to implement +, -, * and /, then we will have a total of 32 operators. One set of operators can have real implementations and the rest can be defined in terms of the first set. Here are the definitions for the + operators.

Show line numbers
 ExpressionRef operator + (Expression& l, Expression& r) {
return ExpressionRef(new ComplexExpression<Add>(l, r));
}

ExpressionRef operator + (ExpressionRef& l, ExpressionRef& r) {
return l.getref() + r.getref();
}

///////////////

ExpressionRef operator + (Expression& l, const double d) {
return l + Constant(d);
}

ExpressionRef operator + (const double d, Expression& l) {
return Constant(d) + l;
}

//////////////

ExpressionRef operator + (ExpressionRef& l, Expression& r) {
return l.getref() + r;
}

ExpressionRef operator + (Expression& l, ExpressionRef& r) {
return l + r.getref();
}

//////////////

ExpressionRef operator + (ExpressionRef& l, double d) {
return l.getref() + Constant(d);
}

ExpressionRef operator + (double d, ExpressionRef& r) {
return Constant(d) + r.getref();
}

To test this solution, use the following functions:

Show line numbers
 double integrate(Expression& e, double low=0, double high = 1, double epsilon = 0.001) {
assert(low <= high );
double auc1 = 0, auc2 = 0;
double running_x = low;
while ( running_x < high ) {
auc1 += e(running_x) * epsilon;
running_x += epsilon;
auc2 += e(running_x) * epsilon;
}

return (auc1 + auc2)/2;
}

double integrate(ExpressionRef& er, double low=0, double high = 1, double epsilon = 0.001) {
return integrate(er.getref(), low, high, epsilon);
}

int main()
{
Variable x;

cout << ((2*x*x + 3*x + 3)*(2*x*x + 3*x + 3))(2) << endl;
cout << integrate((2*x*x + 3*x + 3), 0, 1) << endl;
cout << integrate((2*x*x + 3*x + 3)*(2*x*x + 3*x + 3), 0, 1) << endl;
cout << integrate((x/(1+x)), 0, 1) << endl;
cout << integrate(Exp(), 0, 1) << endl;
cout << integrate(x*x, 0, 7) << endl;

return 0;
}

Having come thus far, you may have made two observations. One we have written code which is nifty, and works well but is repetitive in some parts and not very extensible. Second, although this article is about Expression Templates, we have hardly used any templates (except for the ComplexExpression class).

What are the key concepts in this code which enabled creating the Expression Functors?
1. Nesting function objects and operator() of outer objects calling operator()'s of inner objects - for example ComplexExpression::operator(). The ormal term for these idioms is Functional Composition.
2. Overloading operators on Expression types.

What we've done so far was laying the foundation for our understanding of Expression Templates. The real deal should not take a lot of time after this. We will now head into investigating how this model of functional composition can be retained and generalized, performance of the Expressions phenomenally improved, and the lines of code drastically cut - by using Templates. This is the topic of the next part of this article.

References


  1. Todd Veldhuizen: Expression Templates

    http://ubiety.uwaterloo.ca/~tveldhui/papers/Expression-Templates/exprtmpl.html

No comments: