【譯】使用 Python 編寫虛擬機解釋器

【譯】如何使用 Python 建立一個虛擬機解釋器?

原文地址:Making a simple VM interpreter in Pythonhtml

更新:根據你們的評論我對代碼作了輕微的改動。感謝 robin-gvx、 bs4h 和 Dagur,具體代碼見這裏python

Stack Machine 自己並無任何的寄存器,它將所須要處理的值所有放入堆棧中然後進行處理。Stack Machine 雖然簡單可是卻十分強大,這也是爲神馬 Python,Java,PostScript,Forth 和其餘語言都選擇它做爲本身的虛擬機的緣由。git

首先,咱們先來談談堆棧。咱們須要一個指令指針棧用於保存返回地址。這樣當咱們調用了一個子例程(好比調用一個函數)的時候咱們就可以返回到咱們開始調用的地方了。咱們可使用自修改代碼(self-modifying code)來作這件事,恰如 Donald Knuth 發起的 MIX 所作的那樣。可是若是這麼作的話你不得不本身維護堆棧從而保證遞歸能正常工做。在這篇文章中,我並不會真正的實現子例程調用,可是要實現它其實並不難(能夠考慮把實現它當成練習)。github

有了堆棧以後你會省不少事兒。舉個例子來講,考慮這樣一個表達式 (2+3)*4。在 Stack Machine 上與這個表達式等價的代碼爲 2 3 + 4 *。首先,將 23 推入堆棧中,接下來的是操做符 +,此時讓堆棧彈出這兩個數值,再把它兩加合以後的結果從新入棧。而後將 4 入堆,然後讓堆棧彈出兩個數值,再把他們相乘以後的結果從新入棧。多麼簡單啊!express

讓咱們開始寫一個簡單的堆棧類吧。讓這個類繼承 collections.dequeapp

from collections import deque

class Stack(deque):
    push = deque.append

    def top(self):
        return self[-1]

如今咱們有了 pushpoptop 這三個方法。top 方法用於查看棧頂元素。less

接下來,咱們實現虛擬機這個類。在虛擬機中咱們須要兩個堆棧以及一些內存空間來存儲程序自己(譯者注:這裏的程序請結合下文理解)。得益於 Pyhton 的動態類型咱們能夠往 list 中放入任何類型。惟一的問題是咱們沒法區分出哪些是字符串哪些是內置函數。正確的作法是隻將真正的 Python 函數放入 list 中。我可能會在未來實現這一點。函數

咱們同時還須要一個指令指針指向程序中下一個要執行的代碼。oop

class Machine:
    def __init__(self, code):
        self.data_stack = Stack()
        self.return_addr_stack = Stack()
        self.instruction_pointer = 0
        self.code = code

這時候咱們增長一些方便使用的函數免得之後多敲鍵盤。佈局

def pop(self):
    return self.data_stack.pop()

def push(self, value):
    self.data_stack.push(value)

def top(self):
    return self.data_stack.top()

而後咱們增長一個 dispatch 函數來完成每個操做碼作的事兒(咱們並非真正的使用操做碼,只是動態展開它,你懂的)。首先,增長一個解釋器所必須的循環:

def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)

誠如您所見的,這貨只好好的作一件事兒,即獲取下一條指令,讓指令指針執自增,而後根據操做碼分別處理。dispatch 函數的代碼稍微長了一點。

def dispatch(self, op):
    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":       self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":     self.drop,
        "dup":      self.dup,
        "if":       self.if_stmt,
        "jmp":      self.jmp,
        "over":     self.over,
        "print":    self.print_,
        "println":  self.println,
        "read":     self.read,
        "stack":    self.dump_stack,
        "swap":     self.swap,
    }

    if op in dispatch_map:
        dispatch_map[op]()
    elif isinstance(op, int):
        # push numbers on the data stack
        self.push(op)
    elif isinstance(op, str) and op[0]==op[-1]=='"':
        # push quoted strings on the data stack
        self.push(op[1:-1])
    else:
        raise RuntimeError("Unknown opcode: '%s'" % op)

基本上,這段代碼只是根據操做碼查找是都有對應的處理函數,例如 * 對應 self.muldrop 對應 self.dropdup 對應 self.dup。順便說一句,你在這裏看到的這段代碼其實本質上就是簡單版的 Forth。並且,Forth 語言仍是值得您看看的。

總之捏,它一但發現操做碼是 * 的話就直接調用 self.mul 並執行它。就像這樣:

def mul(self):
    self.push(self.pop() * self.pop())

其餘的函數也是相似這樣的。若是咱們在 dispatch_map 中查找不到相應操做函數,咱們首先檢查他是否是數字類型,若是是的話直接入棧;若是是被引號括起來的字符串的話也是一樣處理--直接入棧。

截止如今,恭喜你,一個虛擬機就完成了。

讓咱們定義更多的操做,而後使用咱們剛完成的虛擬機和p-code 語言 來寫程序。

# Allow to use "print" as a name for our own method:
from __future__ import print_function

# ...

def plus(self):
    self.push(self.pop() + self.pop())

def minus(self):
    last = self.pop()
    self.push(self.pop() - last)

def mul(self):
    self.push(self.pop() * self.pop())

def div(self):
    last = self.pop()
    self.push(self.pop() / last)

def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()

def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()

讓咱們用咱們的虛擬機寫個與 print((2+3)*4) 等同效果的例子。

Machine([2, 3, "+", 4, "*", "println"]).run()

你能夠試着運行它。

如今引入一個新的操做 jump, 即 go-to 操做

def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")

它只改變指令指針的值。咱們再看看分支跳轉是怎麼作的。

def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)

這一樣也是很直白的。若是你想要添加一個條件跳轉,你只要簡單的執行 test-value true-value false-value IF JMP 就能夠了.(分支處理是很常見的操做,許多虛擬機都提供相似 JNE 這樣的操做。JNEjump if not equal 的縮寫)。

下面的程序要求使用者輸入兩個數字,而後打印出他們的和和乘積。

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"Enter another number: "', "print", "read", "cast_int",
    "over", "over",
    '"Their sum is: "', "print", "+", "println",
    '"Their product is: "', "print", "*", "println"
]).run()

overreadcast_int 這三個操做是長這樣滴:

def cast_int(self):
    self.push(int(self.pop()))

def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)

def read(self):
    self.push(raw_input())

如下這一段程序要求使用者輸入一個數字,而後打印出這個數字是奇數仍是偶數。

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"The number "', "print", "dup", "print", '" is "', "print",
    2, "%", 0, "==", '"even."', '"odd."', "if", "println",
    0, "jmp" # loop forever!
]).run()

這裏有個小練習給你去實現:增長 callreturn 這兩個操做碼。call 操做碼將會作以下事情 :將當前地址推入返回堆棧中,而後調用 self.jmp()return 操做碼將會作以下事情:返回堆棧彈棧,將彈棧出來元素的值賦予指令指針(這個值可讓你跳轉回去或者從 call 調用中返回)。當你完成這兩個命令,那麼你的虛擬機就能夠調用子例程了。

一個簡單的解析器

創造一個模仿上述程序的小型語言。咱們將把它編譯成咱們的機器碼。

import tokenize
from StringIO import StringIO

# ...

def parse(text):
    tokens = tokenize.generate_tokens(StringIO(text).readline)
    for toknum, tokval, _, _, _ in tokens:
        if toknum == tokenize.NUMBER:
            yield int(tokval)
        elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
            yield tokval
        elif toknum == tokenize.ENDMARKER:
            break
        else:
            raise RuntimeError("Unknown token %s: '%s'" %
                    (tokenize.tok_name[toknum], tokval))

一個簡單的優化:常量摺疊

常量摺疊(Constant folding)是窺孔優化(peephole optimization)的一個例子,也便是說再在編譯期間能夠針對某些明顯的代碼片斷作些預計算的工做。好比,對於涉及到常量的數學表達式例如 2 3 + 就能夠很輕鬆的實現這種優化。

def constant_fold(code):
    """Constant-folds simple mathematical expressions like 2 3 + to 5."""
    while True:
        # Find two consecutive numbers and an arithmetic operator
        for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
            if isinstance(a, int) and isinstance(b, int) \
                    and op in {"+", "-", "*", "/"}:
                m = Machine((a, b, op))
                m.run()
                code[i:i+3] = [m.top()]
                print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
                break
        else:
            break
    return code

採用常量摺疊遇到惟一問題就是咱們不得不更新跳轉地址,但在不少狀況這是很難辦到的(例如:test cast_int jmp)。針對這個問題有不少解決方法,其中一個簡單的方法就是隻容許跳轉到程序中的命名標籤上,而後在優化以後解析出他們真正的地址。

若是你實現了 Forth words,也即函數,你能夠作更多的優化,好比刪除可能永遠不會被用到的程序代碼(dead code elimination

REPL

咱們能夠創造一個簡單的 PERL,就像這樣

def repl():
    print('Hit CTRL+D or type "exit" to quit.')

    while True:
        try:
            source = raw_input("> ")
            code = list(parse(source))
            code = constant_fold(code)
            Machine(code).run()
        except (RuntimeError, IndexError) as e:
            print("IndexError: %s" % e)
        except KeyboardInterrupt:
            print("\nKeyboardInterrupt")

用一些簡單的程序來測試咱們的 REPL

> 2 3 + 4 * println
Constant-folded 2+3 to 5
Constant-folded 5*4 to 20
20
> 12 dup * println
144
> "Hello, world!" dup println println
Hello, world!
Hello, world!

你能夠看到,常量摺疊看起來運轉正常。在第一個例子中,它把整個程序優化成這樣 20 println

下一步

當你添加完 callreturn 以後,你即可以讓使用者定義本身的函數了。在Forth 中函數被稱爲 words,他們以冒號開頭緊接着是名字而後以分號結束。例如,一個整數平方的 word 是長這樣滴

: square dup * ;

實際上,你能夠試試把這一段放在程序中,好比 Gforth

$ gforth
Gforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.
Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license'
Type `bye' to exit
: square dup * ;  ok
12 square . 144  ok

你能夠在解析器中經過發現 : 來支持這一點。一旦你發現一個冒號,你必須記錄下它的名字及其地址(好比:在程序中的位置)而後把他們插入到符號表(symbol table)中。簡單起見,你甚至能夠把整個函數的代碼(包括分號)放在字典中,譬如:

symbol_table = {
  "square": ["dup", "*"]
  # ...
}

當你完成了解析的工做,你能夠鏈接你的程序:遍歷整個主程序而且在符號表中尋找自定義函數的地方。一旦你找到一個而且它沒有在主程序的後面出現,那麼你能夠把它附加到主程序的後面。而後用 <address> call 替換掉 square,這裏的 <address> 是函數插入的地址。

爲了保證程序能正常執行,你應該考慮剔除 jmp 操做。不然的話,你不得不解析它們。它確實能執行,可是你得按照用戶編寫程序的順序保存它們。舉例來講,你想在子例程之間移動,你要格外當心。你可能須要添加 exit 函數用於中止程序(可能須要告訴操做系統返回值),這樣主程序就不會繼續執行以致於跑到子例程中。

實際上,一個好的程序空間佈局頗有可能把主程序當成一個名爲 main 的子例程。或者由你決定搞成什麼樣子。

如您所見,這一切都是頗有趣的,並且經過這一過程你也學會了不少關於代碼生成、連接、程序空間佈局相關的知識。

更多能作的事兒

你可使用 Python 字節碼生成庫來嘗試將虛擬機代碼爲原生的 Python 字節碼。或者用 Java 實現運行在 JVM 上面,這樣你就能夠自由使用 JITing

一樣的,你也能夠嘗試下register machine。你能夠嘗試用棧幀(stack frames)實現調用棧(call stack),並基於此創建調用會話。

最後,若是你不喜歡相似 Forth 這樣的語言,你能夠創造運行於這個虛擬機之上的自定義語言。譬如,你能夠把相似 (2+3)*4 這樣的中綴表達式轉化成 2 3 + 4 * 而後生成代碼。你也能夠容許 C 風格的代碼塊 { ... } 這樣的話,語句 if ( test ) { ... } else { ... } 將會被翻譯成

<true/false test>
<address of true block>
<address of false block>
if
jmp

<true block>
<address of end of entire if-statement> jmp

<false block>
<address of end of entire if-statement> jmp

例子,

Address  Code
-------  ----
 0       2 3 >
 3       7        # Address of true-block
 4       11       # Address of false-block
 5       if
 6       jmp      # Conditional jump based on test

# True-block
 7       "Two is greater than three."
 8       println
 9       15       # Continue main program
10       jmp

# False-block ("else { ... }")
11       "Two is less than three."
12       println
13       15       # Continue main program
14       jmp

# If-statement finished, main program continues here
15       ...

對了,你還須要添加比較操做符 != < <= > >=

我已經在個人 C++ stack machine 實現了這些東東,你能夠參考下。

我已經把這裏呈現出來的代碼搞成了個項目 Crianza,它使用了更多的優化和實驗性質的模型來吧程序編譯成 Python 字節碼。

祝好運!

完整的代碼

下面是所有的代碼,兼容 Python 2 和 Python 3

你能夠經過這裏 獲得它。

#!/usr/bin/env python
# coding: utf-8

"""
A simple VM interpreter.

Code from the post at http://csl.name/post/vm/
This version should work on both Python 2 and 3.
"""

from __future__ import print_function
from collections import deque
from io import StringIO
import sys
import tokenize


def get_input(*args, **kw):
    """Read a string from standard input."""
    if sys.version[0] == "2":
        return raw_input(*args, **kw)
    else:
        return input(*args, **kw)


class Stack(deque):
    push = deque.append

    def top(self):
        return self[-1]


class Machine:
    def __init__(self, code):
        self.data_stack = Stack()
        self.return_stack = Stack()
        self.instruction_pointer = 0
        self.code = code

    def pop(self):
        return self.data_stack.pop()

    def push(self, value):
        self.data_stack.push(value)

    def top(self):
        return self.data_stack.top()

    def run(self):
        while self.instruction_pointer < len(self.code):
            opcode = self.code[self.instruction_pointer]
            self.instruction_pointer += 1
            self.dispatch(opcode)

    def dispatch(self, op):
        dispatch_map = {
            "%":        self.mod,
            "*":        self.mul,
            "+":        self.plus,
            "-":        self.minus,
            "/":        self.div,
            "==":       self.eq,
            "cast_int": self.cast_int,
            "cast_str": self.cast_str,
            "drop":     self.drop,
            "dup":      self.dup,
            "exit":     self.exit,
            "if":       self.if_stmt,
            "jmp":      self.jmp,
            "over":     self.over,
            "print":    self.print,
            "println":  self.println,
            "read":     self.read,
            "stack":    self.dump_stack,
            "swap":     self.swap,
        }

        if op in dispatch_map:
            dispatch_map[op]()
        elif isinstance(op, int):
            self.push(op) # push numbers on stack
        elif isinstance(op, str) and op[0]==op[-1]=='"':
            self.push(op[1:-1]) # push quoted strings on stack
        else:
            raise RuntimeError("Unknown opcode: '%s'" % op)

    # OPERATIONS FOLLOW:

    def plus(self):
        self.push(self.pop() + self.pop())

    def exit(self):
        sys.exit(0)

    def minus(self):
        last = self.pop()
        self.push(self.pop() - last)

    def mul(self):
        self.push(self.pop() * self.pop())

    def div(self):
        last = self.pop()
        self.push(self.pop() / last)

    def mod(self):
        last = self.pop()
        self.push(self.pop() % last)

    def dup(self):
        self.push(self.top())

    def over(self):
        b = self.pop()
        a = self.pop()
        self.push(a)
        self.push(b)
        self.push(a)

    def drop(self):
        self.pop()

    def swap(self):
        b = self.pop()
        a = self.pop()
        self.push(b)
        self.push(a)

    def print(self):
        sys.stdout.write(str(self.pop()))
        sys.stdout.flush()

    def println(self):
        sys.stdout.write("%s\n" % self.pop())
        sys.stdout.flush()

    def read(self):
        self.push(get_input())

    def cast_int(self):
        self.push(int(self.pop()))

    def cast_str(self):
        self.push(str(self.pop()))

    def eq(self):
        self.push(self.pop() == self.pop())

    def if_stmt(self):
        false_clause = self.pop()
        true_clause = self.pop()
        test = self.pop()
        self.push(true_clause if test else false_clause)

    def jmp(self):
        addr = self.pop()
        if isinstance(addr, int) and 0 <= addr < len(self.code):
            self.instruction_pointer = addr
        else:
            raise RuntimeError("JMP address must be a valid integer.")

    def dump_stack(self):
        print("Data stack (top first):")

        for v in reversed(self.data_stack):
            print(" - type %s, value '%s'" % (type(v), v))


def parse(text):
    # Note that the tokenizer module is intended for parsing Python source
    # code, so if you're going to expand on the parser, you may have to use
    # another tokenizer.

    if sys.version[0] == "2":
        stream = StringIO(unicode(text))
    else:
        stream = StringIO(text)

    tokens = tokenize.generate_tokens(stream.readline)

    for toknum, tokval, _, _, _ in tokens:
        if toknum == tokenize.NUMBER:
            yield int(tokval)
        elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
            yield tokval
        elif toknum == tokenize.ENDMARKER:
            break
        else:
            raise RuntimeError("Unknown token %s: '%s'" %
                    (tokenize.tok_name[toknum], tokval))

def constant_fold(code):
    """Constant-folds simple mathematical expressions like 2 3 + to 5."""
    while True:
        # Find two consecutive numbers and an arithmetic operator
        for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
            if isinstance(a, int) and isinstance(b, int) \
                    and op in {"+", "-", "*", "/"}:
                m = Machine((a, b, op))
                m.run()
                code[i:i+3] = [m.top()]
                print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
                break
        else:
            break
    return code

def repl():
    print('Hit CTRL+D or type "exit" to quit.')

    while True:
        try:
            source = get_input("> ")
            code = list(parse(source))
            code = constant_fold(code)
            Machine(code).run()
        except (RuntimeError, IndexError) as e:
            print("IndexError: %s" % e)
        except KeyboardInterrupt:
            print("\nKeyboardInterrupt")

def test(code = [2, 3, "+", 5, "*", "println"]):
    print("Code before optimization: %s" % str(code))
    optimized = constant_fold(code)
    print("Code after optimization: %s" % str(optimized))

    print("Stack after running original program:")
    a = Machine(code)
    a.run()
    a.dump_stack()

    print("Stack after running optimized program:")
    b = Machine(optimized)
    b.run()
    b.dump_stack()

    result = a.data_stack == b.data_stack
    print("Result: %s" % ("OK" if result else "FAIL"))
    return result

def examples():
    print("** Program 1: Runs the code for `print((2+3)*4)`")
    Machine([2, 3, "+", 4, "*", "println"]).run()

    print("\n** Program 2: Ask for numbers, computes sum and product.")
    Machine([
        '"Enter a number: "', "print", "read", "cast_int",
        '"Enter another number: "', "print", "read", "cast_int",
        "over", "over",
        '"Their sum is: "', "print", "+", "println",
        '"Their product is: "', "print", "*", "println"
    ]).run()

    print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).")
    Machine([
        '"Enter a number: "', "print", "read", "cast_int",
        '"The number "', "print", "dup", "print", '" is "', "print",
        2, "%", 0, "==", '"even."', '"odd."', "if", "println",
        0, "jmp" # loop forever!
    ]).run()


if __name__ == "__main__":
    try:
        if len(sys.argv) > 1:
            cmd = sys.argv[1]
            if cmd == "repl":
                repl()
            elif cmd == "test":
                test()
                examples()
            else:
                print("Commands: repl, test")
        else:
            repl()
    except EOFError:
        print("")

本文系OneAPM工程師編譯整理。OneAPM是中國基礎軟件領域的新興領軍企業,能幫助企業用戶和開發者輕鬆實現:緩慢的程序代碼和SQL語句的實時抓取。想閱讀更多技術文章,請訪問OneAPM官方技術博客

相關文章
相關標籤/搜索