CRTP(Curiously recurring template pattern)

C++
Author

Quasar

Published

December 28, 2024

Introduction

Imagine you have an inheritance hierarchy to implement mathematical interpolation.

#include <concepts>
#include <exception>
#include <vector>
#include <cmath>
#include <string>
#include <algorithm>
#include <stdexcept>
#include <iostream>

template<typename I1, typename I2>
class LinearInterpolatorBase{
    public:
    std::vector<double> xValues(){
        return std::vector<double>(m_xBegin, m_xEnd);
    }
    std::vector<double> yValues(){
        return std::vector<double>(m_yBegin, m_yEnd);
    }

    void safetyCheck(){
        if(!(static_cast<int>(m_xEnd - m_xBegin) >= 2)){
            std::string errorMessage = "not enough points to interpolate : at least ";
            errorMessage += m_requiredPoints + " required, ";
            throw std::logic_error(errorMessage);
        }
        
        for(I1 i{m_xBegin}, j{m_xBegin + 1}; j!=m_xEnd; ++i, ++j){
            if(*i > *j){
                throw std::logic_error("unsorted x values");
            }
        }
    }
    
    int locate(double x){
        safetyCheck();
        return std::distance(m_xBegin, std::upper_bound(m_xBegin, m_xEnd, x));
    }

    virtual double value(double) = 0;
    
    void determineBracket(double& x1, double& x2, double &y1, double& y2, double x){
        int N = static_cast<int>(this->m_xEnd - this->m_xBegin);
        int j = this->locate(x);

        if(j == 0)
        {
            x1 = *(this->m_xBegin);
            y1 = *(this->m_yBegin);
            x2 = *(this->m_xBegin + 1);
            y2 = *(this->m_yBegin + 1);
        }else if(j == N){
            x1 = *(this->m_xEnd - 2);
            y1 = *(this->m_yEnd - 2);
            x2 = *(this->m_xEnd - 1);
            y2 = *(this->m_yEnd - 1);
        }else{
            x1 = *(this->m_xBegin + j - 1);
            y1 = *(this->m_yBegin + j - 1);
            x2 = *(this->m_xBegin + j);
            y2 = *(this->m_yBegin + j);
        }
    }
    
    LinearInterpolatorBase(I1 xBegin, I1 xEnd, I2 yBegin, I2 yEnd)
    : m_xBegin(xBegin)
    , m_xEnd(xEnd)
    , m_yBegin(yBegin)
    , m_yEnd(yEnd)
    {}
    
    protected:
    I1 m_xBegin;
    I1 m_xEnd;
    I2 m_yBegin;
    I2 m_yEnd;
    const int m_requiredPoints {2};
};

template<typename I1, typename I2>
class LinearInterpolator : public LinearInterpolatorBase<I1, I2>{
public:
    double value(double x) override{
        double x1{0.0}, x2{0.0}, y1{0.0}, y2{0.0}, y{0.0};
        this->determineBracket(x1, x2, y1, y2, x);
        
        double t {(x - x1)/(x2 - x1)};
        
        return (1 - t) * y1 + t * y2;
    }

    LinearInterpolator(I1 xBegin, I1 xEnd, I2 yBegin, I2 yEnd)
    : LinearInterpolatorBase<I1,I2>(xBegin, xEnd, yBegin, yEnd) {}
};

template<typename I1, typename I2>
class LogLinearInterpolator : public LinearInterpolatorBase<I1, I2>{
public:
    double value(double x) override{
        double x1{0.0}, x2{0.0}, y1{0.0}, y2{0.0}, y{0.0};
        this->determineBracket(x1, x2, y1, y2, x);
        
        double t {(x - x1)/(x2 - x1)};
        
        return exp((1 - t) * log(y1) + t * log(y2));
    }
    
    LogLinearInterpolator(I1 xBegin, I1 xEnd, I2 yBegin, I2 yEnd)
    : LinearInterpolatorBase<I1,I2>(xBegin, xEnd, yBegin, yEnd) {}
};



int main(int argc, const char * argv[]) {
    std::vector<double> discountFactors{
        1, 0.9523, 0.9070, 0.8683, 0.8227,0.7835,
        0.7462, 0.7106, 0.6768, 0.6446, 0.6139
    };
    
    std::vector<double> times{
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
    };
    
    double result{0.0};

    using I = std::vector<double>::iterator;

    LinearInterpolatorBase<I,I>* p;
    
    LinearInterpolator linearInterp(
        times.begin(),
        times.end(),
        discountFactors.begin(),
        discountFactors.end()
    );

    p = &linearInterp;

    result = p->value(2.5);
    std::cout << "\nResult of linear interpolation y-value = " << result;

    LogLinearInterpolator logLinearInterp(
        times.begin(),
        times.end(),
        discountFactors.begin(),
        discountFactors.end()
    );

    p = &logLinearInterp;
    result = p->value(2.5);
    std::cout << "\nResult of log-linear interpolation y-value = " << result;
    return 0;
}

Compiler Explorer

The base class LinearInterpolationBase has a pure virtual function value and the child classes LinearInterpolation and LogLinearInterpolation will override this virtual function doing different things. This is called dynamic polymorphism.

The implementation of double value(double) chosen at run-time is determined by the object bound to the base class pointer/reference.

You can also overload free-standing functions or class member functions, provided they have different type/number of arguments. For example, you can over the + operator to support addition of two std::vectors component-wise. And this is the compile-time version of polymorphism, called static polymorphism.

Dynamic polymorphism incurs a performance cost because in order to know what functions to call, the compiler needs to build a table of pointers to virtual functions. So, there is some level of indirection when calling virtual functions polymorphically.

Can we get the benefits of dynamic polymorphism at compile time? One way to achieve that is the Curiously Recurring Template Patter(CRTP).

The Curiously Recurring Template Pattern

While diving into the internals of some libraries, James Coplien observed a commonly recurring pattern, and for the lack of a better word, wrote a paper titled the Curiously Recurring Template Pattern(CRTP). The name stuck.

In practice, a minimalistic example of this idiom is as follows:

#include <iostream>
#include <memory>
#include <vector>

template<typename Derived>
struct B{
    void do_work(){
        static_cast<Derived*>(this)->work();
    }
};

struct X : public B<X>{
    void work(){
        std::cout << "\nimpl of X::work";
    }
};

struct Y : public B<Y>{
    void work(){
        std::cout << "\nimpl of Y::work";
    }
};

template<typename T>
void perform_work(T* obj){
    obj->do_work();
}

int main(){
    X x;
    Y y;
    perform_work(&x);
    perform_work(&y);
    return 0;
}

The do_work method in the base class performs an upcast of the this pointer to the derived class pointer T*.

Compiler Explorer

We have move the runtime polymorphism to compile-time. Therefore, the perform_work function cannot treat X and Y objects polymorphically. Instead, we get two different overloads, one that can handle X objects and one that can handle Y objeccts. This is static polymorphism.

We could do this for our numerical routines as well:

#include <concepts>
#include <exception>
#include <vector>
#include <cmath>
#include <string>
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <memory>

template<typename I1, typename I2, typename Derived>
class LinearInterpolatorBase{
    /* 
        ...
    */

    double computeValue(double x){
        return static_cast<Derived*>(this)->value(x);
    }
    
    /*
        ...
    */
};

template<typename I1, typename I2>
class LinearInterpolator : public LinearInterpolatorBase<I1, I2, LinearInterpolator<I1,I2>>{
public:
    double value(double x){
        /*
            ...
        */
    }

    LinearInterpolator(I1 xBegin, I1 xEnd, I2 yBegin, I2 yEnd)
    : LinearInterpolatorBase<I1,I2,LinearInterpolator>(xBegin, xEnd, yBegin, yEnd) {}
};

template<typename I1, typename I2>
class LogLinearInterpolator : public LinearInterpolatorBase<I1, I2, LogLinearInterpolator<I1,I2>>{
public:
    double value(double x){
        /*
            ...
        */
    }
    
    LogLinearInterpolator(I1 xBegin, I1 xEnd, I2 yBegin, I2 yEnd)
    : LinearInterpolatorBase<I1,I2,LogLinearInterpolator>(xBegin, xEnd, yBegin, yEnd) {}
};

template<typename I1, typename I2, typename T>
double computeValueHelper(LinearInterpolatorBase<I1,I2,T>* interpolator, double x){
    return interpolator->computeValue(x);
}

int main(int argc, const char * argv[]) {
    std::vector<double> discountFactors{
        1, 0.9523, 0.9070, 0.8683, 0.8227,0.7835,
        0.7462, 0.7106, 0.6768, 0.6446, 0.6139
    };
    
    std::vector<double> times{
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
    };
    
    double result{0.0};


    LinearInterpolator linearInterp(
        times.begin(),
        times.end(),
        discountFactors.begin(),
        discountFactors.end()
    );

    result = computeValueHelper(&linearInterp, 2.5);
    std::cout << "\nResult of linear interpolation y-value = " << result;

    LogLinearInterpolator logLinearInterp(
        times.begin(),
        times.end(),
        discountFactors.begin(),
        discountFactors.end()
    );

    result = computeValueHelper(&logLinearInterp, 2.5);

    std::cout << "\nResult of log-linear interpolation y-value = " << result;
    return 0;
}

Compiler Explorer

CRTP in QuantLib

It’s always interesting to learn design patterns not in some theoretical project or a textbook or something, but actually find them working in a real-life library and QuantLib is certainly a big enough and an interesting enough library to give us a few examples of design patterns for our pleasure.

namespace QuantLib {
 
    //! Support for the curiously recurring template pattern
    /*! See James O. Coplien. A Curiously Recurring Template Pattern.
        In Stanley B. Lippman, editor, C++ Gems, 135-144.
        Cambridge University Press, New York, New York, 1996.
 
        \ingroup patterns
    */
    template <class Impl>
    class CuriouslyRecurringTemplate {
      protected:
        // not meant to be instantiated as such
        CuriouslyRecurringTemplate() = default;
        ~CuriouslyRecurringTemplate() = default;
        // support for the curiously recurring template pattern
        Impl& impl() {
            return static_cast<Impl&>(*this);
        }
        const Impl& impl() const {
            return static_cast<const Impl&>(*this);
        }
    };
 
}
#endif

Essentially, all that this pattern does is, it takes the template argument Impl and then it performs a cast of the current object *this to whatever type is implemented.

The impl() methods are actually protected, so they are not meant to be used within client code, but are internal to the framework.

Let’s look at the derived classes Solver1D and Bisection.

Solver1D is a general-purpose one-dimension numerical solver and the class header has the signature:

template<typename Impl>
class Solver1D : CuriouslyRecurringTemplate<Impl>
{
    public:
    template <class F>
        Real solve(const F& f,
                   Real accuracy,
                   Real guess,
                   Real xMin,
                   Real xMax) const {
 
            /*
                ...
            */
 
            root_ = guess;
 
            return this->impl().solveImpl(f, accuracy);
        }
};

The class header for Bisection is typically saying that I am a one-dimensional solver Solver1D and I am going to feed by own type Bisection as a template argument to Solver1D<Impl>.

namespace QuantLib {
 
    //! %Bisection 1-D solver
    class Bisection : public Solver1D<Bisection> {
      public:
        template <class F>
        Real solveImpl(const F& f,
                       Real xAccuracy) const {
 
            /* 
                ...
            */
        }
    };
}

This results in the compiler instantiating the class templates and defining CuriousRecurringTemplate<Bisection> and Solver1D<Bisection> classes.

So, if you go back to Solver1D, it doesn’t know the actual implementation. It says, whatever the implementation(Impl template parameter) is, I want to call solveImpl() on it.

Limiting the object count with CRTP

One scenario in which the CRTP idiom could be applied is to limit the object count. We code up a limit_instances class template and have our user-declared classes specialize from it. The limit_instances class has a class(static) member variable called count that keeps track of the number instances of the user-objects. Each time the user-object constructor(and thus limit_instances base class constructor) is invoked, we are to increment count and every destructor call decrements the count.

#include <iostream>
#include <exception>
#include <atomic>

template<typename T, int N>
struct limit_instances{
    static int count;

    limit_instances(){
        if(count >= N)
            throw std::logic_error("Too many instances.");
        ++count;
    }

    ~limit_instances(){
        --count;
    }
};

template<typename T, int N>
int limit_instances<T,N>::count = 0;

struct X : public limit_instances<X,3>{
    /*
        User-defined private member fields
        and public methods.
    */
};

struct Y : public limit_instances<Y,5>{
    /*
        ...
    */
};

int main()
{
    X x1,x2,x3;   // okay
    try{
        X x4;         
    }
    catch(std::exception& e){
        std::cout << e.what() << "\n";
    }
    
    Y y[5];       // okay
     try{
        Y anotherInstance;         
    }
    catch(std::exception& e){
        std::cout << e.what() << "\n";
    } 
    return 0;
}

Compiler Explorer

The class template limit_instances takes the user-class T and the maximum number of allowable instances N as template parameters. So, the user-class when inheriting from limit_instances and while invoking the base-class constructor should supply appropriate template arguments.

Implementing the Composite design pattern

The composite design pattern lets you compose objects into tree structures.

CRTP can be used to implement the composite design pattern. We can quickly code it up:

#include <vector>
#include <iostream>
#include <string>

template<typename T>
struct Component{
    template<typename U>
    void connect(U& other);

    std::string name;
    Component(){}
    Component(std::string name_) : name(name_) {}

};

struct Node : Component<Node>{
    Node* begin(){
        return this;
    }    

    Node* end(){
        return (this + 1);
    }

    Node(std::string name_) : Component<Node>(name_) {}
    std::vector<Node*> connections;
};

struct Composite : std::vector<Node>, Component<Composite>{
    Composite() : Component<Composite>(){}
    Composite(std::string name_) : Component<Composite>(name_){}
};

template<typename T>
template<typename U>
void Component<T>::connect(U& other){
    for(Node& from : *static_cast<T*>(this)){
        for(Node& to : other){
            from.connections.push_back(&to);
            to.connections.push_back(&from);
        }
    }
}

int main(){
    
    Node a("Gregory Peck");
    
    Node b("Marlon Brando");
    Node c("Audrey Hepburn");
    Node d("Charles Chaplin");

    Node guitarist("Jimmy Hendrix");
    Node artistParExcellence("Elvis Presley");
    
    Composite actorsTroupe;
    actorsTroupe.push_back(a);
    actorsTroupe.push_back(b);
    actorsTroupe.push_back(c);
    actorsTroupe.push_back(d);

    guitarist.connect(actorsTroupe);
    actorsTroupe.connect(artistParExcellence);
    
    return 0;
}

This helps avoid the explosion of state-space and writing methods such as Node::connect(Node&), Node::connect(Composite&), Composite::connect(Node&) and Node::connect(Node&).