Tao,在英文中的意思是「The ultimate principle of universe」,即「道」,它是宇宙的終極奧義。java
「道生一,一輩子二,二生三,三生無窮。」 ——《道德經》git
Tao同時也是我用Go語言開發的一個異步的TCP服務器框架(TCP Asynchronous Go server FramewOrk),秉承Go語言「Less is more」的極簡主義哲學,它能穿透一切表象,帶你一窺網絡編程的世界,讓你今後完全擺脫只會寫「socket-bind-listen-accept」的窘境。本文將簡單討論一下這個框架的設計思路以及本身的一些思考。github
package main import ( "fmt" "net" "github.com/leesper/holmes" "github.com/leesper/tao" "github.com/leesper/tao/examples/chat" ) // ChatServer is the chatting server. type ChatServer struct { *tao.Server } // NewChatServer returns a ChatServer. func NewChatServer() *ChatServer { onConnectOption := tao.OnConnectOption(func(conn tao.WriteCloser) bool { holmes.Infoln("on connect") return true }) onErrorOption := tao.OnErrorOption(func(conn tao.WriteCloser) { holmes.Infoln("on error") }) onCloseOption := tao.OnCloseOption(func(conn tao.WriteCloser) { holmes.Infoln("close chat client") }) return &ChatServer{ tao.NewServer(onConnectOption, onErrorOption, onCloseOption), } } func main() { defer holmes.Start().Stop() tao.Register(chat.ChatMessage, chat.DeserializeMessage, chat.ProcessMessage) l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "", 12345)) if err != nil { holmes.Fatalln("listen error", err) } chatServer := NewChatServer() err = chatServer.Start(l) if err != nil { holmes.Fatalln("start error", err) } }
// ProcessMessage handles the Message logic. func ProcessMessage(ctx context.Context, conn tao.WriteCloser) { holmes.Infof("ProcessMessage") s, ok := tao.ServerFromContext(ctx) if ok { msg := tao.MessageFromContext(ctx) s.Broadcast(msg) } }
// ChatServer is the chatting server. type ChatServer struct { *tao.Server }
func (ls *logSegment)Write(p []byte) (n int, err error) { if ls.timeToCreate != nil && ls.logFile != os.Stdout && ls.logFile != os.Stderr { select { case current := <-ls.timeToCreate: ls.logFile.Close() ls.logFile = nil name := getLogFileName(current) ls.logFile, err = os.Create(path.Join(ls.logPath, name)) if err != nil { fmt.Fprintln(os.Stderr, err) ls.logFile = os.Stderr } else { next := current.Truncate(ls.unit).Add(ls.unit) ls.timeToCreate = time.After(next.Sub(time.Now())) } default: // do nothing } } return ls.logFile.Write(p) }
type Writer interface { Write(p []byte) (n int, err error) }
Tao框架支持經過tao.TLSCredsOption()函數提供傳輸層安全的TLS Server。服務器的核心職責是「監聽並接受客戶端鏈接」。每一個進程可以打開的文件描述符是有限制的,因此它還須要限制最大併發鏈接數,關鍵代碼以下:
// Start starts the TCP server, accepting new clients and creating service // go-routine for each. The service go-routines read messages and then call // the registered handlers to handle them. Start returns when failed with fatal // errors, the listener willl be closed when returned. func (s *Server) Start(l net.Listener) error { s.mu.Lock() if s.lis == nil { s.mu.Unlock() l.Close() return ErrServerClosed } s.lis[l] = true s.mu.Unlock() defer func() { s.mu.Lock() if s.lis != nil && s.lis[l] { l.Close() delete(s.lis, l) } s.mu.Unlock() }() holmes.Infof("server start, net %s addr %s\n", l.Addr().Network(), l.Addr().String()) s.wg.Add(1) go s.timeOutLoop() var tempDelay time.Duration for { rawConn, err := l.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay >= max { tempDelay = max } holmes.Errorf("accept error %v, retrying in %d\n", err, tempDelay) select { case <-time.After(tempDelay): case <-s.ctx.Done(): } continue } return err } tempDelay = 0 // how many connections do we have ? sz := s.conns.Size() if sz >= MaxConnections { holmes.Warnf("max connections size %d, refuse\n", sz) rawConn.Close() continue } if s.opts.tlsCfg != nil { rawConn = tls.Server(rawConn, s.opts.tlsCfg) } netid := netIdentifier.GetAndIncrement() sc := NewServerConn(netid, s, rawConn) sc.SetName(sc.rawConn.RemoteAddr().String()) s.mu.Lock() if s.sched != nil { sc.RunEvery(s.interv, s.sched) } s.mu.Unlock() s.conns.Put(netid, sc) addTotalConn(1) s.wg.Add(1) go func() { sc.Start() }() holmes.Infof("accepted client %s, id %d, total %d\n", sc.GetName(), netid, s.conns.Size()) s.conns.RLock() for _, c := range s.conns.m { holmes.Infof("client %s\n", c.GetName()) } s.conns.RUnlock() } // for loop }
Go語言在發佈1.7版時在標準庫中引入了context包。context包提供的Context結構可以在服務器,網絡鏈接以及各相關線程之間創建一種相關聯的「上下文」關係。這種上下文關係包含的信息是與某次網絡請求有關的(request scoped),所以與該請求有關的全部Go線程都能安全地訪問這個上下文結構,讀取或者寫入與上下文有關的數據。好比handleLoop線程會將某個網絡鏈接的net ID以及message打包到上下文結構中,而後連同handler函數一塊兒交給工做者線程去處理:
// handleLoop() - put handler or timeout callback into worker go-routines func handleLoop(c WriteCloser, wg *sync.WaitGroup) { //... omitted ... for { select { //... omitted ... case msgHandler := <-handlerCh: msg, handler := msgHandler.message, msgHandler.handler if handler != nil { if askForWorker { WorkerPoolInstance().Put(netID, func() { handler(NewContextWithNetID(NewContextWithMessage(ctx, msg), netID), c) }) } } //... omitted ... } }
隨後,在工做者線程真正執行時,業務邏輯代碼就能在handler函數中獲取到message或者net ID,這些都是與本次請求有關的上下文數據,好比一個典型的echo server就會這樣處理:
// ProcessMessage process the logic of echo message. func ProcessMessage(ctx context.Context, conn tao.WriteCloser) { msg := tao.MessageFromContext(ctx).(Message) holmes.Infof("receving message %s\n", msg.Content) conn.Write(msg) }
// Stop gracefully closes the server, it blocked until all connections // are closed and all go-routines are exited. func (s *Server) Stop() { // immediately stop accepting new clients s.mu.Lock() listeners := s.lis s.lis = nil s.mu.Unlock() for l := range listeners { l.Close() holmes.Infof("stop accepting at address %s\n", l.Addr().String()) } // close all connections conns := map[int64]*ServerConn{} s.conns.RLock() for k, v := range s.conns.m { conns[k] = v } s.conns.Clear() s.conns.RUnlock() for _, c := range conns { c.rawConn.Close() holmes.Infof("close client %s\n", c.GetName()) } s.mu.Lock() s.cancel() s.mu.Unlock() s.wg.Wait() holmes.Infoln("server stopped gracefully, bye.") os.Exit(0) }
// Start starts the server connection, creating go-routines for reading, // writing and handlng. func (sc *ServerConn) Start() { holmes.Infof("conn start, <%v -> %v>\n", sc.rawConn.LocalAddr(), sc.rawConn.RemoteAddr()) onConnect := sc.belong.opts.onConnect if onConnect != nil { onConnect(sc) } loopers := []func(WriteCloser, *sync.WaitGroup){readLoop, writeLoop, handleLoop} for _, l := range loopers { looper := l sc.wg.Add(1) go looper(sc, sc.wg) } }
/* readLoop() blocking read from connection, deserialize bytes into message, then find corresponding handler, put it into channel */ func readLoop(c WriteCloser, wg *sync.WaitGroup) { var ( rawConn net.Conn codec Codec cDone <-chan struct{} sDone <-chan struct{} setHeartBeatFunc func(int64) onMessage onMessageFunc handlerCh chan MessageHandler msg Message err error ) switch c := c.(type) { case *ServerConn: rawConn = c.rawConn codec = c.belong.opts.codec cDone = c.ctx.Done() sDone = c.belong.ctx.Done() setHeartBeatFunc = c.SetHeartBeat onMessage = c.belong.opts.onMessage handlerCh = c.handlerCh case *ClientConn: rawConn = c.rawConn codec = c.opts.codec cDone = c.ctx.Done() sDone = nil setHeartBeatFunc = c.SetHeartBeat onMessage = c.opts.onMessage handlerCh = c.handlerCh } defer func() { if p := recover(); p != nil { holmes.Errorf("panics: %v\n", p) } wg.Done() holmes.Debugln("readLoop go-routine exited") c.Close() }() for { select { case <-cDone: // connection closed holmes.Debugln("receiving cancel signal from conn") return case <-sDone: // server closed holmes.Debugln("receiving cancel signal from server") return default: msg, err = codec.Decode(rawConn) if err != nil { holmes.Errorf("error decoding message %v\n", err) if _, ok := err.(ErrUndefined); ok { // update heart beats setHeartBeatFunc(time.Now().UnixNano()) continue } return } setHeartBeatFunc(time.Now().UnixNano()) handler := GetHandlerFunc(msg.MessageNumber()) if handler == nil { if onMessage != nil { holmes.Infof("message %d call onMessage()\n", msg.MessageNumber()) onMessage(msg, c.(WriteCloser)) } else { holmes.Warnf("no handler or onMessage() found for message %d\n", msg.MessageNumber()) } continue } handlerCh <- MessageHandler{msg, handler} } } }
/* writeLoop() receive message from channel, serialize it into bytes, then blocking write into connection */ func writeLoop(c WriteCloser, wg *sync.WaitGroup) { var ( rawConn net.Conn sendCh chan []byte cDone <-chan struct{} sDone <-chan struct{} pkt []byte err error ) switch c := c.(type) { case *ServerConn: rawConn = c.rawConn sendCh = c.sendCh cDone = c.ctx.Done() sDone = c.belong.ctx.Done() case *ClientConn: rawConn = c.rawConn sendCh = c.sendCh cDone = c.ctx.Done() sDone = nil } defer func() { if p := recover(); p != nil { holmes.Errorf("panics: %v\n", p) } // drain all pending messages before exit OuterFor: for { select { case pkt = <-sendCh: if pkt != nil { if _, err = rawConn.Write(pkt); err != nil { holmes.Errorf("error writing data %v\n", err) } } default: break OuterFor } } wg.Done() holmes.Debugln("writeLoop go-routine exited") c.Close() }() for { select { case <-cDone: // connection closed holmes.Debugln("receiving cancel signal from conn") return case <-sDone: // server closed holmes.Debugln("receiving cancel signal from server") return case pkt = <-sendCh: if pkt != nil { if _, err = rawConn.Write(pkt); err != nil { holmes.Errorf("error writing data %v\n", err) return } } } } }
// handleLoop() - put handler or timeout callback into worker go-routines func handleLoop(c WriteCloser, wg *sync.WaitGroup) { var ( cDone <-chan struct{} sDone <-chan struct{} timerCh chan *OnTimeOut handlerCh chan MessageHandler netID int64 ctx context.Context askForWorker bool ) switch c := c.(type) { case *ServerConn: cDone = c.ctx.Done() sDone = c.belong.ctx.Done() timerCh = c.timerCh handlerCh = c.handlerCh netID = c.netid ctx = c.ctx askForWorker = true case *ClientConn: cDone = c.ctx.Done() sDone = nil timerCh = c.timing.timeOutChan handlerCh = c.handlerCh netID = c.netid ctx = c.ctx } defer func() { if p := recover(); p != nil { holmes.Errorf("panics: %v\n", p) } wg.Done() holmes.Debugln("handleLoop go-routine exited") c.Close() }() for { select { case <-cDone: // connectin closed holmes.Debugln("receiving cancel signal from conn") return case <-sDone: // server closed holmes.Debugln("receiving cancel signal from server") return case msgHandler := <-handlerCh: msg, handler := msgHandler.message, msgHandler.handler if handler != nil { if askForWorker { WorkerPoolInstance().Put(netID, func() { handler(NewContextWithNetID(NewContextWithMessage(ctx, msg), netID), c) }) addTotalHandle() } else { handler(NewContextWithNetID(NewContextWithMessage(ctx, msg), netID), c) } } case timeout := <-timerCh: if timeout != nil { timeoutNetID := NetIDFromContext(timeout.Ctx) if timeoutNetID != netID { holmes.Errorf("timeout net %d, conn net %d, mismatched!\n", timeoutNetID, netID) } if askForWorker { WorkerPoolInstance().Put(netID, func() { timeout.Callback(time.Now(), c.(WriteCloser)) }) } else { timeout.Callback(time.Now(), c.(WriteCloser)) } } } } }
// Handler takes the responsibility to handle incoming messages. type Handler interface { Handle(context.Context, interface{}) } // HandlerFunc serves as an adapter to allow the use of ordinary functions as handlers. type HandlerFunc func(context.Context, WriteCloser) // Handle calls f(ctx, c) func (f HandlerFunc) Handle(ctx context.Context, c WriteCloser) { f(ctx, c) } // UnmarshalFunc unmarshals bytes into Message. type UnmarshalFunc func([]byte) (Message, error) // handlerUnmarshaler is a combination of unmarshal and handle functions for message. type handlerUnmarshaler struct { handler HandlerFunc unmarshaler UnmarshalFunc } func init() { messageRegistry = map[int32]messageFunc{} buf = new(bytes.Buffer) } // Register registers the unmarshal and handle functions for msgType. // If no unmarshal function provided, the message will not be parsed. // If no handler function provided, the message will not be handled unless you // set a default one by calling SetOnMessageCallback. // If Register being called twice on one msgType, it will panics. func Register(msgType int32, unmarshaler func([]byte) (Message, error), handler func(context.Context, WriteCloser)) { if _, ok := messageRegistry[msgType]; ok { panic(fmt.Sprintf("trying to register message %d twice", msgType)) } messageRegistry[msgType] = handlerUnmarshaler{ unmarshaler: unmarshaler, handler: HandlerFunc(handler), } } // GetUnmarshalFunc returns the corresponding unmarshal function for msgType. func GetUnmarshalFunc(msgType int32) UnmarshalFunc { entry, ok := messageRegistry[msgType] if !ok { return nil } return entry.unmarshaler } // GetHandlerFunc returns the corresponding handler function for msgType. func GetHandlerFunc(msgType int32) HandlerFunc { entry, ok := messageRegistry[msgType] if !ok { return nil } return entry.handler } // Message represents the structured data that can be handled. type Message interface { MessageNumber() int32 Serialize() ([]byte, error) }
// Context is the context info for every handler function. // Handler function handles the business logic about message. // We can find the client connection who sent this message by netid and send back responses. type Context struct{ message Message netid int64 } func NewContext(msg Message, id int64) Context { return Context{ message: msg, netid: id, } } func (ctx Context)Message() Message { return ctx.message } func (ctx Context)Id() int64 { return ctx.netid }
// Codec is the interface for message coder and decoder. // Application programmer can define a custom codec themselves. type Codec interface { Decode(Connection) (Message, error) Encode(Message) ([]byte, error) }
// Codec is the interface for message coder and decoder. // Application programmer can define a custom codec themselves. type Codec interface { Decode(net.Conn) (Message, error) Encode(Message) ([]byte, error) } // TypeLengthValueCodec defines a special codec. // Format: type-length-value |4 bytes|4 bytes|n bytes <= 8M| type TypeLengthValueCodec struct{} // Decode decodes the bytes data into Message func (codec TypeLengthValueCodec) Decode(raw net.Conn) (Message, error) { byteChan := make(chan []byte) errorChan := make(chan error) go func(bc chan []byte, ec chan error) { typeData := make([]byte, MessageTypeBytes) _, err := io.ReadFull(raw, typeData) if err != nil { ec <- err close(bc) close(ec) holmes.Debugln("go-routine read message type exited") return } bc <- typeData }(byteChan, errorChan) var typeBytes []byte select { case err := <-errorChan: return nil, err case typeBytes = <-byteChan: if typeBytes == nil { holmes.Warnln("read type bytes nil") return nil, ErrBadData } typeBuf := bytes.NewReader(typeBytes) var msgType int32 if err := binary.Read(typeBuf, binary.LittleEndian, &msgType); err != nil { return nil, err } lengthBytes := make([]byte, MessageLenBytes) _, err := io.ReadFull(raw, lengthBytes) if err != nil { return nil, err } lengthBuf := bytes.NewReader(lengthBytes) var msgLen uint32 if err = binary.Read(lengthBuf, binary.LittleEndian, &msgLen); err != nil { return nil, err } if msgLen > MessageMaxBytes { holmes.Errorf("message(type %d) has bytes(%d) beyond max %d\n", msgType, msgLen, MessageMaxBytes) return nil, ErrBadData } // read application data msgBytes := make([]byte, msgLen) _, err = io.ReadFull(raw, msgBytes) if err != nil { return nil, err } // deserialize message from bytes unmarshaler := GetUnmarshalFunc(msgType) if unmarshaler == nil { return nil, ErrUndefined(msgType) } return unmarshaler(msgBytes) } }
// WorkerPool is a pool of go-routines running functions. type WorkerPool struct { workers []*worker closeChan chan struct{} } var ( globalWorkerPool *WorkerPool ) func init() { globalWorkerPool = newWorkerPool(WorkersNum) } // WorkerPoolInstance returns the global pool. func WorkerPoolInstance() *WorkerPool { return globalWorkerPool } func newWorkerPool(vol int) *WorkerPool { if vol <= 0 { vol = WorkersNum } pool := &WorkerPool{ workers: make([]*worker, vol), closeChan: make(chan struct{}), } for i := range pool.workers { pool.workers[i] = newWorker(i, 1024, pool.closeChan) if pool.workers[i] == nil { panic("worker nil") } } return pool }
// Put appends a function to some worker's channel. func (wp *WorkerPool) Put(k interface{}, cb func()) error { code := hashCode(k) return wp.workers[code&uint32(len(wp.workers)-1)].put(workerFunc(cb)) } func (w *worker) start() { for { select { case <-w.closeChan: return case cb := <-w.callbackChan: before := time.Now() cb() addTotalTime(time.Since(before).Seconds()) } } } func (w *worker) put(cb workerFunc) error { select { case w.callbackChan <- cb: return nil default: return ErrWouldBlock } }
每一個定時任務由一個timerType表示,它帶有本身的id和包含定時回調函數的結構OnTimeOut。expiration表示該任務到期要被執行的時間,interval表示時間間隔,interval > 0意味着該任務是會被週期性重複執行的任務。
/* 'expiration' is the time when timer time out, if 'interval' > 0 the timer will time out periodically, 'timeout' contains the callback to be called when times out */ type timerType struct { id int64 expiration time.Time interval time.Duration timeout *OnTimeOut index int // for container/heap } // OnTimeOut represents a timed task. type OnTimeOut struct { Callback func(time.Time, WriteCloser) Ctx context.Context } // NewOnTimeOut returns OnTimeOut. func NewOnTimeOut(ctx context.Context, cb func(time.Time, WriteCloser)) *OnTimeOut { return &OnTimeOut{ Callback: cb, Ctx: ctx, } }
// timerHeap is a heap-based priority queue type timerHeapType []*timerType func (heap timerHeapType) getIndexByID(id int64) int { for _, t := range heap { if t.id == id { return t.index } } return -1 } func (heap timerHeapType) Len() int { return len(heap) } func (heap timerHeapType) Less(i, j int) bool { return heap[i].expiration.UnixNano() < heap[j].expiration.UnixNano() } func (heap timerHeapType) Swap(i, j int) { heap[i], heap[j] = heap[j], heap[i] heap[i].index = i heap[j].index = j } func (heap *timerHeapType) Push(x interface{}) { n := len(*heap) timer := x.(*timerType) timer.index = n *heap = append(*heap, timer) } func (heap *timerHeapType) Pop() interface{} { old := *heap n := len(old) timer := old[n-1] timer.index = -1 *heap = old[0 : n-1] return timer }
func (tw *TimingWheel) update(timers []*timerType) { if timers != nil { for _, t := range timers { if t.isRepeat() { t.expiration = t.expiration.Add(t.interval) heap.Push(&tw.timers, t) } } } } func (tw *TimingWheel) start() { for { select { case timerID := <-tw.cancelChan: index := tw.timers.getIndexByID(timerID) if index >= 0 { heap.Remove(&tw.timers, index) } case tw.sizeChan <- tw.timers.Len(): case <-tw.ctx.Done(): tw.ticker.Stop() return case timer := <-tw.addChan: heap.Push(&tw.timers, timer) case <-tw.ticker.C: timers := tw.getExpired() for _, t := range timers { tw.GetTimeOutChannel() <- t.timeout } tw.update(timers) } } }
// AddTimer adds new timed task. func (tw *TimingWheel) AddTimer(when time.Time, interv time.Duration, to *OnTimeOut) int64 { if to == nil { return int64(-1) } timer := newTimer(when, interv, to) tw.addChan <- timer return timer.id } // Size returns the number of timed tasks. func (tw *TimingWheel) Size() int { return <-tw.sizeChan } // CancelTimer cancels a timed task with specified timer ID. func (tw *TimingWheel) CancelTimer(timerID int64) { tw.cancelChan <- timerID }
要使用一個鏈接來同時發送心跳和其餘業務消息,這樣一旦應用層由於出錯發不出消息,對方就可以馬上經過心跳中止感知到。值得注意的是,在Tao框架中,定時器只有一個,而客戶端鏈接可能會有不少個。在長鏈接模式下,每一個客戶端都須要處理心跳包,或者其餘類型的定時任務。將框架設計爲「每一個客戶端鏈接自帶一個定時器」是不合適的——有十萬個鏈接就有十萬個定時器,會有較高的CPU佔用率。定時器應該只有一個,全部客戶端註冊進來的定時任務都由它負責處理。可是若是全部的客戶端鏈接都等待惟一一個定時器發來的消息,就又會存在併發問題。好比client 1的定時任務到期了,但它如今正忙着處理其餘消息,這個定時任務就可能被其餘client執行。因此這裏採起了一種「先集中後分散」的處理機制:每個定時任務都由一個TimeOut結構表示,該結構中除了回調函數還包含一個context。客戶端啓動定時任務的時候都會填入net ID。TCPServer統一接收定時任務,而後從定時任務中取出net ID,而後將該定時任務交給相應的ServerConn或ClientConn去執行:
// Retrieve the extra data(i.e. net id), and then redispatch timeout callbacks // to corresponding client connection, this prevents one client from running // callbacks of other clients func (s *Server) timeOutLoop() { defer s.wg.Done() for { select { case <-s.ctx.Done(): return case timeout := <-s.timing.GetTimeOutChannel(): netID := timeout.Ctx.Value(netIDCtx).(int64) if sc, ok := s.conns.Get(netID); ok { sc.timerCh <- timeout } else { holmes.Warnf("invalid client %d\n", netID) } } } }
// ConnMap is a safe map for server connection management. type ConnMap struct { sync.RWMutex m map[int64]*ServerConn } // NewConnMap returns a new ConnMap. func NewConnMap() *ConnMap { return &ConnMap{ m: make(map[int64]*ServerConn), } } // Clear clears all elements in map. func (cm *ConnMap) Clear() { cm.Lock() cm.m = make(map[int64]*ServerConn) cm.Unlock() } // Get gets a server connection with specified net ID. func (cm *ConnMap) Get(id int64) (*ServerConn, bool) { cm.RLock() sc, ok := cm.m[id] cm.RUnlock() return sc, ok } // Put puts a server connection with specified net ID in map. func (cm *ConnMap) Put(id int64, sc *ServerConn) { cm.Lock() cm.m[id] = sc cm.Unlock() } // Remove removes a server connection with specified net ID. func (cm *ConnMap) Remove(id int64) { cm.Lock() delete(cm.m, id) cm.Unlock() } // Size returns map size. func (cm *ConnMap) Size() int { cm.RLock() size := len(cm.m) cm.RUnlock() return size } // IsEmpty tells whether ConnMap is empty. func (cm *ConnMap) IsEmpty() bool { return cm.Size() <= 0 }