frp源碼剖析-frp中的mux模塊

前言

frp幾乎全部的鏈接處理都是構建在mux模塊之上的,重要性沒必要多說,來看一下這是個啥吧git

ps: 安裝方法github

go get "github.com/fatedier/golib/net/mux"

該模塊很小,不到300行,分爲兩個文件:mux.gorule.go
由於rule.go文件相對簡單一些,咱們先來看這個。golang

role.go文件

首先看其中所命名的函數類型MatchFunc算法

type MatchFunc func(data []byte) (match bool)

該類型的函數用來判斷data屬於什麼協議。編程

那麼具體如何判斷呢,這裏也實現了三個例子:網絡

var (
    HttpsNeedBytesNum uint32 = 1
    HttpNeedBytesNum  uint32 = 3
    YamuxNeedBytesNum uint32 = 2
)

var HttpsMatchFunc MatchFunc = func(data []byte) bool {
    if len(data) < int(HttpsNeedBytesNum) {
        return false
    }

    if data[0] == 0x16 {
        return true
    } else {
        return false
    }
}

// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
var httpHeadBytes = map[string]struct{}{
    "GET": struct{}{},
    "HEA": struct{}{},
    "POS": struct{}{},
    "PUT": struct{}{},
    "DEL": struct{}{},
    "CON": struct{}{},
    "OPT": struct{}{},
    "TRA": struct{}{},
    "PAT": struct{}{},
}

var HttpMatchFunc MatchFunc = func(data []byte) bool {
    if len(data) < int(HttpNeedBytesNum) {
        return false
    }

    _, ok := httpHeadBytes[string(data[:3])]
    return ok
}

// From https://github.com/hashicorp/yamux/blob/master/spec.md
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
    if len(data) < int(YamuxNeedBytesNum) {
        return false
    }

    if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
        return true
    }
    return false
}

這三個函數分別實現了區分HTTPS,HTTP以及go中特有的yamux(實際上這是一個庫,能夠參考Go中的I/O多路複用)。app

mux.go文件

先來看其中的struct,第一個是Mux第二個是listener,這裏先來看一下較爲簡單的listenertcp

listener結構體

type listener struct {
    mux *Mux

    priority     int
    needBytesNum uint32
    matchFn      MatchFunc

    c  chan net.Conn
    mu sync.RWMutex
}

// Accept waits for and returns the next connection to the listener.
func (ln *listener) Accept() (net.Conn, error) {
    ...
}

// Close removes this listener from the parent mux and closes the channel.
func (ln *listener) Close() error {
    ...
}

func (ln *listener) Addr() net.Addr {
    ...
}

剛看到這個結構體咱們可能很迷惑,不知道都是幹啥的,並且網絡編程中通常listener這種東西要綁定在一個套接字上,但很明顯listener沒有,不過其惟一跟套接字相關的多是其c字段,其是一個由net包中的Conn接口組成的chanel;而後mu字段就是讀寫鎖了,這個很簡單;而後mux字段則是上面提到的兩個結構體中的另外一個結構體Mux的指針;接下來到了priority字段上,顧名思義,這個彷佛跟優先級有關係,暫且存疑;needBytesNum則更有些蒙了,不過感受其是跟讀取byte的數量有關係,最後是matchFn函數

好,初步認識了這個結構體的結構後,咱們看看其方法。三個方法的listener實現了net模塊中的Listener接口:ui

// A Listener is a generic network listener for stream-oriented protocols.
//
// Multiple goroutines may invoke methods on a Listener simultaneously.
type Listener interface {
    // Accept waits for and returns the next connection to the listener.
    Accept() (Conn, error)

    // Close closes the listener.
    // Any blocked Accept operations will be unblocked and return errors.
    Close() error

    // Addr returns the listener's network address.
    Addr() Addr
}

而後先來分析其Accept方法:

func (ln *listener) Accept() (net.Conn, error) {
    conn, ok := <-ln.c
    if !ok {
        return nil, fmt.Errorf("network connection closed")
    }
    return conn, nil
}

該方法很簡單,就是從c這個由Conn組成的channel中,獲取Conn對象,好這裏咱們就明白了,這個listener和普通的不同,他很特別,普通的listener監聽的是套接字,而他監聽的是channel,另外,確定有某個地方在不停的往c這個channel中放Conn

接下來是Close方法:

func (ln *listener) Close() error {
    if ok := ln.mux.release(ln); ok {
        // Close done to signal to any RLock holders to release their lock.
        close(ln.c)
    }
    return nil
}

咱們暫且先把這個ln.mux.release(ln)放到一邊,由於還不知道這個東西幹了啥,暫且只需關注close(ln.c),咱們知道這個函數是用來關閉channel的,go推薦由發送端調用,但這裏彷佛listener是一個消費端,能夠看一下如何優雅的關閉Go Channel,看來重點在於ln.mux.release(ln)這裏,咱們暫且存疑[1],留待下面解決。

最後是Addr方法:

func (ln *listener) Addr() net.Addr {
    if ln.mux == nil {
        return nil
    }
    ln.mux.mu.RLock()
    defer ln.mux.mu.RUnlock()
    if ln.mux.ln == nil {
        return nil
    }
    return ln.mux.ln.Addr()
}

在這裏,mu字段就用上了,加讀鎖,而後返回mux字段中的ln字段的Addr方法。也就是這句return ln.mux.ln.Addr()

Mux結構體

字段以及相關函數

Mux結構體則相對來講複雜不少,先來看一下他的字段定義:

type Mux struct {
    ln net.Listener

    defaultLn *listener

    // sorted by priority
    lns             []*listener
    maxNeedBytesNum uint32

    mu sync.RWMutex
}

好,第一個字段ln是一個Listener接口;而後defaultLn是一個listener的指針;lns則是由listener的指針組成的切片,根據註釋// sorted by priority,咱們終於知道listenerpriority字段是幹啥的了;接下來是maxNeedBytesNum字段,好奇怪,比起listenerneedBytesNum多了個「Max」,因此咱們推測這個值取得是lns以及defaultLn字段中全部listenerneedBytesNum值最大的;最後的mu字段咱們就不說了。

須要注意的是:咱們可能會發現Muxlistener存在相互引用,但在Go中咱們倒也不用太擔憂,由於Go採用「標記-回收」或者其變種的垃圾回收算法,感興趣能夠參考Golang 垃圾回收剖析

mux.go文件中定義了Mux的生成函數NewMux:

func NewMux(ln net.Listener) (mux *Mux) {
    mux = &Mux{
        ln:  ln,
        lns: make([]*listener, 0),
    }
    return
}

很簡單,須要注意的是ln字段存儲的通常不是listener這樣的很是規Listener,通常是TCPListener這樣具體的綁定了套接字的監聽器。

Mux方法

接下來看Mux結構體的方法,首先看ListencopyLns

// priority
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
    // 1
    ln := &listener{
        c:            make(chan net.Conn),
        mux:          mux,
        priority:     priority,
        needBytesNum: needBytesNum,
        matchFn:      fn,
    }

    mux.mu.Lock()
    defer mux.mu.Unlock()
    // 2
    if needBytesNum > mux.maxNeedBytesNum {
        mux.maxNeedBytesNum = needBytesNum
    }

    // 3
    newlns := append(mux.copyLns(), ln)
    sort.Slice(newlns, func(i, j int) bool {
        if newlns[i].priority == newlns[j].priority {
            return newlns[i].needBytesNum < newlns[j].needBytesNum
        }
        return newlns[i].priority < newlns[j].priority
    })
    mux.lns = newlns
    return ln
}

func (mux *Mux) copyLns() []*listener {
    lns := make([]*listener, 0, len(mux.lns))
    for _, l := range mux.lns {
        lns = append(lns, l)
    }
    return lns
}

copyLns方法很簡單,就是跟名字的含義同樣,生成一個lns字段的副本並返回。

Listen基本作了三步:

  1. 生成一個listener結構體實例,並獲取互斥鎖
  2. 根據狀況更新needBytesNum字段
  3. 將新生成的listener實例按照優先級放入lns字段對應的slice中

接下來是ListenHttpListenHttps方法:

func (mux *Mux) ListenHttp(priority int) net.Listener {
    return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
}

func (mux *Mux) ListenHttps(priority int) net.Listener {
    return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
}

這兩個差很少,因此放到一塊兒說,基本都是專門寫了一個方法讓咱們能方便的建立處理Http或者Httpslistener

再來看DefaultListener方法:

func (mux *Mux) DefaultListener() net.Listener {
    mux.mu.Lock()
    defer mux.mu.Unlock()
    if mux.defaultLn == nil {
        mux.defaultLn = &listener{
            c:   make(chan net.Conn),
            mux: mux,
        }
    }
    return mux.defaultLn
}

這個方法很簡單,基本就是有則返回沒有則生成而後返回的套路。不過咱們要注意defaultLn字段中的listener是不放入lns字段中的。

接下來是Server方法:

// Serve handles connections from ln and multiplexes then across registered listeners.
func (mux *Mux) Serve() error {
    for {
        // Wait for the next connection.
        // If it returns a temporary error then simply retry.
        // If it returns any other error then exit immediately.
        conn, err := mux.ln.Accept()
        if err, ok := err.(interface {
            Temporary() bool
        }); ok && err.Temporary() {
            continue
        }

        if err != nil {
            return err
        }

        go mux.handleConn(conn)
    }
}

通常來講,當咱們調用NewMux函數之後,接下來就會調用Server方法,該方法基本上就是阻塞監聽某個套接字,當有鏈接創建成功後當即另起一個goroutine調用handleConn方法;當鏈接創建失敗根據err是否含有Temporary方法,若是有則執行並忽略錯誤,沒有則返回錯誤。

如今咱們看看handleConn方法幹了些啥:

func (mux *Mux) handleConn(conn net.Conn) {
    // 1
    mux.mu.RLock()
    maxNeedBytesNum := mux.maxNeedBytesNum
    lns := mux.lns
    defaultLn := mux.defaultLn
    mux.mu.RUnlock()
    
    // 2
    sharedConn, rd := gnet.NewSharedConnSize(conn, int(maxNeedBytesNum))
    data := make([]byte, maxNeedBytesNum)

    conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
    _, err := io.ReadFull(rd, data)
    if err != nil {
        conn.Close()
        return
    }
    conn.SetReadDeadline(time.Time{})
    // 3
    for _, ln := range lns {
        if match := ln.matchFn(data); match {
            err = errors.PanicToError(func() {
                ln.c <- sharedConn
            })
            if err != nil {
                conn.Close()
            }
            return
        }
    }

    // No match listeners
    if defaultLn != nil {
        err = errors.PanicToError(func() {
            defaultLn.c <- sharedConn
        })
        if err != nil {
            conn.Close()
        }
        return
    }

    // No listeners for this connection, close it.
    conn.Close()
    return
}

handleConn方法也不算複雜,大致能夠分爲三步:

  1. 獲取當前狀態
  2. conn中讀取數據,注意:shareConnrd存在單向關係,若是從rd中讀取數據的話,數據也會複製一份放到shareConn中,反過來就不成立了
  3. 讀取到的數據會被遍歷,最終選出與matchFunc匹配的最高優先級的listener,並將shareConn放入該listenerc字段中,若是沒有匹配到則放到defaultLn中的c字段中,若是defaultLnnil的話就不處理,直接關閉conn

最後來到了release方法了:

func (mux *Mux) release(ln *listener) bool {
    result := false
    mux.mu.Lock()
    defer mux.mu.Unlock()
    lns := mux.copyLns()

    for i, l := range lns {
        if l == ln {
            lns = append(lns[:i], lns[i+1:]...)
            result = true
            break
        }
    }
    mux.lns = lns
    return result
}

release方法意思很明確:把對應的listenerlns中移除,並把結果返回,整個過程有互斥鎖,咱們回到存疑1,儘管有互斥鎖,但在這種狀況下:當某個goroutine運行到handleConn已經執行到了第三階段的開始狀態(也就是尚未找到匹配的listener)時,且Go運行在多核狀態下,當另外一個goroutine運行完listenerClose方法時,這時就可能發生往一個已經關閉的channel中send數據,但請注意handleConn的第三步的這段代碼:

err = errors.PanicToError(func() { // 就是這裏了
    ln.c <- sharedConn
})
if err != nil {
    conn.Close()
}

這個PanicToError是這樣的:

func PanicToError(fn func()) (err error) {
    defer func() {
        if r := recover(); r != nil {
            err = fmt.Errorf("Panic error: %v", r)
        }
    }()

    fn()
    return
}

基本上就是執行了recover而後將錯誤打印出來,結合下面的對err的判斷,就會將send失敗的conn關閉。

總結

  1. Mux中包含了一個初始監聽器,基本上全部的事件(好比說新的鏈接創建,之因此叫事件是由於我實在想不出更精確的詞語了)都起源於此
  2. listener實現了net.Listener接口,能夠做爲二級監聽器使用(好比傳給net/http.Server結構體的Server方法進行處理)。
  3. Mux包含了一個由listener組成的有序slice,當有事件產生時就會遍歷這個slice找出合適的listener並將事件傳給他。

講到這裏基本上是完事了。整個mux模塊仍是比較簡單的,起碼是由一個個簡單的東西組合而成。那麼一塊兒來意淫一下總體流程吧。

假如我要實現這麼一個網絡程序:

  1. 綁定監聽一個基於tcp的套接字
  2. 咱們容許其應用層可支持多個(好比說支持http https這兩個吧,儘管http和https能夠說是一個協議。。),不一樣的應用層協議對應不一樣的處理函數

就這麼兩個很簡單的要求,不難吧。

那麼咱們一塊兒來實現吧:

type HandleFunc func(c net.Conn) (n int, err error) 

type MyServer struct {
    l net.Listener
    hFunc HandleFunc
}

func (h *MyServer) Server() (err error) {
    for {
        conn, err := h.l.Accept()
        if err != nil {
            return
        }
        go h.hFunc(conn)
    }
}

func HandleHttp(c net.Conn)(n int, err error){
    n, err = c.Write([]byte("Get Off! Don't you know that it is not safe?"))
}

func HandleHttps(c net.Conn)(n int, err error){
    n, err = c.Write([]byte("Get Off! Don't you know that this is more complicated than http?"))
}


func main() (err error){
    ln, err := net.Listen("tcp", "0.0.0.0:12345")
    if err != nil {
        err = fmt.Errorf("Create server listener error, %v", err)
        return
    }
    muxer = mux.NewMux(ln)
    
    var lHttp, lHttps net.Listener
    lHttp = muxer.ListenHttp(1)
    httpServer := *MyServer{lHttp, HandleHttp}
    
    lHttps = muxer.ListenHttps(2)
    httpsServer := *MyServer{lHttps, HandleHttps}
    
    go httpServer.Server()
    go httpsServer.Server()

    err = muxer.Serve()
}
相關文章
相關標籤/搜索