使用C++11實現一個半同步半異步線程池

此處輸入圖片的描述

前言

C++11以前咱們使用線程須要系統提供API、posix線程庫或者使用boost提供的線程庫,C++11後就加入了跨平臺的線程類std::thread,線程同步相關類std::mutex、std::lock_guard、std::condition_variable、std::atomic以及異步操做相關類std::async、std::future、std::promise等等,這使得咱們編寫跨平臺的多線程程序變得容易,線程的一個高級應用就是線程池,使用線程池能夠充分利用多核CPU的並行計算能力,以及避免了使用單個線程的建立和銷燬的開銷,因此線程池在實際項目中用的很普遍,不少RPC框架都是用了線程池來處理事務,好比說Thrifteasyrpc等等,接下來咱們將使用C++11來實現一個通用的半同步半異步線程池(我的博客也發表了《使用C++11實現一個半同步半異步線程池》)。ios

實現

一個半同步半異步線程池分爲三層。git

  1. 同步服務層:它處理來自上層的任務請求,上層的請求多是併發的,這些請求不是立刻就會被處理的,而是將這些任務放到一個同步排隊層中,等待處理。
  2. 同步排隊層: 來自上層的任務請求都會加到排隊層中等待處理,排隊層實際就是一個std::queue。
  3. 異步服務層: 這一層中會有多個線程同時處理排隊層中的任務,異步服務層從同步排隊層中取出任務並行的處理。

這三個層次之間須要使用std::mutex、std::condition_variable來進行事件同步,線程池的實現代碼以下。github

#ifndef _THREADPOOL_H
#define _THREADPOOL_H

#include <vector>
#include <queue>
#include <thread>
#include <mutex>
#include <memory>
#include <functional>
#include <condition_variable>
#include <atomic>
#include <type_traits>

static const std::size_t max_task_quque_size = 100000;
static const std::size_t max_thread_size = 30;

class thread_pool
{
public:
    using work_thread_ptr = std::shared_ptr<std::thread>;
    using task_t = std::function<void()>; 

    explicit thread_pool() : _is_stop_threadpool(false) {}

    ~thread_pool()
    {
        stop();
    }

    void init_thread_num(std::size_t num)
    {
        if (num <= 0 || num > max_thread_size)
        {
            std::string str = "Number of threads in the range of 1 to " + std::to_string(max_thread_size);
            throw std::invalid_argument(str);
        }

        for (std::size_t i = 0; i < num; ++i)
        {
            work_thread_ptr t = std::make_shared<std::thread>(std::bind(&thread_pool::run_task, this));
            _thread_vec.emplace_back(t);
        }
    }

    // 支持普通全局函數、靜態函數、以及lambda表達式
    template<typename Function, typename... Args>
    void add_task(const Function& func, Args... args)
    {
        if (!_is_stop_threadpool)
        {
            // 用lambda表達式來保存函數地址和參數
            task_t task = [&func, args...]{ return func(args...); };
            add_task_impl(task);
        }
    }

    // 支持函數對象(仿函數)
    template<typename Function, typename... Args>
    typename std::enable_if<std::is_class<Function>::value>::type add_task(Function& func, Args... args)
    {
        if (!_is_stop_threadpool)
        {
            task_t task = [&func, args...]{ return func(args...); };
            add_task_impl(task);
        }
    }

    // 支持類成員函數
    template<typename Function, typename Self, typename... Args>
    void add_task(const Function& func, Self* self, Args... args)
    {
        if (!_is_stop_threadpool)
        {
            task_t task = [&func, &self, args...]{ return (*self.*func)(args...); };
            add_task_impl(task);
        }
    }

    void stop()
    {
        // 保證terminate_all函數只被調用一次
        std::call_once(_call_flag, [this]{ terminate_all(); });
    }

private:
    void add_task_impl(const task_t& task)
    {
        {
            // 任務隊列滿了將等待線程池消費任務隊列
            std::unique_lock<std::mutex> locker(_task_queue_mutex);
            while (_task_queue.size() == max_task_quque_size && !_is_stop_threadpool)
            {
                _task_put.wait(locker);
            }

            _task_queue.emplace(std::move(task));
        }

       // 向任務隊列插入了一個任務並提示線程池能夠來取任務了
        _task_get.notify_one();
    }

    void terminate_all()
    {
        _is_stop_threadpool = true;
        _task_get.notify_all();

        for (auto& iter : _thread_vec)
        {
            if (iter != nullptr)
            {
                if (iter->joinable())
                {
                    iter->join();
                }
            }
        }
        _thread_vec.clear();

        clean_task_queue();
    }

    void run_task()
    {
        // 線程池循環取任務
        while (true)
        {
            task_t task = nullptr;
            {
                // 任務隊列爲空將等待
                std::unique_lock<std::mutex> locker(_task_queue_mutex);
                while (_task_queue.empty() && !_is_stop_threadpool)
                {
                    _task_get.wait(locker);
                }

                if (_is_stop_threadpool)
                {
                    break;
                }

                if (!_task_queue.empty())
                {
                    task = std::move(_task_queue.front());
                    _task_queue.pop();
                }
            }

            if (task != nullptr)
            {
                // 執行任務,並通知同步服務層能夠向隊列聽任務了
                task();
                _task_put.notify_one();
            }
        }
    }

    void clean_task_queue()
    {
        std::lock_guard<std::mutex> locker(_task_queue_mutex);
        while (!_task_queue.empty())
        {
            _task_queue.pop();
        }
    }

private:
    std::vector<work_thread_ptr> _thread_vec;
    std::condition_variable _task_put;
    std::condition_variable _task_get;
    std::mutex _task_queue_mutex;
    std::queue<task_t> _task_queue;
    std::atomic<bool> _is_stop_threadpool;
    std::once_flag _call_flag;
};

#endif

測試代碼

#include <iostream>
#include <string>
#include <chrono>
#include "thread_pool.hpp"

void test_task(const std::string& str)
{
    std::cout << "Current thread id: " << std::this_thread::get_id() << ", str: " << str << std::endl;
    std::this_thread::sleep_for(std::chrono::milliseconds(50));
}

class Test
{
public:
    void print(const std::string& str, int i)
    {
        std::cout << "Test: " << str << ", i: " << i << std::endl;
    }
};

class Test2
{
public:
    void operator()(const std::string& str, int i)
    {
        std::cout << "Test2: " << str << ", i: " << i << std::endl;
    }
};

int main()
{
    Test t;
    Test2 t2;
    thread_pool pool;
    // 啓動10個線程
    pool.init_thread_num(10);

    std::string str = "Hello world";
    
    for (int i = 0; i < 1000; ++i)
    {
        // 支持lambda表達式
        pool.add_task([]{ std::cout << "Hello ThreadPool" << std::endl; });
        // 支持全局函數
        pool.add_task(test_task, str);
        // 支持函數對象
        pool.add_task(t2, str, i);
        // 支持類成員函數
        pool.add_task(&Test::print, &t, str, i);
    }

    std::cin.get();
    std::cout << "##############END###################" << std::endl;
    return 0;
}

測試程序啓動了十個線程並調用add_task函數加入了4000個任務,add_task支持普通全局函數、靜態函數、類成員函數、函數對象(仿函數)以及lambda表達式,而且支持函數傳入,該線程池的實現以及測試代碼我已經放到了github上。apache

參考資料

《深刻應用C++11--代碼優化與工程級應用》promise

相關文章
相關標籤/搜索