130行C語言實現個用戶態線程庫(1)

準確的說是除掉頭文件,測試代碼和非關鍵的純算法代碼(只有雙向環形鏈表的ADT),核心代碼只有130行左右,已是蠅量級的用戶態線程庫了。把這個庫取名爲ezthread,意思是,這太easy了,人人均可以讀懂而且實現這個用戶態線程庫。我把該項目放在github上,歡迎來拍磚: https://github.com/Yuandong-Chen/coroutine/tree/old-version(注意,最新的版本已經用了共享棧技術,可以支持1000K數量級的協程了,讀完這篇博文後能夠進一步參考後續的博文:http://www.cnblogs.com/github-Yuandong-Chen/p/6973932.html)。那麼下面談談怎麼實現這個ezthread。html

你們都會雙向環形鏈表(就是頭尾相連的雙向鏈表),咱們構造這個ADT結構:node

首先是每一個節點:git

1 typedef struct __pnode pNode;
2 struct __pnode
3 {
4     pNode *next;
5     pNode *prev;
6     Thread_t *data;
7 };

顯然,next指向下一個節點,prev指向上一個節點,data指向該節點數據,那麼這個Thread_t是什麼類型的數據結構呢?github

typedef struct __ez_thread Thread_t;
struct __ez_thread
{
    Regs regs;
    int tid;
    unsigned int stacktop;
    unsigned int stacksize;
    void *stack;
    void *retval;
};

這個結構體包含了線程內部的信息,好比第一項爲Regs,記錄的是各個寄存器的取值(咱們在下面給出具體的結構),tid就是線程的ID了,stacktop記錄的是線程棧的頂部(和頁對齊的最大地址,每一個線程都有本身的運行時的棧,用於構成他們相對獨立的運行時環境),stacksize就是棧的大小了,stack指針指向咱們給該線程棧分配的堆的指針(什麼?怎麼一會棧一會堆的?咱們其實用了malloc函數分配出一些堆空間,把這些空間用於線程棧,當線程退出時候,咱們再free這些堆),retval就是線程運行完了的返回值(pthread_join裏頭拿到的線程返回值就是這個了)。算法

下面是寄存器結構體:數據結構

typedef struct __thread_table_regs Regs;
struct __thread_table_regs
{
    int _ebp;
    int _esp;
    int _eip;
    int _eflags;
};

真是好懂,一看就知道了,這個結構體只能支持X86體系的計算機了。那麼還有個問題,爲什麼只有這些寄存器,沒用其餘的好比:eax,ebx,edi,esi等等呢?由於咱們在轉換狀態函數switch_to裏頭當返回時(準確地說是從上次切換的點切換回來時)用了return來切換回線程運行時環境,return會自動幫助咱們把這些其餘的寄存器的值恢復原狀的(具體咱們放到switch_to的時候再詳細說明)。多線程

而後呢,咱們定義了一個遊標去取這個環形鏈表的值,不然咱們怎麼讀取這個環形鏈表裏頭的數據呢?總得有個東西指向其中某個節點吧。app

typedef struct __loopcursor Cursor;
struct __loopcursor
{
    int total;
    pNode *current;
};

這個遊標結構體記錄瞭如今指向的節點地址和這個環形鏈表裏頭一共有多少節點。ide

咱們得用兩個這樣的環形鏈表結構體來支持咱們的線程庫,爲什麼是倆呢?一個是正在運行的線程,咱們把他們串成一個環形鏈表,取名爲live(活的),而後用另一個鏈表把運行結束的線程串成一串,取名爲dead(死的)。而後最開始咱們就有個線程在運行了,那就是主線程main,咱們用pmain節點來記錄主線程:函數

extern Cursor live;
extern Cursor dead;
extern Thread_t pmain;

好了,剩下的只有在這些結構體上操做的函數了:

void init();
void switch_to(int, ...);
int threadCreat(Thread_t **, void *(*)(void *), void *);
int threadJoin(Thread_t *, void **);

咱們開始時調用init,以初始化咱們的live,dead和pmain。而後當咱們想創造線程時,就threadCreat就能夠了,用法和pthread_create基本如出一轍,熟悉posix多線程的人一看就明白了,threadJoin也是仿照pthread_join接口寫的。這裏的switch_to就是最關鍵的運行時環境轉換函數了,當線程調用這個函數時候,咱們就切換到其餘線程上次暫停的點去執行了(這些狀態都保存在咱們的Thread_t結構體裏,因此咱們可以記錄下切換前的狀態,從而可以從容地去切換到下一個線程中)。咱們沒有用定時器每隔幾微秒去激發switch_to(實現起來也是很是簡單的,可是得添加多個signal_block函數,很是不簡潔),而是讓線程裏頭的函數主動調用switch_to來切換線程,這有點相似協程。

好了,如今講具體的實現了。首先是對雙向鏈表的操做函數,這個東西不是咱們的重點,懂基礎算法數據結構的人都能實現,具體是雙向環形鏈表的增查刪操做:

 1 void initCursor(Cursor *cur)
 2 {
 3     cur->total = 0;
 4     cur->current = NULL;
 5 }
 6 
 7 Thread_t *findThread(Cursor *cur, int tid)
 8 {
 9     int counter = cur->total;
10     if(counter == 0){
11         return NULL;
12     }
13 
14     int i;
15     pNode *tmp = cur->current;
16     for (int i = 0; i < counter; ++i)
17     {
18         if((tmp->data)->tid == tid){
19             return tmp->data;
20         }
21 
22         tmp = tmp->next;
23     }
24     return NULL;
25 }
26 
27 int appendThread(Cursor *cur, Thread_t *pth)
28 {
29     if(cur->total == 0)
30     {
31         cur->current = (pNode *)malloc(sizeof(pNode));
32         assert(cur->current);
33         (cur->current)->data = pth;
34         (cur->current)->prev = cur->current;
35         (cur->current)->next = cur->current;
36         cur->total++;
37         return 0;
38     }
39     else
40     {    
41         if(cur->total > MAXCOROUTINES)
42         {
43             assert((cur->total == MAXCOROUTINES));
44             return -1;
45         }
46         
47         pNode *tmp = malloc(sizeof(pNode));
48         assert(tmp);
49         tmp->data = pth;
50         tmp->prev = cur->current;
51         tmp->next = (cur->current)->next;
52         ((cur->current)->next)->prev = tmp;
53         (cur->current)->next = tmp;
54         cur->total++;
55         return 0;
56     }
57 }
58 
59 pNode *deleteThread(Cursor *cur, int tid)
60 {
61     int counter = cur->total;
62     int i;
63     pNode *tmp = cur->current;
64     for (int i = 0; i < counter; ++i)
65     {
66         if((tmp->data)->tid == tid){
67             (tmp->prev)->next = tmp->next;
68             (tmp->next)->prev = tmp->prev;
69             if(tmp == cur->current)
70             {
71                 cur->current = cur->current->next;
72             }  
73 
74             cur->total--;
75             assert(cur->total >= 0);
76             return tmp;
77         }
78         tmp = tmp->next;
79     }
80     return NULL;
81 }
雙向鏈表操做函數

拋開這部分純算法代碼,咱們只剩下130行代碼了。這還不如某些函數的代碼量大。可是咱們就是在這130行代碼裏頭實現了switch_to,threadCreat以及threadJoin等等關鍵代碼。

先說下init怎麼實現的:

1 void init()
2 {
3     initCursor(&live);
4     initCursor(&dead);
5     appendThread(&live, &pmain);
6 }

其實關鍵點只有一句,那就是第5行的append(&live,&pmain);往live鏈表裏頭添加pmain節點,可是咱們的pmain還沒初始化呢,裏頭stack,regs等等統統都是0,可是沒事呢,由於當咱們第一次進入switch_to的時候,switch_to在跳轉前會幫助咱們保存當前線程,這時也就是pmain的運行時狀態。

而後咱們看看threadCreat怎麼實現:

 1 int threadCreat(Thread_t **pth, void *(*start_rtn)(void *), void *arg)
 2 {
 3 
 4     *pth = malloc(sizeof(Thread_t));
 5     (*pth)->stack = malloc(PTHREAD_STACK_MIN);
 6     assert((*pth)->stack);
 7     (*pth)->stacktop = (((int)(*pth)->stack + PTHREAD_STACK_MIN)&(0xfffff000));
 8     (*pth)->stacksize = PTHREAD_STACK_MIN - (((int)(*pth)->stack + PTHREAD_STACK_MIN) - (*pth)->stacktop);
 9     (*pth)->tid = fetchTID();
10     /* set params */
11     void *dest = (*pth)->stacktop - 12;
12     memcpy(dest, pth, 4);
13     dest += 4;
14     memcpy(dest, &start_rtn, 4);
15     dest += 4;
16     memcpy(dest, &arg, 4);
17     (*pth)->regs._eip = &real_entry;
18     (*pth)->regs._esp = (*pth)->stacktop - 16;
19     (*pth)->regs._ebp = 0;
20     appendThread(&live, (*pth));
21 
22     return 0;
23 }

咱們在第4行分配了堆空間,而後讓線程棧頂變量stacktop對齊頁,設置stacksize大小(這個其實對咱們的線程庫沒有用,由於咱們尚未實現相似stackguard之類的檢查機制),設置tid,這裏fetchTID函數以下:

1 int fetchTID()
2 {
3     static int tid;
4     return ++tid;
5 }

接着,咱們在threadCreat函數的11-16行代碼中,在棧頂壓入變量pth,start_rtn以及arg(咱們用memcpy來操做線程棧空間),這些都是做爲real_entry這個函數的參數壓入線程棧的。咱們不難發現,其實每一個線程的最初入口地址都是real_entry函數(注意到咱們在17行把eip設置爲real_entry的地址)。最後,咱們於17-19行設置寄存器變量,以知足剛進入該real_entry時的棧的狀態,在live鏈表中添加該線程結構體指針,返回。這一系列操做致使的效果就是,好比咱們第一次調用threadCreat函數,當發生switch_to的時候,固然咱們先保存當前線程狀態,而後就從主線程main中切換到了real_entry裏頭去了,並且對應的參數咱們設置好了,就好像咱們在主線程裏頭直接調用了real_entry同樣。下面看下real_entry作了些什麼:

 1 void real_entry(Thread_t *pth, void *(*start_rtn)(void *), void* args)
 2 {
 3     ALIGN();
 4 
 5     pth->retval = (*start_rtn)(args);
 6 
 7     deleteThread(&live, pth->tid);
 8     appendThread(&dead, pth);
 9 
10     switch_to(-1);
11 }

 

 第3行是對齊棧操做,咱們先不作說明。接下來就是調用start_rtn函數,而且把args做爲參數,返回值賦給線程的retval。當返回時,說明線程已經運行結束,在live鏈表裏頭刪除該節點,在dead鏈表裏頭添加該節點。在第10行最後調用switch_to(-1),也就是在switch_to裏頭直接跳到下一個線程去執行,且不保存當前狀態。

咱們再看下threadJoin函數的實現:

 1 int threadJoin(Thread_t *pth, void **rval_ptr)
 2 {
 3 
 4     Thread_t *find1, *find2;
 5     find1 = findThread(&live, pth->tid);
 6     find2 = findThread(&dead, pth->tid);
 7     
 8 
 9     if((find1 == NULL)&&(find2 == NULL)){
10         
11         return -1;
12     }
13 
14     if(find2){
15         if(rval_ptr != NULL)
16             *rval_ptr = find2->retval;
17 
18         pNode *tmp = deleteThread(&dead, pth->tid);
19         free(tmp);
20         free((Stack_t)find2->stack);
21         free(find2);
22         return 0;
23     }
24 
25     while(1)
26     {
27         switch_to(0);
28         if((find2 = findThread(&dead, pth->tid))!= NULL){
29             if(rval_ptr!= NULL)
30                 *rval_ptr = find2->retval;
31 
32             pNode *tmp = deleteThread(&dead, pth->tid);
33             free(tmp);
34             free((Stack_t)find2->stack);
35             free(find2);
36             return 0;
37         }   
38     }
39     return -1;
40 }

threadJoin是用於回收線程資源並獲得返回值的。實現大致的思路就是,咱們先查找live和dead裏頭有沒有這個線程,若是都沒有,說明根本不存在這個線程,若是dead鏈表裏頭有,那麼咱們就獲得返回值(15-16行),而後釋放堆空間(19-22行)。若是在live裏頭,說明該線程還沒執行結束,咱們進入循環,先調用switch_to(0),保存當前線程狀態,而後切換到下一個線程去。當再次回到這個循環時候,咱們繼續看看dead裏頭有沒有這個線程,有就設置返回值(29-30行),而後釋放資源(32-35行),不然繼續切換並循環。

最後,最關鍵的,咱們給出switch_to的實現:

 1 void switch_to(int signo, ...)
 2 {
 3 
 4     va_list ap; 
 5     va_start(ap, signo);
 6 
 7     Regs regs;
 8 
 9     if(signo == -1)
10     {
11         regs = live.current->data->regs;
12         JMP(regs);
13         assert(0);
14     }
15     
16     int _ebp;
17     int _esp;
18     int _eip = &&_REENTERPOINT;
19     int _eflags;
20     /* save current context */
21     SAVE();
22     /* save context in current thread */
23     live.current->data->regs._eip = _eip;
24     live.current->data->regs._esp = _esp;
25     live.current->data->regs._ebp = _ebp;
26     live.current->data->regs._eflags = _eflags;
27 
28     if(va_arg(ap,int) == -1){
29  _REENTERPOINT:
30         assert(va_arg(ap,int) != -1);
31         return;
32     }
33 
34     va_end(ap);
35     regs = live.current->next->data->regs;
36     live.current = live.current->next;
37     JMP(regs);
38     assert(0);
39 }

先看11-13行,咱們把自動變量regs的值賦爲當前線程的寄存器的結構體,而後跳轉到當前線程(第12行JMP是跳轉語句,13行永遠不會執行)。這裏你們有個疑問,從當前線程跳轉到當前線程,那麼還不是當前線程麼?而後執行assert(0)報錯退出?!其實只有當線程返回時,也就是在real_entry裏頭纔可能執行switch_to(-1),注意到real_entry最後的幾行代碼,裏頭已經把當前線程從live裏頭刪除,並添加到dead裏了,因此如今live裏頭的當前線程實際上是下一個線程。而後咱們看21-26行,咱們保存當前寄存器的值到當前線程中,注意第18行,咱們把返回點設置在了_REENTERPOINT這個標籤上,也就是之後若是再次切換到該線程時,咱們會在第30行繼續向下執行,很簡單,第30行的有意義的代碼只有return,也就是恢復其餘寄存器(eax,edi,esi等等),而後返回到線程繼續執行。咱們繼續看34-38行代碼:咱們把自動變量regs的值賦值爲下一個線程的寄存器,而後live的當前線程指針current也指向了下一個線程,經過37行JMP,咱們調到了下一個線程去執行,下個一個線程多是real_entry處開始執行,也多是_REENTERPOINT處開始執行。最後再重新說說31行的return到底return到哪裏去了,咱們看一下測試代碼:

 1 #include "ezthread.h"
 2 #include <stdio.h>
 3 #include <stdlib.h>
 4 
 5 void *sum1tod(void *d)
 6 {
 7     int i, j=0;
 8 
 9     for (i = 0; i <= d; ++i)
10     {
11         j += i;
12         printf("thread %d is grunting... %d\n",live.current->data->tid , i);
13         switch_to(0); // Give up control to next thread
14     }
15     
16     return ((void *)j);
17 }
18 
19 int main(int argc, char const *argv[])
20 {
21     int res = 0;
22     int i;
23     init();
24     Thread_t *tid1, *tid2;
25     int *res1, *res2;
26 
27     threadCreat(&tid1, sum1tod, 10);
28     threadCreat(&tid2, sum1tod, 10);
29 
30     for (i = 0; i <= 5; ++i){
31         res+=i;
32         printf("main is grunting... %d\n", i); 
33         switch_to(0); //Give up control to next thread
34     }
35     threadJoin(tid1, &res1); //Collect and Release the resourse of tid1
36     threadJoin(tid2, &res2); //Collect and Release the resourse of tid2
37     printf("parallel compute: %d = (1+2+3+4+5) + (1+2+...+10)*2\n", (int)res1+(int)res2+(int)res);
38     return 0;
39 }

注意到咱們在測試代碼裏頭sum1tod裏頭調用了switch_to(0),若是這個循環加法(11-13行)還未結束,那麼上述的那個_REENTERPOINT裏頭的return就會return回這個循環繼續執行,就如在sum1tod裏的switch_to(0)函數直接調用return,什麼事情也沒幹同樣,可是其實咱們通過了無數其餘線程的執行,可是在sum1tod裏頭毫無感受,簡直好像其餘線程不存在同樣(除非咱們在這裏頭調用threadJoin等待其餘線程結束)。

如今咱們給出討厭的內嵌彙編:

 1 #define JMP(r)    asm volatile \
 2                 (   \
 3                     "pushl %3\n\t" \
 4                     "popf\n\t" \
 5                     "movl %0, %%esp\n\t" \
 6                     "movl %2, %%ebp\n\t" \
 7                     "jmp *%1\n\t" \
 8                     : \
 9                     : "m"(r._esp),"a"(r._eip),"m"(r._ebp), "m"(r._eflags) \
10                     :  \
11                 )
12 
13 #define SAVE()                  asm volatile \
14                             (  \
15                                    "movl %%esp, %0\n\t" \
16                                 "movl %%ebp, %1\n\t" \
17                                 "pushf\n\t" \
18                                 "movl (%%esp), %%eax\n\t" \
19                                 "movl %%eax, %2\n\t" \
20                                 "popf\n\t" \
21                                 : "=m"(_esp),"=m"(_ebp), "=m"(_eflags) \ 
22                                 : \
23                                 :  \
24                             )
25 
26 #define ALIGN()             asm volatile \
27                             ( \
28                                 "andl $-16, %%esp\n\t" \
29                                 : \
30                                 : \
31                                 :"%esp" \
32                             )
inline asm

第一個就是起到跳轉做用,第二個是保存寄存器到自動變量做用,最後一個是棧對齊做用。爲什麼要棧對齊?由於咱們在堆裏頭設置了這個棧的空間,這個和普通的棧空間並不徹底同樣,咱們須要作對齊處理。

到這裏咱們就幾乎徹底明白了這個線程庫的實現,還有一小點就是switch_to裏頭的可變參數怎麼回事,其實那個是防止編譯器中消除冗餘代碼形成咱們_REENTERPOINT中的代碼被優化而整個刪除用的。若是咱們在_REENTERPOINT前加入goto語句跳到下面執行,而後刪除這個_REENTERPOINT以前的判斷語句,咱們會發現,編譯器會把switch_to裏頭的第28-32行做爲冗餘代碼所有刪除。

謝謝你能看到最後,告訴大家一個消息,其實咱們的實現是介於longjmp和彙編實現版本之間的某種實現:咱們用匯編保存了運行時狀態,可是其中的return又有點相似longjmp中自動恢復寄存器的做用。並且咱們的庫比純彙編實現更具可移植性,但比longjmp實現版本又弱了點。

相關文章
相關標籤/搜索