Thread-Safe Stack Implementation

C++
Author

Quasar

Published

February 23, 2025

Basic thread-safe stack and queue

We can use C++ synchronization primitives to implement a basic thread-safe stack and queue.

Thread-safe stack

Basic functionalities expected from a thread-safe stack

#include <gtest/gtest.h>
#include "threadsafe_stack.h"
#include <thread>
#include <vector>
#include <algorithm>

TEST(ThreadSafeStackTest, PushAndTopTest) {
    dev::threadsafe_stack<int> stack;
    stack.push(42);
    EXPECT_EQ(stack.top(), 42);
}

TEST(ThreadSafeStackTest, PopTest) {
    dev::threadsafe_stack<int> stack;
    stack.push(42);
    auto popped = stack.pop();
    EXPECT_TRUE(popped.has_value());
    EXPECT_EQ(popped.value(), 42);
    EXPECT_TRUE(stack.empty());
}

TEST(ThreadSafeStackTest, EmptyTest) {
    dev::threadsafe_stack<int> stack;
    EXPECT_TRUE(stack.empty());
    stack.push(42);
    EXPECT_FALSE(stack.empty());
}

TEST(ThreadSafeStackTest, ConcurrentPushTest) {
    dev::threadsafe_stack<int> stack;
    std::vector<std::thread> threads;

    for (int i = 0; i < 10; ++i) {
        threads.emplace_back([&stack, i]() {
            stack.push(i);
        });
    }

    for (auto& thread : threads) {
        thread.join();
    }

    EXPECT_FALSE(stack.empty());
    EXPECT_TRUE(stack.size() == 10);
}

TEST(ThreadSafeStackTest, ConcurrentPopTest) {
    dev::threadsafe_stack<int> stack;
    for (int i = 0; i < 10; ++i) {
        stack.push(i);
    }

    std::vector<std::thread> threads;
    std::vector<int> results;
    std::mutex mtx;

    for (int i = 0; i < 10; ++i) {
        threads.emplace_back([&stack, &results, &mtx]() {
            auto popped = stack.pop();
            if (popped.has_value()) {
                std::unique_lock<std::mutex> unique_lck(mtx);
                results.push_back(popped.value());
            }
        });
    }

    for (auto& thread : threads) {
        thread.join();
    }

    EXPECT_EQ(results.size(), 10);
}

TEST(ThreadSafeStackTest, ConcurrentSwapTest){
    dev::threadsafe_stack<int> evens, odds;
    for(int i{0};i<5;++i)
        evens.push(2*i);
    
    for(int i{0};i<5;++i)
        odds.push(2*i + 1);  
        
    dev::threadsafe_stack<int> evensCopy {evens};
    dev::threadsafe_stack<int> oddsCopy {odds};            

    std::vector<std::thread> threads;

    for(int j{0};j<2;++j){
        threads.emplace_back([&evens, &odds](){
            evens.swap(odds);
        });
    }

    for(int j{0};j<2;++j)
        threads[j].join();
    
    EXPECT_EQ(evens == evensCopy, true);
    EXPECT_EQ(odds == oddsCopy, true);
}

Implementation Notes

#include <stack>
#include <mutex>
#include <shared_mutex>
#include <optional>

namespace dev{
    template<typename T>
    class threadsafe_stack{
        private:
        std::stack<T> m_stack;
        std::shared_mutex m_shared_mutex;

        public:
        using value_type = T;
        using reference = T&;
        using const_reference = const T&;

        // default constructor
        threadsafe_stack() = default;

        // copy constructor
        threadsafe_stack(const threadsafe_stack& other)
        {
            std::shared_lock<std::shared_mutex> shared_lck(const_cast<std::shared_mutex&>(other.m_shared_mutex));
            m_stack = other.m_stack;
        }

        // copy assignment
        threadsafe_stack& operator=(const threadsafe_stack& ) = delete;

        /**
         * @brief Inserts an element at the top of the stack.
         */
        void push(const_reference element){
            std::unique_lock unique_lck(m_shared_mutex);
            if constexpr(std::is_nothrow_constructible_v<T>){
                m_stack.push(std::move(element));
            }
            else{
                m_stack.push(element);
            }
        }

        std::optional<T> pop(){
            std::unique_lock<std::shared_mutex> unique_lck(m_shared_mutex);
            if(m_stack.empty())
                return std::nullopt;
            
            T element;
            if constexpr(std::is_nothrow_move_constructible_v<T>){
                element = std::move(m_stack.top());
            }
            else{
                element = m_stack.top();
            }
            m_stack.pop();
            return element;
        }

        value_type top(){
            std::shared_lock<std::shared_mutex> shared_lck(m_shared_mutex);
            return m_stack.top();
        }

        bool empty(){
            std::shared_lock<std::shared_mutex> shared_lck(m_shared_mutex);
            return m_stack.empty();
        }

        std::size_t size(){
            std::shared_lock<std::shared_mutex> shared_lck(m_shared_mutex);
            return m_stack.size();
        }

        void swap(threadsafe_stack& other){
            std::scoped_lock<std::shared_mutex, std::shared_mutex> scoped_lck(m_shared_mutex, other.m_shared_mutex);
            std::swap(m_stack, other.m_stack);
        }

        friend void swap(threadsafe_stack& lhs, threadsafe_stack& rhs){
            lhs.swap(rhs);
        }

        friend bool operator==(threadsafe_stack& lhs, threadsafe_stack& rhs){
            std::scoped_lock<std::shared_mutex, std::shared_mutex> scoped_lck(lhs.m_shared_mutex, rhs.m_shared_mutex);
            if(lhs.m_stack.size() != rhs.m_stack.size())
                return false;
            return lhs.m_stack == rhs.m_stack;
        }
    };
}