使用asio搭建商用服務器

1. 背景介紹

1.1 什麼是asio

2012年從5月份開始我主持了webyy服務器項目(http://www.yy.com/webyy.html),項目中沒有按照慣例使用公司既有的基於epoll的網絡框架,而是嘗試了C++ tr2標準中的實驗網絡庫asio,不管從開發效率、程序性能、穩定性上來講,都是一次成功的嘗試。雖然是商業項目,但使用了linux、asio、protobuf等大量開源項目,開發過程共也借鑑了其餘一些開源項目,所以我決定把與公司無關的部分剝離一下,分享出來,盡到使用自由軟件的義務。html

asio由Christopher M. Kohlhoff大牛從2003年着手開發,2006年申請加入C++ tr1,2008年3月份加入boost1.35.0,按照boost與C++標準庫的發展慣例,預測很快會加入C++標準庫中。其中的async調用方式已經做爲很是重要的新特性,加入到C++0x標準庫。linux

1.2 asio的相關資料

asio官方提供了及其詳細的文檔、例子、教程,沒有必要再累贅地將其轉述一遍。若是有朋友對英文有些吃力,網上也早有不少翻譯版。這裏提供一些官方的文檔資料:web

2. 源碼參考

因爲代碼使用了一點其餘的工具,因此並無想讓讀者可以編譯經過。可是對從頭開始搭建服務器的朋友來講,必定是一份很是有價值的參考。async

2.1 做爲Client的模塊

這部分供做爲Client去鏈接其餘服務器時使用。給出的源碼中有三個類:TcpConnection, BizConnection, Client. 其中

  • TcpConnection 提供了與協議無關的tcp鏈接,異步操做的結果以虛函數方式供派生類使用

  • BizConnection 繼承自TcpConnection,使用具體的協議解析報文

  • Client使用BizConnection,並提供了等待具體某條消息的wait_for、心跳、延遲等功能

2.1.1 tcpconnection.h

#ifndef TCPCONNECTION_H
#define TCPCONNECTION_H

/**
 * @author yurunsun@gmail.com
 */

#include <asio.hpp>
#include <asio/deadline_timer.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/timer.hpp>
#include <sstream>
#include "safehandler.h"

class TcpConnection
        : public boost::enable_shared_from_this<TcpConnection>
        , private boost::noncopyable
{
public:
    typedef std::vector<uint8_t> DataBuffer;
    typedef boost::shared_ptr<TcpConnection> TcpPtr;
    static TcpPtr create(asio::io_service& io_service, const string& name){return TcpPtr(new TcpConnection(io_service, name));}
    virtual ~TcpConnection();
    void start(const string& ip, const string& port);
    void start(unsigned ip, uint16_t port);
    void start(const string& ip, uint16_t port);
    void stop();
    bool isConnected() {return m_socket.is_open();}

    /// Getters and Setters
    void setName(const string& name) {m_name = name;}
    const string& getName() {return m_name;}
    void setHeadLength(uint32_t size) {m_headLength = size;}
    uint32_t getHeadLength() {return m_headLength;}
    void setConnectTimeoutSec(uint32_t sec) {m_connectTimeoutSec = sec;}
    uint32_t getConnectTimeoutSec() {return m_connectTimeoutSec;}
    const string& getip() {return m_ip;}
    uint16_t getport() {return m_port;}
    string getFarpointInfo() { stringstream ss; ss << m_name << " " << m_ip << ":" << m_port << " "; return ss.str(); }

protected:
    explicit TcpConnection(asio::io_service& io_service, const string& name);

    /// Provide for derived class
    void connect(asio::ip::tcp::endpoint endpoint);
    void receiveHead();
    void receiveBody(uint32_t bodyLength);
    void send(const void *data, uint32_t length);

    /// Class override callbacks
    virtual void onConnectSuccess() { assert(false); }
    virtual void onConnectFailure(const asio::error_code& e) { (void)e; assert(false); }
    virtual void onReceiveHeadSuccess(DataBuffer& data) { (void)data; assert(false); }
    virtual void onReceiveBodySuccess(DataBuffer& data) { (void)data; assert(false); }
    virtual void onReceiveFailure(const asio::error_code& e) { (void)e; assert(false); }
    virtual void onSendSuccess() { assert(false); }
    virtual void onSendFailure(const asio::error_code& e) { (void)e; assert(false); }
    virtual void onTimeoutFailure(const asio::error_code& e) { (void)e; assert(false); }
    virtual void onCommonError(uint32_t ec, const string& em) { (void)ec; (void)em; assert(false); }

private:
    void checkDeadline(const asio::error_code& e);
    void handleConnect(const asio::error_code &e);
    void handleReceiveHead(const asio::error_code& e);
    void handleReceiveBody(const asio::error_code& e);
    void handleSend(const asio::error_code& e);

    typedef TcpConnection this_type;
    asio::ip::tcp::socket m_socket;
    asio::deadline_timer m_deadline;
    bool m_stopped;
    DataBuffer m_readBuf;
    string m_name;
    boost::shared_ptr<Probe> m_probe;
    uint32_t m_headLength;
    uint32_t m_connectTimeoutSec;
    string m_ip;
    uint16_t m_port;
};

#endif // TCPCONNECTION_H

2.1.2 tcpconnection.cpp

#include "stdafx.h"
#include "tcpconnection.h"

using asio::ip::tcp;

TcpConnection::TcpConnection(asio::io_service &io_service, const string& name)
    : m_socket(io_service)
    , m_deadline(io_service)
    , m_stopped(false)
    , m_name(name)
    , m_probe(new Probe)
    , m_headLength(10)
    , m_connectTimeoutSec(5)
    , m_ip("")
    , m_port(0)
{
}

TcpConnection::~TcpConnection()
{
    stop();
}

void TcpConnection::start(const string &ip, const string &port)
{
    start(ip, atoi(port.c_str()));
}

void TcpConnection::start(unsigned ip, uint16_t port)
{
    connect(tcp::endpoint(asio::ip::address_v4(ip), port));
}

void TcpConnection::start(const string &ip, uint16_t port)
{
    asio::ip::address_v4 addr_v4 = asio::ip::address_v4::from_string(ip);
    connect(tcp::endpoint(addr_v4, port));
}

void TcpConnection::stop()
{
    if (!m_stopped) {
        m_stopped = true;
        try {
            m_readBuf.clear();
            asio::error_code ignored;
            m_socket.shutdown(tcp::socket::shutdown_both, ignored);
            m_socket.close(ignored);
            m_deadline.cancel();
        } catch (const asio::system_error& err) {
            FATAL("asio::system_error em %s", err.what());
        }
    }
}

void TcpConnection::connect(tcp::endpoint endpoint)
{
    m_stopped = false;
    m_ip = endpoint.address().to_string();
    m_port = endpoint.port();
    INFO("Trying connect %s:%u ...%s", STR(m_ip), m_port, STR(m_name));

    m_deadline.expires_from_now(boost::posix_time::seconds(m_connectTimeoutSec));
    m_socket.async_connect(endpoint, boost::bind(&TcpConnection::handleConnect, shared_from_this(), asio::placeholders::error));
    m_deadline.async_wait(SafeHandler1<this_type, const asio::error_code&>(&this_type::checkDeadline, this, m_probe));
}

void TcpConnection::receiveHead()
{
    m_readBuf.resize(m_headLength);
    asio::async_read(m_socket, asio::buffer(&m_readBuf[0], m_headLength),
                     boost::bind(&TcpConnection::handleReceiveHead, shared_from_this(), asio::placeholders::error));
}

void TcpConnection::receiveBody(uint32_t bodyLength)
{
    if (!m_stopped){
        if ((bodyLength <= MAX_BUFFER_SIZE) && (bodyLength > 0)) {
            m_readBuf.resize(bodyLength + m_headLength);
            asio::async_read(m_socket, asio::buffer(&m_readBuf[m_headLength], bodyLength),
                             boost::bind(&TcpConnection::handleReceiveBody, shared_from_this(), asio::placeholders::error));
        } else {
            onCommonError(S_FATAL, "illegal bodyLength to call receiveBody");
        }
    } else {
        onCommonError(S_ERROR, "illegal to call receiveBody while tcp is not connected");
    }
}

void TcpConnection::send(const void* data, uint32_t length)
{
    if (!m_stopped) {
        if (length <= MAX_BUFFER_SIZE) {
            asio::async_write(m_socket, asio::const_buffers_1(data, length),
                              boost::bind(&TcpConnection::handleSend, shared_from_this(), asio::placeholders::error));
        } else {
            onCommonError(S_ERROR, "too big length to call send");
        }
    } else {
        onCommonError(S_ERROR, "illegal to call send while tcp is not connected");
    }
}

/**
 * @brief TcpConnection::checkDeadline
 * @param e
 * case1: m_stopped == true which means user canceled
 * case2: m_deadline.expires_at() <= asio::deadline_timer::traits_type::now()
 *          Check whether the deadline has passed. We compare the deadline against
            the current time since a new asynchronous operation may have moved the
            deadline before this actor had a chance to run.
 */
void TcpConnection::checkDeadline(const asio::error_code& e)
{
    if (!m_stopped) {
        if (m_deadline.expires_at() <= asio::deadline_timer::traits_type::now()) {
            onTimeoutFailure(e);
        }
    }
}

void TcpConnection::handleConnect(const asio::error_code &e)
{
    if (!m_stopped) {
        if (!e) {
            m_deadline.cancel();
            onConnectSuccess();
        } else {
            onConnectFailure(e);
        }
    } else {
        INFO("%s %s %u user's canceled by stop()", STR(m_name), STR(m_ip), m_port);
    }
}

void TcpConnection::handleReceiveHead(const asio::error_code &e)
{
    if (!m_stopped) {
        if (!e) {
            onReceiveHeadSuccess(m_readBuf);
        } else if (isConnected()){
            onReceiveFailure(e);
        }
    } else {
        INFO("%s %s %u user's canceled by stop()", STR(m_name), STR(m_ip), m_port);
    }
}

void TcpConnection::handleReceiveBody(const asio::error_code &e)
{
    if (!m_stopped) {
        if (!e) {
            onReceiveBodySuccess(m_readBuf);
        } else if (isConnected()){
            onReceiveFailure(e);
        }
    } else {
        INFO("%s %s %u user's canceled by stop()", STR(m_name), STR(m_ip), m_port);
    }
}

void TcpConnection::handleSend(const asio::error_code &e)
{
    if (!m_stopped) {
        if (!e) {
            //onSendSuccess();
        } else if (isConnected()){
            onSendFailure(e);
        }
    } else {
        INFO("%s %s %u user's canceled by stop()", STR(m_name), STR(m_ip), m_port);
    }
}

2.1.3 bizconnection.h

#ifndef BIZCONNECTION_H
#define BIZCONNECTION_H

/**
 * @author yurunsun@gmail.com
 */

#include <asio.hpp>
#include <cstdio>
#include <stdexcept>
#include "sigslot/sigslot.h"

#include "tcpconnection.h"

class BizConnection
        : public TcpConnection
{
public:
    typedef boost::shared_ptr<BizConnection> BizPtr;
    static BizPtr create(asio::io_service& io_service, const string& name = string(""))
    {
        return BizPtr(new BizConnection(io_service, name));
    }
    void sendBizMsg(uint32_t uri, const BizPackage& pkg);

    sigslot::signal0<> BizConnected;
    sigslot::signal2<uint32_t, const string&> BizError;
    sigslot::signal0<> BizClosed;
    sigslot::signal1<BizPackage&> BizMsgArrived;

protected:
    explicit BizConnection(asio::io_service& io_service, const string& name);

    /// Implement callbacks in base class
    virtual void onConnectSuccess();
    virtual void onConnectFailure(const asio::error_code& e);
    virtual void onReceiveHeadSuccess(DataBuffer& data);
    virtual void onReceiveBodySuccess(DataBuffer& data);
    virtual void onReceiveFailure(const asio::error_code& e);
    virtual void onSendSuccess();
    virtual void onSendFailure(const asio::error_code& e);
    virtual void onTimeoutFailure(const asio::error_code &e);
    virtual void onCommonError(uint32_t ec, const string &em);

    static void initNeedErrorSet();
    static std::set<uint32_t> m_needError;

private:
    inline bool peekLength(void* data, uint32_t length, uint32_t& outputi32);
    void handleError(const string& from, uint32_t ec, const asio::error_code& e = asio::error_code());
};

inline bool BizConnection::peekLength(void* data, uint32_t length, uint32_t& outputi32)
{
    if (length >= 4) {
        memcpy(&outputi32, data, sizeof(uint32_t));
        return true;
    } else {
        return false;
    }
}

#endif // BIZCONNECTION_H

2.1.4 bizconnection.cpp

#include "stdafx.h"
#include "bizconnection.h"
#include "tcpconnection.h"

std::set<uint32_t> BizConnection::m_needError;

BizConnection::BizConnection(asio::io_service &io_service, const string& name)
    : TcpConnection(io_service, name)
{
}

void BizConnection::sendBizMsg(unsigned uri, const BizPackage &pkg)
{
    /// TODO: usually this BizPackage contains the buffer of stream data to be sent to farpoint
    /// You should implement this by retriving buffer in BizPackage then call TcpConnection::send();
}

void BizConnection::onConnectSuccess()
{
    receiveHead();
    BizConnected.emit();
}

void BizConnection::onConnectFailure(const asio::error_code &e)
{
    if (e == asio::error::operation_aborted) {
        INFO("%s %s %u operation aborted... %s", STR(getName()), STR(getip()), getport(), STR(e.message()));
    }
    else if ((e == asio::error::already_connected) || (e == asio::error::already_open) || (e == asio::error::already_started)) {
        WARN("%s %s %u alread connected... %s", STR(getName()), STR(getip()), getport(), STR(e.message()));
    } else {
        handleError("onConnectFailure", S_FATAL, e);
    }
}

void BizConnection::onReceiveHeadSuccess(TcpConnection::DataBuffer &data)
{
    uint32_t pkglen = 0;
    if (peekLength(data.data(), data.size(), pkglen)) {
        receiveBody(pkglen - getHeadLength());
    } else {
        handleError("peekLength", S_FATAL);
    }
}

void BizConnection::onReceiveBodySuccess(TcpConnection::DataBuffer &data)
{
    /// This is simply an example, actually it's user's duty to unmarshal buffer to package.
    BizPackage msg;
    BizPackage.unserializeFrom(data);
    BizMsgArrived.emit(msg);
    receiveHead();
}

void BizConnection::onReceiveFailure(const asio::error_code &e)
{
    if ((e == asio::error::operation_aborted)) {
        INFO("%s operation_aborted... %s", STR(getFarpointInfo()), STR(e.message()));
    } else {
        handleError("onReceiveFailure", S_FATAL, e);
    }
}

void BizConnection::onSendSuccess()
{
    /// Leave it blank
}

void BizConnection::onSendFailure(const asio::error_code &e)
{
    if ((e == asio::error::operation_aborted)) {
        INFO("%s operation_aborted... %s", STR(getFarpointInfo()), STR(e.message()));
    } else {
        handleError("onSendFailure", S_FATAL, e);
    }
}

void BizConnection::onTimeoutFailure(const asio::error_code &e)
{
    if ((e == asio::error::operation_aborted)) {
        INFO("%s operation_aborted... %s", STR(getFarpointInfo()), STR(e.message()));
    } else {
        handleError("onTimeoutFailure", S_FATAL);
    }
}

void BizConnection::onCommonError(uint32_t ec, const string &em)
{
    handleError(em, ec);
}

void BizConnection::handleError(const string& from, uint32_t ec, const asio::error_code &e/* = asio::error_code()*/)
{
    stringstream ss;
    ss << getFarpointInfo() << " " << from;
    if (e) {
        ss << " asio " << e.value() << " " << e.message();
    }
    BizError.emit(ec, ss.str());
}

2.1.5 client.h

#ifndef CLIENT_H
#define CLIENT_H

/**
 * @author yurunsun@gmail.com
 */

#include "bizconnection.h"
#include "handler.h"
#include "safehandler.h"

#include <asio.hpp>
#include <boost/timer.hpp>

class Client
        : public sigslot::has_slots<>
{
protected:
    BizConnection::BizPtr m_pBizConnection;

public:
    typedef void (Client::*RequestPtr)(BizPackage&);
    typedef std::map<uint32_t, RequestPtr> RequestMap;

    typedef void (Client::*NotifyPtr)(BizPackage&);
    typedef std::map<uint32_t, NotifyPtr> NotifyMap;

    explicit Client(const string& name = string(""));

    /// 繼承類須要實現的提供外部的方法
    virtual void startServer() = 0;
    virtual bool sendToServer(YProto &proto) = 0;
    virtual void stopServer() {clearWaitforTimer(); m_pBizConnection->stop();}

protected:
    /// 繼承類須要實現的初始化函數
    virtual void initRequestMap() {assert(false);}
    virtual void initNotifyMap() {assert(false);}
    virtual void initSignal();

    /// 繼承類須要實現的鉤子函數,用於處理網絡事件
    virtual void onBizMsgArrived(core::Request& msg) = 0;
    virtual void onBizError(uint32_t ec, const string& em);
    virtual void onBizConnected() = 0;

    /// 繼承類可使用的工具方法
    /// 1. 心跳類
    void setKeepAliveSec(uint32_t sec) {m_keepAliveSec = sec;}
    uint32_t getKeepAliveSec() {return m_keepAliveSec;}
    void startKeepAlive();
    void keepAlive(const asio::error_code& e);
    virtual void onKeepAlive() {assert(false);}
    /// 2. 登錄狀態類
    void setHasLogin(bool b) {m_hasLogin = b;}
    bool getHasLogin() {return m_hasLogin;}
    bool isOnline() { return (m_pBizConnection->isConnected() && m_hasLogin);}
    /// 3. 消息保存類
    template <typename Handler>
    bool savePendingCommand(Handler handler)
    {
        if(m_pendingCmd.size() < MaxPendingCommandCount) {
            m_pendingCmd.push_back(Command(handler));
            return true;
        }
        return false;
    }
    void sendPendingCommand()
    {
        if (isOnline()) {
            vector<Command>::iterator it = m_pendingCmd.begin();
            for(; it != m_pendingCmd.end(); ++it) {
                (*it)();
            }
            m_pendingCmd.clear();
        }
    }
    /// 4. 延遲處理類
    typedef void (Client::*HoldonCallback)();
    void holdonSeconds(uint32_t sec, HoldonCallback func);
    void holdonHandler(HoldonCallback func, const asio::error_code &e);

    /// 5. waitfor 工具 處理異步消息超時
    typedef boost::shared_ptr<asio::deadline_timer> SharedTimerPtr;
    typedef boost::scoped_ptr<asio::deadline_timer> ScopedTimerPtr;
    typedef boost::shared_ptr<Probe> SharedProbe;
    typedef map<uint32_t, SharedTimerPtr> Uri2Timer;            /// 等待收到的包uri --> 這個時間timer
    void waitfor(uint32_t uri, uint32_t sec);                   /// 在發送req的時候調用,sec 秒數 uri 等待收到的uri
    void waitforTimeout(uint32_t uri, const asio::error_code& e); /// 全部waitfor超時都會自動回調這個函數
    virtual void onWaitforTimeout(uint32_t uri) {(void)uri; assert(false);}         /// 繼承類覆蓋這個鉤子函數來進行錯誤處理
    void waitforReceived(uint32_t uri);                         /// 當響應函數handler被回調時,記得調用waitforReceived作清理工做
    void eraseWaitforTimer(uint32_t uri);
    void clearWaitforTimer();

    /// 繼承類可使用的工具成員:心跳 探針 請求阻塞
    typedef std::set<uint32_t> BlockReq;
    BlockReq m_block;
    SharedProbe m_probe;
    ScopedTimerPtr m_timer;
    uint32_t m_keepAliveSec;
    bool m_hasLogin;
    vector<Command> m_pendingCmd;
    static const uint32_t MaxPendingCommandCount = 20;
    ScopedTimerPtr m_holdonTimer;
    Uri2Timer m_uri2timer;
};


#define BIND_REQ(m, uri, callback) \
    m[static_cast<uint32_t>(uri)] = static_cast<RequestPtr>(callback);

#define BIND_NOTIFY(m, uri, callback) \
    m[static_cast<uint32_t>(uri)] = static_cast<NotifyPtr>(callback);

#endif // CLIENT_H

2.1.6 client.cpp

#include "stdafx.h"
#include "client.h"

Client::Client(const string& name)
    : m_pBizConnection(BizConnection::create(ioService::instance(), name))
    , m_probe(new Probe)
    , m_keepAliveSec(10)
    , m_hasLogin(false)
{
}

void Client::initSignal()
{
    m_pBizConnection->BizError.connect(this, &Client::onBizError);
    m_pBizConnection->BizMsgArrived.connect(this, &Client::onBizMsgArrived);
    m_pBizConnection->BizConnected.connect(this, &Client::onBizConnected);
}

void Client::onBizError(uint32_t ec, const string &em)
{
    m_facade.serverError.emit(ec, em);
}

void Client::startKeepAlive()
{
    m_timer.reset(new asio::deadline_timer(m_facade.io_service_ref));
    m_timer->expires_from_now(boost::posix_time::seconds(m_keepAliveSec));
    m_timer->async_wait(SafeHandler1<Client, const asio::error_code&>(&Client::keepAlive, this, m_probe));
}

void Client::keepAlive(const asio::error_code &e)
{
    if (e != asio::error::operation_aborted) {
        FINE("%u send ping to %s %s:%u", m_facade.m_pInfo->uid, STR(m_pBizConnection->getName()), STR(m_pBizConnection->getip()), m_pBizConnection->getport());
        onKeepAlive();
        m_timer->expires_from_now(boost::posix_time::seconds(m_keepAliveSec ));
        m_timer->async_wait(SafeHandler1<Client, const asio::error_code&>(&Client::keepAlive, this, m_probe));
    }
}

void Client::holdonSeconds(uint32_t sec, HoldonCallback func)
{
    m_holdonTimer.reset(new asio::deadline_timer(m_facade.io_service_ref));
    m_holdonTimer->expires_from_now(boost::posix_time::seconds(sec));
    SafeHandler1Bind1<Client, HoldonCallback, const asio::error_code&> h(&Client::holdonHandler, this, func, m_probe);
    m_holdonTimer->async_wait(h);
}

void Client::holdonHandler(HoldonCallback func, const asio::error_code &e)
{
    if (!e) {
        if (m_holdonTimer != NULL)
            m_holdonTimer->cancel();
        (this->*func)();
    } else {
         WARN("error: %s", STR(e.message()));
    }
}

void Client::waitfor(uint32_t uri, uint32_t sec)
{
    SharedTimerPtr t(new asio::deadline_timer(m_facade.io_service_ref));
    t->expires_from_now(boost::posix_time::seconds(sec));
    t->async_wait(SafeHandler1Bind1<Client, uint32_t, const asio::error_code&>(
                      &Client::waitforTimeout, this, uri, m_probe));
    m_uri2timer[uri] = t;
}

void Client::waitforTimeout(uint32_t uri, const asio::error_code &e)
{
    if (e != asio::error::operation_aborted) {
        FATAL("%s waitfor uri %u timeout", STR(m_pBizConnection->getName()), uri);
        eraseWaitforTimer(uri);
        onWaitforTimeout(uri);
    }
}

void Client::waitforReceived(uint32_t uri)
{
    eraseWaitforTimer(uri);
}

void Client::eraseWaitforTimer(uint32_t uri)
{
    Uri2Timer::iterator it = m_uri2timer.find(uri);
    if (it != m_uri2timer.end()) {
        SharedTimerPtr& t = it->second;
        if (t) {
            asio::error_code e;
            t->cancel(e);
            t.reset();
        }
        m_uri2timer.erase(it);
    }
}

void Client::clearWaitforTimer()
{
    Uri2Timer::iterator it = m_uri2timer.begin();
    for (; it != m_uri2timer.end(); ++it) {
        SharedTimerPtr& t = it->second;
        if (t) {
            asio::error_code e;
            t->cancel(e);
            t.reset();
        }
    }
    m_uri2timer.clear();
}

2.2 做爲server模塊

做爲server模塊因爲涉及公司的業務比較多,這裏剝離出一個做爲crossdomain服務器的部分,功能很簡單:flash客戶端經過socket請求crossdomain配置文件,server返回給定的字符串。這裏使用了比較著名的pimpl模式,將實現徹底隱藏在cpp文件中。

2.2.1 crossdomain.h

#ifndef CROSSDOMAIN_H
#define CROSSDOMAIN_H

#include <string>
#include <boost/shared_ptr.hpp>
#include <asio.hpp>

/**
 * @author yurunsun@gmail.com
 */

class CrossDomain
{
private:
    struct Server;
    boost::shared_ptr<Server> m_pserver;
    CrossDomain(asio::io_service& io_service, const std::string& local_port);
    static CrossDomain* s_instance;

public:
    static void create(asio::io_service& io_service, const std::string& local_port)
    {
        s_instance = new CrossDomain(io_service, local_port);
    }

    static CrossDomain* instance();
    void start_server();
};

#endif // CROSSDOMAIN_H

2.2.2 crossdomain.cpp

#include "stdafx.h"
#include "crossdomain.h"

using asio::ip::tcp;
using boost::uint8_t;
CrossDomain* CrossDomain::s_instance = NULL;

struct CrossDomainImpl : public boost::enable_shared_from_this<CrossDomainImpl>
{
public:
    static const unsigned MaxReadSize = 22;
    typedef boost::shared_ptr<CrossDomainImpl> CrossDomainImplPtr;
    static CrossDomainImplPtr create(asio::io_service& io_service) {
        return CrossDomainImplPtr(new CrossDomainImpl(io_service));
    }

    tcp::socket& get_socket() {
        return m_socket;
    }

    void start() {
        start_read_some();
    }

    ~CrossDomainImpl() {
        close();
    }

    void close() {
        if (m_socket.is_open()) {
            m_socket.close();
        }
    }

private:
    CrossDomainImpl(asio::io_service& io_service)
        : m_socket(io_service)
    {
    }

    void start_read_some() {
        m_socket.async_read_some(asio::buffer(m_readbuf, MaxReadSize),
            boost::bind(&CrossDomainImpl::handle_read_some, shared_from_this(), asio::placeholders::error()));
    }

    void handle_read_some(const asio::error_code& err) {
        if (!err) {
            string str(m_readbuf);
            string reply("invalid");
            if (str == "<policy-file-request/>") {
                reply = "anything you wanna send back to client...";
            }
            asio::async_write(m_socket, asio::buffer(ref),
                    boost::bind(&CrossDomainImpl::handle_write, shared_from_this(), asio::placeholders::error));
        }
    }

    void handle_write(const asio::error_code& error) {
        FINE("CrossDomain handle_write, gonna close");
        close();
    }

    tcp::socket m_socket;
    char m_readbuf[MaxReadSize];
};

struct CrossDomain::Server
{
private:
    CrossDomain *m_facade;
    tcp::acceptor m_acceptor;
    bool m_listened;
    string m_local_port;

public:
    Server(asio::io_service& io_service, const string &local_port)
        : m_acceptor(io_service)
        , m_listened(false)
        , m_local_port(local_port)
    {
        // intend to leave it blank
    }
    ~Server() {
        if (m_acceptor.is_open()) {
            INFO("close server acceptor");
            m_acceptor.close();
        }
    }

    void start_server() {
        FINE("CrossDomain start_server....");
        if (!m_listened) {
            FINE("Try to listen...");
            try {
                tcp::endpoint ep(tcp::endpoint(tcp::v4(), atoi(m_local_port.c_str())));
                m_acceptor.open(ep.protocol());
                m_acceptor.bind(ep);
                m_acceptor.listen();
            } catch (const asio::system_error& ec) {
                WARN("Port %s already in use! Fail to listen...", STR(m_local_port));
                return;
            } catch (...) {
                WARN("Unknown error while trying to listen...");
                return;
            }
            m_listened = true;
            FINE("Listen port %s succesfully!", STR(m_local_port));
        }

        CrossDomainImpl::CrossDomainImplPtr new_server_impl = CrossDomainImpl::create(m_acceptor.get_io_service());
        m_acceptor.async_accept(new_server_impl->get_socket(),
            boost::bind(&Server::handle_accept, this, new_server_impl, asio::placeholders::error));
    }

private:
    void handle_accept(CrossDomainImpl::CrossDomainImplPtr pserver_impl, const asio::error_code& err) {
        FINE("CrossDomain handle_accpet....");
        if (!err) {
            FINE("CrossDomain everything ok, start...");
            pserver_impl->start();    // start this server
            start_server();           // waiting for another Tuna Connection
        } else {
            pserver_impl->close();
        }
    }
};

CrossDomain::CrossDomain(asio::io_service &io_service, const std::string &local_port)
    : m_pserver(new Server(io_service, local_port))
{
}

CrossDomain *CrossDomain::instance()
{
    if (!s_instance) {
        return NULL;
    }
    return s_instance;
}

void CrossDomain::start_server()
{
    m_pserver->start_server();
}

3. 使用asio的陷阱

上邊代碼其實有幾點漏洞:

3.1 std::vector<uint_8>不適合做爲buffer

vector<uint8_t>不適合作buffer的緣由是,sgi的內存分配器會以2倍的形式增加vector的內存,例如這個buffer要求100K,但當前vector的capability只有90K,那麼sgi默認內存分配器會將vector的capability增加到180K。注意capability與size的區別。這就致使vector的內存佔用依賴最大buffer的size,這是很危險的。

推薦使用boost的circular_buffer做爲buffer,能有效避免內存碎片、隱式內存泄露等問題。

3.2 asio::const_buffer拷貝構造函數沒有深拷貝

const_buffer系列靜態buffer只能從mutable_buffermerge過來,可是從const_buffer的拷貝構造函數源碼能看到,他並不對buffer作深拷貝。因此試圖將其放到隊列或者容器中,期待產生buffer的拷貝,是錯誤的。

3.3. async_write可能會拆包發送

例如先調用async_write發送一個100K的大包,再立刻調用async_write發送一個8字節的ping包,很是可能出現問題。async_write函數的實現是循環調用async_write_some,對於大包會將其拆分紅幾個小報文。若是此時收到用戶一個新的async_write調用,很是可能將小包夾在大包的幾個部分中間發送,致使接收端出現異常。

解決的辦法能夠直接操做async_write_some,代替async_write。但更方便的辦法是建立一個發送隊列。實際上asio會準確地將發送成功的通知發送給用戶,例如剛剛100K的打包,直到全部100K所有發送完成,纔會調用handle回調。所以能夠在發送時將報文入隊列,回調函數裏將報文出隊列,發送下一個小報時判斷隊列是否爲空,若是非空說明100K的包尚未發完。示例代碼以下:

///發送時入隊列
void TcpConnection::send(const void* data, uint32_t length)
{
    if (!m_stopped) {
        if (length <= MAX_BUFFER_SIZE) {
            const char* begin = (const char*)data;
            vector<char> vec(begin, begin + length);

            bool isLastComplete = m_bufQueue.empty();
            m_bufQueue.push_back(vec);
            /// 若是沒有殘餘的包,就直接發送
            if (isLastComplete) {
                vector<char>& b(m_bufQueue.front());
                send(b);
            }
        } else {
            onCommonError(S_ERROR, "too big length to call send");
        }
    } else {
        onCommonError(S_ERROR, "illegal to call send while tcp is not connected");
    }
}

void TcpConnection::send(const std::vector<char>& vec)
{
    if (!m_stopped) {
        asio::async_write(m_socket, asio::buffer(&vec[0], vec.size()), asio::transfer_all(), 
                          boost::bind(&TcpConnection::handleSend, shared_from_this(), asio::placeholders::error));

    } else {
        onCommonError(S_ERROR, "illegal to call send while tcp is not connected");
    }
}

///回調函數將以前的buffer出隊列,同時檢查是否有後來的包
void TcpConnection::handleSend(const asio::error_code &e)
{
    if (!m_stopped) {
        if (!e) {
            m_bufQueue.pop_front();
            if (!m_bufQueue.empty()) {
                std::vector<char>& b(m_bufQueue.front());
                send(b);
            }
            //onSendSuccess();
        } else if (isConnected()){
            onSendFailure(e);
        }
    } else {
        INFO("%s %s %u user's canceled by stop()", STR(m_name), STR(m_ip), m_port);
    }
}
相關文章
相關標籤/搜索