Implementing std::visit

C++
Author

dev::author

Published

April 25, 2026

Introduction

A variant models a choice between types. It is a smart union. Variants are sum types.

struct on { int temperature_; };
struct off { };

using oven_state = std::variant<on, off>;

Recursive node-based structures, error handling, state machines etc. are examples of concepts that can be elegantly modeled by variants:

  • JSON
  • ASTs(Abstract Syntax Trees)
  • State Machines
  • Error Handling

Variant visitation

What is visitation?

Visitation can be defined as an abstraction that allows you to access the currently active variant alternative in an exhaustive and expressive manner. You are forced to provide some behavior for all the types that the variant supports. Its expressive, because it doesn’t work through a chain of if-else; its a very elegant way of saying, for these types perform these actions.

Traditional visitation

Traditional visitation requires a Callable object which can be invoked with every possible variant alternative. The traditional way of creating such an object is defining a struct.

// Traditional visitation example - one variant
struct printer{
    void operator()(int){ std::cout << "int" << "\n"; }
    void operator()(float){ std::cout << "float" << "\n"; }
    void operator()(double){ std::cout << "double" << "\n"; }
};

using my_variant = std::variant<int, float, double>;
my_variant v0{ 20.0f };

std::visit(printer{}, v0);

If you call std::visit with an instance of your printer struct and the variant, it will automatically figure out, what type is in the variant and call the corresponding overload.

You can extend this logic to multiple variants.

struct circle{ double radius_; };
struct box{ 
    double width_; 
    double height_;
    double depth;
};

struct collision_resolver{
    void operator()(circle, circle){ /*...*/ }
    void operator()(circle, box){ /*...*/ }
    void operator()(box, circle){ /*...*/ }
    void operator()(box, box){ /*...*/ }
};

using my_variant = std::variant<circle, box>;
my_variant v0{circle{}};
my_variant v1{box{}};

std::visit(collision_resolver{}, v0, v1);

What you get is a form of multiple dispatch, that has to handle all possible combinations. This is also really powerful way of implementing multiple dispatch.

Building helper functions needed for std::visit

Under the hood std::visit must invoke the right candidate for the currently active alternatives from the function overload set. We start with the simplest case of \(2\) variants. The central idea is to maintain a single array of function pointers. Each entry in the function pointer table has the shape [](Visitor visitor, Variant0 v0, Variant1 v1). For concreteness, assume that both variants v0 and v1 offer a choice between \(2\) types:

using Variant0 = std::variant<int, float>;
using Variant1 = std::variant<int, float>;

It is possible to have \(4\) distinguishable behaviors(overloads), when a visitor visits these variants. So, we build an array of size \(4\).

The std::visit implementation should automatically invoke the correct function overload based on the currently active types in v0 and v1. Before we proceed, let us fix some terminology first. Suppose a variant v allows a choice amongst \(n\) types:

// Pseudo-code
std::variant<T0, T1, ..., Tn> v;

In the case of 2 variants v0 and v1 offering a choice amongst \(n_1\) and \(n_2\) alternative types respectively, we have a total of \(n_1 \times n_2\) distinguishgable choices. For example, consider the case of v0 offer a choice of \(3\) types and v1 offering a choice of \(4\) types. Each choice is a distinct entry in the 2d-table, and we have a total of \(3 \times 4 = 12\) distinct choices. We will have \(12\) candidates in our function overload set. We can assign a multi-index <i,j> to the function overload [](Ti, Tj){ /*...*/ }.

j=0 j=1 j=2 j=3
i=0 <0, 0> <0, 1> <0, 2> <0, 3>
i=1 <1, 0> <1, 1> <1, 2> <1, 3>
i=2 <2, 0> <2, 1> <2, 2> <2, 3>

We can convert this 2d-table to a 1d-table.

Multi-index 1D-index
<0, 0> 0
<0, 1> 1
<0, 2> 2
<0, 3> 3
<1, 0> 4
<1, 1> 5
<1, 2> 6
<1, 3> 7
<2, 0> 8
<2, 1> 9
<2, 2> 10
<2, 3> 11

Thus, the pointer to the function [](Ti, Tj){ /*...*/} with multi-index <i,j> is mapped to 4i + j 1d-index in a 1d-table.

In the case of 3 variants v0, v1 and v2 offering a choice amongst \(n_1 = 3\), \(n_2 = 4\) and \(n_3 = 5\) alternative types respectively, we have a total of \(3 \times 4 \times 5 = 60\) distinguishable choices. We will have \(60\) candidates in our function overload set. We can assign a multi-index <i,j,k> to the function overload [](Ti, Tj, Tk){ /*...*/ }.

j=0 j=1 j=2 j=3
i=0 <0,0,k> <0,1,k> <0,2,k> <0,3,k>
i=1 <1,0,k> <1,1,k> <1,2,k> <1,3,k>
i=2 <2,0,k> <2,1,k> <2,2,k> <2,3,k>

where each entry <i,j,k> expands over \(k = 0, 1, 2, 3, 4\). The full set of multi-indices and their 1D mappings is:

Multi-index 1D-index Multi-index 1D-index Multi-index 1D-index
<0, 0, 0> 0 <1, 0, 0> 20 <2, 0, 0> 40
<0, 0, 1> 1 <1, 0, 1> 21 <2, 0, 1> 41
<0, 0, 2> 2 <1, 0, 2> 22 <2, 0, 2> 42
<0, 0, 3> 3 <1, 0, 3> 23 <2, 0, 3> 43
<0, 0, 4> 4 <1, 0, 4> 24 <2, 0, 4> 44
<0, 1, 0> 5 <1, 1, 0> 25 <2, 1, 0> 45
<0, 1, 1> 6 <1, 1, 1> 26 <2, 1, 1> 46
<0, 1, 2> 7 <1, 1, 2> 27 <2, 1, 2> 47
<0, 1, 3> 8 <1, 1, 3> 28 <2, 1, 3> 48
<0, 1, 4> 9 <1, 1, 4> 29 <2, 1, 4> 49
<0, 2, 0> 10 <1, 2, 0> 30 <2, 2, 0> 50
<0, 2, 1> 11 <1, 2, 1> 31 <2, 2, 1> 51
<0, 2, 2> 12 <1, 2, 2> 32 <2, 2, 2> 52
<0, 2, 3> 13 <1, 2, 3> 33 <2, 2, 3> 53
<0, 2, 4> 14 <1, 2, 4> 34 <2, 2, 4> 54
<0, 3, 0> 15 <1, 3, 0> 35 <2, 3, 0> 55
<0, 3, 1> 16 <1, 3, 1> 36 <2, 3, 1> 56
<0, 3, 2> 17 <1, 3, 2> 37 <2, 3, 2> 57
<0, 3, 3> 18 <1, 3, 3> 38 <2, 3, 3> 58
<0, 3, 4> 19 <1, 3, 4> 39 <2, 3, 4> 59

Thus, the pointer to the function [](Ti, Tj, Tk){ /*...*/} with multi-index <i,j,k> is mapped to the 1D-index \(20i + 5j + k\).

We shall code up two helper functions to_1d_index<3,4,5>(size_t, size_t, size_t) and from_1d_index<3,4,5>(size_t) to convert a multi-index to a 1d-index and back.

namespace dev::tools{
    template<std::size_t... Dimensions>
    constexpr auto build_coeffs_array(){
        constexpr std::array<std::size_t, sizeof...(Dimensions)> dimensions { Dimensions... };
        constexpr std::size_t coeffs_size = sizeof...(Dimensions);
        std::array<std::size_t, coeffs_size> coeffs = {};
        
        for(std::size_t i{0}; i < coeffs_size; ++i)
            coeffs[i] = 1;

        for(std::size_t i{0}; i < coeffs_size - 1; ++i){
            // In step i, we need to populate all coeffs[j], j <= i
            for(std::size_t j{0}; j <= i; ++j){
                coeffs[j] *= dimensions[i+1];
            }
        }
        return coeffs;
    }
    /*
    coeff[0] = v1size * ... * v[n-1] size
    coeff[1] = v2size * .... *v[n-1] size
    ...
    coeff[n-2] = v[n-1] size
    coeff[n-1] = 1
    */
    template<std::size_t... Dimensions>
    std::size_t constexpr to_1d_index(auto... indices){
        constexpr std::array<std::size_t, sizeof...(Dimensions)> dimensions { Dimensions... };
        std::size_t coeffs_size = sizeof...(Dimensions);
        auto coeffs = build_coeffs_array<Dimensions...>();
                                                                                                       
        const std::array<size_t, sizeof...(indices)> indices_arr { indices... };
        //std::println("{}", indices_arr); 
        return std::inner_product(coeffs.begin(), coeffs.end(), indices_arr.begin(), 0); 
    }

    /*
    Suppose we have the dimensions <3, 5, 2> and the coordinates (1, 3, 1).
    (1, 3, 1) maps to the linear index 17. 
    */

    template<size_t N>
    constexpr auto build_coords_array(std::array<size_t,N> coeffs, size_t& initState){ 
        size_t state = initState;
        std::array<size_t, N> coords{};
        for(size_t i{0}; i < N; ++i){
            coords[i] = static_cast<size_t>(state / coeffs[i]);
            state -= coords[i] * coeffs[i];
        }
        return coords;
    }

    template<size_t... Dimensions>
    decltype(auto) constexpr from_1d_index(std::size_t n){
        constexpr std::size_t coords_arr_size = sizeof...(Dimensions);
        constexpr auto coeffs = build_coeffs_array<Dimensions...>();
        std::array<size_t, coords_arr_size> coords = build_coords_array(coeffs, n);
        return coords;
    }
}

Simpler case of \(1\) and \(2\) variants

I strongly recommend creating a simpler version first, maybe hardcoded with two or three variants. So you can get an idea of the actual shape before making it variadic. Otherwise you’d be facing both problems at once.

// Simple case of 2 variants
namespace dev{

    namespace example{

        struct Dummy{};
        template<typename T>
        using Wrapper = std::conditional_t<std::is_void_v<T>,Dummy,T>;


        template<typename Visitor, typename Variant0, typename Variant1>
        decltype(auto) visit(Visitor&& visitor, Variant0 v0, Variant1 v1){
            constexpr std::array<std::size_t, 2> dimensions = { std::variant_size_v<Variant0>, std::variant_size_v<Variant1> };
            constexpr std::size_t vtable_size = (std::variant_size_v<Variant0> * std::variant_size_v<Variant1>);
            using cases_t = std::string(*)(Visitor, Variant0, Variant1);
            static constexpr auto vtable {
                []<size_t... Indices>(std::index_sequence<Indices...>){
                    return std::array<cases_t,4>{
                        [](Visitor vis, Variant0 v0, Variant1 v1){ 
                            constexpr auto multi_idx = dev::tools::from_1d_index<2,2>(Indices);
                            return vis(std::get<multi_idx[0]>(v0), std::get<multi_idx[1]>(v1));
                    }... };
                    
                }(std::make_index_sequence<vtable_size>())
            };

            return vtable[dev::tools::to_1d_index<2,2>(v0.index(), v1.index())](visitor, v0, v1);
        }
    };
}

template<typename... Callables>
struct Visitor : Callables...{
    using Callables::operator()...;
};

void test_single_dispatch(){
    std::variant<int, float, double> v{0};
}

void test_double_dispatch(){
    std::variant<int, float> v1{0};
    std::variant<int, float> v2{3.14f};

    auto result = dev::example::visit(
        Visitor{
            [](int, int) -> std::string { return "(int, int)"; },
            [](int, float) -> std::string { return "(int, float)"; },
            [](float, int) -> std::string { return "(float, int)"; },
            [](float, float) -> std::string { return "(float, float)"; },
        },
        v1, v2
    );

    std::cout << "result = " << result << "\n";
}

Generalizing to a pack of variants.

The shape of a function inside the vtable_func array should be something like this:

[](Visitor&& vis, Variants&&... var) { /* ... */ }

Internally you would invoke vis with std::get<Is>(var)..., where Is... would be an index sequence or constexpr array of the desired indices.

namespace dev{
    template <typename Visitor, typename... Variants>
    decltype(auto) visit(Visitor &&visitor, Variants &&...vs) {
        constexpr std::size_t vtable_size = (std::variant_size_v<std::remove_cvref_t<Variants>> * ...);

        // Each entry in the vtable should have the shape [](Visitor visitor, Variants... vs){}
        using result_t = decltype(visitor(std::get<0>(vs)...));
        using cases_t = result_t(*)(Visitor, Variants...);

        static constexpr auto vtable{
            []<size_t... Indices>(std::index_sequence<Indices...>){
                constexpr std::array<std::size_t, sizeof...(Variants)> dimensions = { std::variant_size_v<std::remove_cvref_t<Variants>>... };
                constexpr std::size_t vtable_size = (std::variant_size_v<std::remove_cvref_t<Variants>> * ...);
                return std::array<cases_t, vtable_size>{
                    [](Visitor vis, Variants... vs) -> result_t{
                        constexpr auto multi_idx = dev::tools::from_1d_index<std::variant_size_v<std::remove_cvref_t<Variants>>...>(Indices);
                        return [&]<size_t... Is>(std::index_sequence<Is...>){
                            return vis((static_cast<std::variant_alternative_t<Is, std::remove_cvref_t<Variants>>>(std::get<multi_idx[Is]>(vs)))...);
                        }(std::make_index_sequence<sizeof...(Variants)>());
                    }...
                };
            }(std::make_index_sequence<vtable_size>())
        };

        auto i = dev::tools::to_1d_index<std::variant_size_v<std::remove_cvref_t<Variants>>...>(vs.index()...);
        return vtable[i](visitor, vs... );
    }
}

void test_multiple_dispatch(){
    std::variant<int, float, double> v1{0};
    std::variant<int, float, double> v2{3.14f};
    //std::variant<char, int, long, float, double> v3{'H'};

    auto result = dev::visit(
        Visitor{
            [](int, int) -> std::string { return "(int, int)"; },
            [](int, float) -> std::string { return "(int, float)"; },
            [](int, double) -> std::string { return "(int, double)"; },
            [](float, int) -> std::string { return "(float, int)"; }
        },
        v1, v2
    );

    std::cout << "result = " << result << "\n";
}

Compiler Explorer

References