股票量化交易回測框架pyalgotrade源碼閱讀(一)

PyAlgoTrade是什麼呢?html

一個股票量化交易的策略回測框架。python

而做者的說明以下。git

To make it easy to backtest stock trading strategies.github


  簡單的來講,是一個用於驗證本身交易策略的框架。設計模式

適用如下場景:api

  我有個前無古人後無來者的想法,我以爲我按照這個想法去買股票穩賺不賠,可是爲了穩妥起見,我須要測試一下這個個人這個想法到底用沒有用,怎麼測試呢?微信


大概下面兩種方法數據結構

一:弄個模擬交易的軟件,天天按照本身的想法買入賣出,而後看看一個月或者一年後的收益如何。app

優勢:更貼近現實,至少當下的現實框架

缺點:測試周期大,數據有限


二:我相信個人這個想法不是針對如今或者將來有用,甚至是在之前應該也是起做用的,那麼我能夠將歷史數據調出來,用於測試,看看在歷史行情中收益如何。

優勢:數據充分,能夠反覆測試。

缺點:可能不能貼近現實


  而pyalgotrade就是爲了提供給使用者基於歷史數據回測的框架,即爲了讓你更好的使用上述的第二種方法。

注:不管怎麼測,確定都有誤差的, 由於都是猜,就像×××,你算好了各類機率,想好了各類策略,可是你能保證的只是你贏錢的機率大一些,而不是必贏,由於在沒有欺詐的狀況下,將來是不可測,也不能肯定的,誰也不能預知將來~吧~


文章目錄

  1. 官方示例

  2. 設計模式之觀察者模式

  3. 源碼解析


官方示例

sma_crossover.py文件

from pyalgotrade import strategy
from pyalgotrade.technical import ma
from pyalgotrade.technical import cross


class SMACrossOver(strategy.BacktestingStrategy):
    def __init__(self, feed, instrument, smaPeriod):
        super(SMACrossOver, self).__init__(feed)
        self.__instrument = instrument
        self.__position = None
        # We'll use adjusted close values instead of regular close values.
        self.setUseAdjustedValues(True)
        self.__prices = feed[instrument].getPriceDataSeries()
        self.__sma = ma.SMA(self.__prices, smaPeriod)

    def getSMA(self):
        return self.__sma

    def onEnterCanceled(self, position):
        self.__position = None

    def onExitOk(self, position):
        self.__position = None

    def onExitCanceled(self, position):
        # If the exit was canceled, re-submit it.
        self.__position.exitMarket()

    def onBars(self, bars):
        # If a position was not opened, check if we should enter a long position.
        if self.__position is None:
            if cross.cross_above(self.__prices, self.__sma) > 0:
                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())
                # Enter a buy market order. The order is good till canceled.
                self.__position = self.enterLong(self.__instrument, shares, True)
        # Check if we have to exit the position.
        elif not self.__position.exitActive() and cross.cross_below(self.__prices, self.__sma) > 0:
            self.__position.exitMarket()


sma_crossover_sample.py

import sma_crossover
from pyalgotrade import plotter
from pyalgotrade.tools import yahoofinance
from pyalgotrade.stratanalyzer import sharpe


def main(plot):
    instrument = "aapl"
    smaPeriod = 163

    # Download the bars.
    feed = yahoofinance.build_feed([instrument], 2011, 2012, ".")

    strat = sma_crossover.SMACrossOver(feed, instrument, smaPeriod)
    sharpeRatioAnalyzer = sharpe.SharpeRatio()
    strat.attachAnalyzer(sharpeRatioAnalyzer)

    if plot:
        plt = plotter.StrategyPlotter(strat, True, False, True)
        plt.getInstrumentSubplot(instrument).addDataSeries("sma", strat.getSMA())

    strat.run()
    print "Sharpe ratio: %.2f" % sharpeRatioAnalyzer.getSharpeRatio(0.05)

    if plot:
        plt.plot()


if __name__ == "__main__":
    main(True)


  上面的代碼主要作一件這樣的事。

  建立了一個策略,這個策略就是你的想法,這個想法是什麼呢?

  想法是,當價格高於近163日內的平均價格就買入,低於近163日內的平均價格就賣出(平倉)。

  其實還作了其餘的事,好比策略分析之類的,可是這篇文章暫時忽略。


設計模式之觀察者模式

#!/usr/bin/python
#coding:utf8
'''
Observer
'''
 
 
class Subject(object):
    def __init__(self):
        self._observers = []
 
    def attach(self, observer):
        if not observer in self._observers:
            self._observers.append(observer)
 
    def detach(self, observer):
        try:
            self._observers.remove(observer)
        except ValueError:
            pass
 
    def notify(self, modifier=None):
        for observer in self._observers:
            if modifier != observer:
                observer.update(self)
 
# Example usage
class Data(Subject):
    def __init__(self, name=''):
        Subject.__init__(self)
        self.name = name
        self._data = 0
 
    @property
    def data(self):
        return self._data
 
    @data.setter
    def data(self, value):
        self._data = value
        self.notify()
 
class HexViewer:
    def update(self, subject):
        print('HexViewer: Subject %s has data 0x%x' %
              (subject.name, subject.data))
 
class DecimalViewer:
    def update(self, subject):
        print('DecimalViewer: Subject %s has data %d' %
              (subject.name, subject.data))
 
# Example usage...
def main():
    data1 = Data('Data 1')
    data2 = Data('Data 2')
    view1 = DecimalViewer()
    view2 = HexViewer()
    data1.attach(view1)
    data1.attach(view2)
    data2.attach(view2)
    data2.attach(view1)
 
    print("Setting Data 1 = 10")
    data1.data = 10
    print("Setting Data 2 = 15")
    data2.data = 15
    print("Setting Data 1 = 3")
    data1.data = 3
    print("Setting Data 2 = 5")
    data2.data = 5
    print("Detach HexViewer from data1 and data2.")
    data1.detach(view2)
    data2.detach(view2)
    print("Setting Data 1 = 10")
    data1.data = 10
    print("Setting Data 2 = 15")
    data2.data = 15
 
if __name__ == '__main__':
    main()

意圖:

  定義對象間的一種一對多的依賴關係,當一個對象的狀態發生改變時, 全部依賴於它的對象都獲得通知並被自動更新。

適用性:

  當一個抽象模型有兩個方面, 其中一個方面依賴於另外一方面。將這兩者封裝在獨立的對象中以使它們能夠各自獨立地改變和複用。

  當對一個對象的改變須要同時改變其它對象, 而不知道具體有多少對象有待改變。

  當一個對象必須通知其它對象,而它又不能假定其它對象是誰。換言之, 你不但願這些對象是緊密耦合的。


摘自:http://www.cnblogs.com/Liqiongyu/p/5916710.html



  若是你看得懂就略過吧。

  上面的代碼想作個上面事情呢?

  想達到事件的目的,即,在更新數據的時候,會觸發相關的事件

  上面定義了主要三個種類型的類,subject,data,viewer。

  其中subject是data的父類。

  經過attach的操做,將不一樣的viewer加入到self.__observers列表裏面,當data對象要更新數據的時候,就回調用notify方法,而notify方法則會遍歷self.__observers列表的每一個observer,而後依次調用其update方法。

  這也是爲毛hexViewer,DecimalViewer都要實現自身的update方法。

  爲毛要這麼寫?

  前人總結的經驗~

  能不能不這麼寫?

  能夠的。

  若是看不懂這個設計模式,那麼pyalgotrade的源碼看起來可能會比較吃力,可是也只是可能而已,由於不少人看不懂,只是由於沒有實際的有用場景而已。



源碼解析

  首先是框架,看一遍,好比那些模塊,不過我的經驗之談就是,看完以後,通常都會有一下迷思。

  爲毛這麼寫?

  這裏到底想幹什麼?

  這麼複雜有毛用~

  恩,我也是這種感受~

  通常是pdb跟一遍流程或者一個一個找繼承關係。

  pdb這裏就不講了,主要就是跟每一個方法調用死磕到底,固然了,你也許有你得方法,我比較較真就是這樣看源代碼的,至少如今是這樣的。

  在看源代碼以前,官方文檔,示例什麼的最好也看一下,這樣才能跟接近做者的意思。

  這裏面有個對象,須要着重聲明,那就是bar。

  什麼是bar呢?

  每一個bar都是一個時刻股票各個價格的集合,即,當前價格,當前時間,最高價,最低價,成交量什麼的。

  而這些屬性都是經過get_xxx的方法獲取的。


獲取數據

很明顯數據是經過下面這行代碼獲取的。

feed = yahoofinance.build_feed([instrument], 2011, 2012, ".")

build_feed方法在tools/yahoofinance.py

def build_feed(instruments, fromYear, toYear, storage, frequency=bar.Frequency.DAY, timezone=None, skipErrors=False):
logger = pyalgotrade.logger.getLogger("yahoofinance")
    logger = pyalgotrade.logger.getLogger("yahoofinance")
    ret = yahoofeed.Feed(frequency, timezone)

    for year in range(fromYear, toYear+1):
        for instrument in instruments:
            fileName = os.path.join(storage, "%s-%d-yahoofinance.csv" % (instrument, year))
            if not os.path.exists(fileName):
                logger.info("Downloading %s %d to %s" % (instrument, year, fileName))
                try:
                    if frequency == bar.Frequency.DAY:
                        download_daily_bars(instrument, year, fileName)
                    elif frequency == bar.Frequency.WEEK:
                        download_weekly_bars(instrument, year, fileName)
                    else:
                        raise Exception("Invalid frequency")
                except Exception, e:
                    if skipErrors:
                        logger.error(str(e))
                        continue
                    else:
                        raise e
            ret.addBarsFromCSV(instrument, fileName)
    return ret


在build_feed函數裏面又根據狀況調用了相應的下載函數

def download_csv(instrument, begin, end, frequency):
    url = "http://ichart.finance.yahoo.com/table.csv?s=%s&a=%d&b=%d&c=%d&d=%d&e=%d&f=%d&g=%s&ignore=.csv" % (instrument, __adjust_month(begin.month), begin.day, begin.year, __adjust_month(end.month), end.day, end.year, frequency)
    return csvutils.download_csv(url)

  而最終執行的下載函數爲download_csv,經過這個函數咱們能夠訪問yahoo的api,最終下載函數,固然了,能夠進一步的查看csvutils.download_csv函數。

  這裏咱們知道數據是經過download_csv這個函數,將相應的股票代碼,開始結束時間及頻率傳入,而後訪問相應的url,獲得相應的數據。


feed對象


  在tools/yahoofinance.py中咱們能夠看到,返回的結果並非一個csv的對象,而是一個ret即,Feed對象,而Feed對象經過addBarsFromCSV將下載的數據加載到內存。

  從這裏你也許會開始抓狂了爲毛一層一層的繼承。


其中yahoofeed.Feed在barfeed/yahoofeed.py

class Feed(csvfeed.BarFeed):
    def addBarsFromCSV(self, instrument, path, timezone=None):
        rowParser = RowParser(
            self.getDailyBarTime(), self.getFrequency(), timezone, self.__sanitizeBars, self.__barClass
        )
        super(Feed, self).addBarsFromCSV(instrument, path, rowParser)

上面調用了父類的addBarsFromCSV方法。


父類的addBarsFromCSV在barfeed/csvfeed.py

class BarFeed(membf.BarFeed):
    def addBarsFromCSV(self, instrument, path, rowParser):
        # Load the csv file
        loadedBars = []
        reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())
        for row in reader:
            bar_ = rowParser.parseBar(row)
            if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):
                loadedBars.append(bar_)

        self.addBarsFromSequence(instrument, loadedBars)

而後csvfeed又調用了父類的方法~

值得注意的是,上面的rowParser.parseBar方法在子類實現的 。。。後面會在說起。


addBarsFromSequence方法在barfeed/membf.py

class BarFeed(barfeed.BaseBarFeed):
    def addBarsFromSequence(self, instrument, bars):
        if self.__started:
            raise Exception("Can't add more bars once you started consuming bars")

        self.__bars.setdefault(instrument, [])
        self.__nextPos.setdefault(instrument, 0)

        # Add and sort the bars
        self.__bars[instrument].extend(bars)
        barCmp = lambda x, y: cmp(x.getDateTime(), y.getDateTime())
        self.__bars[instrument].sort(barCmp)

        self.registerInstrument(instrument)

而後又調用了父類的方法~

值得注意的是這裏將yahoo的數據存在了self.__bars中,至於bars是什麼對象,後面再說。


registerInstrument方法在barfeed/__init__.py

class BaseBarFeed(feed.BaseFeed):
    def registerInstrument(self, instrument):
        self.__defaultInstrument = instrument
        self.registerDataSeries(instrument)


而後又調用了父類的方法~

registerDataSeries方法在feed/__init__.py

class BaseFeed(observer.Subject):
    def __init__(self, maxLen):
        super(BaseFeed, self).__init__()

        maxLen = dataseries.get_checked_max_len(maxLen)

        self.__ds = {}
        self.__event = observer.Event()
        self.__maxLen = maxLen
    def registerDataSeries(self, key):
        if key not in self.__ds:
            self.__ds[key] = self.createDataSeries(key, self.__maxLen)

  恩,這裏就是邏輯的終點了,雖然它仍是繼承,不過pyalgotrade裏面大多數對象都是是繼承observer.Subject,之因此繼承,是爲了完成相似觀察者的設計模式裏面的事件操做。

  簡單總結一下繼承關係。

barfeed/yahoofeed.Feed -> barfeed/csvfeed.BarFeed -> barfeed/membf.BarFeed -> barfeed/__init__.py.BaseFeed -> feed/__init.py.BaseFeed

  而後yahoo的數據結果,最終是由RowParser的parseBar方法依次導入,而RowPaser.parseBar方法是在barfeed/yahoofeed.py中。


  而後咱們再來走一遍加載數據的流程,不過此次不僅是整個邏輯,而此次咱們關注於具體的數據是啥。

其中barfeed/yahoofeed裏面的RowParser的邏輯及parsrBar的具體的具體實現,截取以下。

class RowParser(csvfeed.RowParser):
    def __init__(self, dailyBarTime, frequency, timezone=None, sanitize=False, barClass=bar.BasicBar):
        self.__dailyBarTime = dailyBarTime
        self.__frequency = frequency
        self.__timezone = timezone
        self.__sanitize = sanitize
        self.__barClass = barClass

    def __parseDate(self, dateString):
        ret = parse_date(dateString)
        # Time on Yahoo! Finance CSV files is empty. If told to set one, do it.
        if self.__dailyBarTime is not None:
            ret = datetime.datetime.combine(ret, self.__dailyBarTime)
        # Localize the datetime if a timezone was given.
        if self.__timezone:
            ret = dt.localize(ret, self.__timezone)
        return ret

    def getFieldNames(self):
        # It is expected for the first row to have the field names.
        return None

    def getDelimiter(self):
        return ","

    def parseBar(self, csvRowDict):
        dateTime = self.__parseDate(csvRowDict["Date"])
        close = float(csvRowDict["Close"])
        open_ = float(csvRowDict["Open"])
        high = float(csvRowDict["High"])
        low = float(csvRowDict["Low"])
        volume = float(csvRowDict["Volume"])
        adjClose = float(csvRowDict["Adj Close"])

        if self.__sanitize:
            open_, high, low, close = common.sanitize_ohlc(open_, high, low, close)

        return self.__barClass(dateTime, open_, high, low, close, volume, adjClose, self.__frequency)

  其中解析後返回的結果是一個bar.BasicBar對象。

  而後調用父類barfeed/csvfeed裏面的addBarsFromCSV方法,獲得一個bar.BasicBar對象的列表,即loadBars。傳入繼承於父類的addBarsFromSequence方法,截取以下。

class BarFeed(membf.BarFeed):
    def addBarsFromCSV(self, instrument, path, rowParser):
        # Load the csv file
        loadedBars = []
        reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())
        for row in reader:
            bar_ = rowParser.parseBar(row)
            if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):
                loadedBars.append(bar_)

        self.addBarsFromSequence(instrument, loadedBars)


下面則是處理addBarsFromSequence的操做,主要是建立了一個self.__bars的字典,每一個股票代碼對應相應時間段的bar.BasicBar對象的列表,而後調用父類的registerInstrument方法,傳入相應的股票代碼。

barfeed/membf.py --> BarFeed

class BarFeed(barfeed.BaseBarFeed):
    def addBarsFromSequence(self, instrument, bars):
        if self.__started:
            raise Exception("Can't add more bars once you started consuming bars")

        self.__bars.setdefault(instrument, [])
        self.__nextPos.setdefault(instrument, 0)

        # Add and sort the bars
        self.__bars[instrument].extend(bars)
        barCmp = lambda x, y: cmp(x.getDateTime(), y.getDateTime())
        self.__bars[instrument].sort(barCmp)

        self.registerInstrument(instrument)


下面則是registerInstrument的具體邏輯,即註冊DataSeries對象,而registerDataSeries方法是在父類實現。

barfeed/__init__.py --->BaseBarFeed

BaseBarFeed(feed.BaseFeed):
    def registerInstrument(self, instrument):
        self.__defaultInstrument = instrument
        self.registerDataSeries(instrument)


下面則是最終的registerDataSeries操做,建立了一個dataseries的對象。

feed/__init__.py  --->BaseFeed

class BaseFeed(observer.Subject):
    def registerDataSeries(self, key):
        if key not in self.__ds:
            self.__ds[key] = self.createDataSeries(key, self.__maxLen)


而createDataSeries方法並無在基類中實現。

@abc.abstractmethod
def createDataSeries(self, key, maxLen):
    raise NotImplementedError()


createDataSeries的具體實現則是在barfeed/__init__.py --->BaseBarFeed

    def createDataSeries(self, key, maxLen):
        ret = bards.BarDataSeries(maxLen)
        ret.setUseAdjustedValues(self.__useAdjustedValues)
        return ret


因此最終,feed對象有兩個重要的數據集。

一:

self.__bars

裏面的數據結構大概是{"instrument_xx":[bar1,bar2,bar3]}

self.__ds = {}

裏面的數據結構大概是self.__ds = {"instrument_xx": dataseries_xx}

其中instrument指特定的股票代碼,好比aapl,bar1,bar2則是bar.BasicBar對象,dataseries則是bards.BarDataSeries對象。

至於bar.BasicBar以及dataseries的數據結構究竟是什麼,你們能夠自行瞧瞧。

值得注意的是,父類與基類之間數據獲取不會經過共享變量的方式得到,好比最終經過基類self.__ds的數據是經過基類的getKeys的方法暴露給子類去獲取實際的數據。。


策略


初始化策略

strat = sma_crossover.SMACrossOver(feed, instrument, smaPeriod)

策略最終繼承於strategy.BacktestingStrategy


analyzer

建立一個stratanalyzer的實例並attach

sharpeRatioAnalyzer = sharpe.SharpeRatio()
strat.attachAnalyzer(sharpeRatioAnalyzer)

analyzer這裏暫時不說,由於,這裏主要將具體的策略實現,以及feed對象,analyzer以及broker的內容會放在下一篇文章講。


run

運行策略。

strat.run()


run方法在strategy/__init__.py裏面的BaseStrategy類。

class BaseStrategy(object):
    def run(self):
    """Call once (**and only once**) to run the strategy."""
        self.__dispatcher.run()
    if self.__barFeed.getCurrentBars() is not None:
        self.onFinish(self.__barFeed.getCurrentBars())
    else:
        raise Exception("Feed was empty")


而run方法會調用self.__dispatcher的run方法,即dispatcher.py裏面的Dispatcher類,在說Dispatcher類以前,咱們得先看看BaseStrategy在初始化的時候到底初始化了啥。

class BaseStrategy(object):
    def __init__(self, barFeed, broker):
        self.__barFeed = barFeed
        self.__broker = broker
        self.__activePositions = set()
        self.__orderToPosition = {}
        self.__barsProcessedEvent = observer.Event()
        self.__analyzers = []
        self.__namedAnalyzers = {}
        self.__resampledBarFeeds = []
        self.__dispatcher = dispatcher.Dispatcher()
        self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)
        self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)
        self.__dispatcher.getStartEvent().subscribe(self.onStart)
        self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)
        # It is important to dispatch broker events before feed events, specially if we're backtesting.
        self.__dispatcher.addSubject(self.__broker)
        self.__dispatcher.addSubject(self.__barFeed)

  綁定barFeed,broker到self,初始化__activePositions,OderToPosition,__analyzers,__namedAnlyzers,__resampledBarFeeds的值,並初始化一個observer.Event的實例。

  建立一個dispatcher的實例,並在dispatcher的初始化過程當中建立兩個observer.Event,observer.Event的實例。

  其中broker實例經過getOrderUpdatedEvent方法獲得一個event實例,並訂閱策略的onOrderEvent的事件

  barFeed實例經過getNewValuesEvent方法獲得一個event實例,並訂閱策略的onBars的事件。

  dispatcher的實例分別得到startEvent,IdleEvent並訂閱onStart,__onIdle事件。

  最後dispatcher實例將broker,barFeed兩個subject分別加入到dispatcher的subjects列表中。

  而後咱們在回到Dispatcher類的run方法,這裏主要是首先遍歷本身__subjects列表裏面的subject,而後調用每一個subject的start方法,由BaseStrategy類的初始化方法可知,dispatcher加入了兩個subject,分別是broker,barFeed。


具體實現以下。

class Dispatcher(object):
    def run(self):
    try:
        for subject in self.__subjects:
            subject.start()
        self.__startEvent.emit()
        
        while not self.__stop:
            eof, eventsDispatched = self.__dispatch()
        if eof:
            self.__stop = True
        elif not eventsDispatched:
            self.__idleEvent.emit()
    finally:
        for subject in self.__subjects:
            subject.stop()
        for subject in self.__subjects:
            subject.join()


整個回測策略的邏輯基本就是在dispatcher調度各個subject並觸發事件的過程。

調用完每一個subject的start方法後,執行自身的self.__startEvent.emit方法。

而後經過while循環啓動整個運轉邏輯。

在循環結束後依次啓動每一個subject並等待全部subject關閉。

如今再次回到初始化過程,查看各個event,subject的內容究竟是什麼。

self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)
    def __onOrderEvent(self, broker_, orderEvent):
        order = orderEvent.getOrder()
        self.onOrderUpdated(order)
        self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)
    def __onBars(self, dateTime, bars):
        # THE ORDER HERE IS VERY IMPORTANT
        # 1: Let analyzers process bars.
        self.__notifyAnalyzers(lambda s: s.beforeOnBars(self, bars))
        # 2: Let the strategy process current bars and submit orders.
        self.onBars(bars)
        # 3: Notify that the bars were processed.
        self.__barsProcessedEvent.emit(self, bars)
        self.__dispatcher.getStartEvent().subscribe(self.onStart)
    def onStart(self):
        """Override (optional) to get notified when the strategy starts executing. The default implementation is empty. """
        pass
        self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)
        def __onIdle(self):
        # Force a resample check to avoid depending solely on the underlying
        # barfeed events.
        for resampledBarFeed in self.__resampledBarFeeds:
        resampledBarFeed.checkNow(self.getCurrentDateTime())
        self.onIdle()

上面是各個event訂閱的subject,是相應的handler函數。


而後如今瞧瞧每一個subject的start方法。

其中observer.py裏面定義的Subject相似一個抽象工廠,只是定義了各個方法可是並無實現具體方法的邏輯。

咱們首先來看看broker這個subject的start方法的處理邏輯。

而繼承observer.Subject的Broker也只是一個抽象工廠,定義了一系列的接口。

在此策略中,咱們據代碼得知,咱們初始化的broker是一個backtesting的broker,代碼以下。

class BacktestingStrategy(BaseStrategy):
    def __init__(self, barFeed, cash_or_brk=1000000):
        # The broker should subscribe to barFeed events before the strategy.
        # This is to avoid executing orders submitted in the current tick.
        if isinstance(cash_or_brk, pyalgotrade.broker.Broker):
            broker = cash_or_brk
        else:
          broker = backtesting.Broker(cash_or_brk, barFeed)
        查看backtesting的broker
        broker/backtesting.py
        class Broker(broker.Broker):
        def start(self):
            super(Broker, self).start()

 

查看backtesting的broker -> broker/backtesting.py

        class Broker(broker.Broker):
        def start(self):
            super(Broker, self).start()


其中基類的start以下

observer.py
class Subject(object):
@abc.abstractmethod
def start(self):
pass


而後再來看barFeed的subject的start

其中barFeed也沒有本身定義start方法,即,start方法也是如上。


在每一個subject調用start方法後,dispatcher就會調用自身self.__startEvent.emit。而後到循環eof, eventsDispatched = self.__dispatch()

    def __dispatch(self):
        smallestDateTime = None
        eof = True
        eventsDispatched = False

        # Scan for the lowest datetime.
        for subject in self.__subjects:
            if not subject.eof():
                eof = False
                smallestDateTime = utils.safe_min(smallestDateTime, subject.peekDateTime())


再次實例建立的feed爲yahoofeed

而依次繼承於csvfeed.BarFeed,membf.BarFeed,barfeed.BaseBaseFeed,feed.BaseFeed

其中membf.BarFeed,BaseBarFeed都實現了eof方法。


經過代碼追蹤,咱們發現eof主要爲了判斷是否以及迭代完每個bar

代碼以下

    def eof(self):
        ret = True
        # Check if there is at least one more bar to return.
        for instrument, bars in self.__bars.iteritems():
            nextPos = self.__nextPos[instrument]
            if nextPos < len(bars):
                ret = False
                break
        return ret


其中self.__nextPos在addBarsFromSequence函數裏面已經將其定義爲0,也就是說,這個nextPos是爲了在迭代每一個bar的同時記錄迭代的位置,即索引位置。

當判斷完eof以後,則調用__dispatchSubject方法,迭代每一個subject並調用其dispatch方法。

其中dispatch的實如今基類feed/__init__.py

class BaseFeed(observer.Subject):
    def dispatch(self):
        dateTime, values = self.getNextValuesAndUpdateDS()
        if dateTime is not None:
            self.__event.emit(dateTime, values)
        return dateTime is not None


getNextValuesAndUpdateDS方法實如今feed/__init__.py

   def getNextValuesAndUpdateDS(self):
        dateTime, values = self.getNextValues()
        if dateTime is not None:
            for key, value in values.items():
                # Get or create the datseries for each key.
                try:
                    ds = self.__ds[key]
                except KeyError:
                    ds = self.createDataSeries(key, self.__maxLen)
                    self.__ds[key] = ds
                ds.appendWithDateTime(dateTime, value)
        return (dateTime, values)

    def __iter__(self):
        return feed_iterator(self)


而getNextValues的方法實如今barfeed/__init__.py

class BaseBarFeed(feed.BaseFeed):
    def getNextValues(self):
        dateTime = None
        bars = self.getNextBars()
        if bars is not None:
            dateTime = bars.getDateTime()

            # Check that current bar datetimes are greater than the previous one.
            if self.__currentBars is not None and self.__currentBars.getDateTime() >= dateTime:
                raise Exception(
                    "Bar date times are not in order. Previous datetime was %s and current datetime is %s" % (
                        self.__currentBars.getDateTime(),
                        dateTime
                    )
                )

            # Update self.__currentBars and self.__lastBars
            self.__currentBars = bars
            for instrument in bars.getInstruments():
                self.__lastBars[instrument] = bars[instrument]
        return (dateTime, bars)


其中 getNextBars的方法實如今barfeed/membf.py

class BarFeed(barfeed.BaseBarFeed):
    def getNextBars(self):
        # All bars must have the same datetime. We will return all the ones with the smallest datetime.
        smallestDateTime = self.peekDateTime()

        if smallestDateTime is None:
            return None

        # Make a second pass to get all the bars that had the smallest datetime.
        ret = {}
        for instrument, bars in self.__bars.iteritems():
            nextPos = self.__nextPos[instrument]
            if nextPos < len(bars) and bars[nextPos].getDateTime() == smallestDateTime:
                ret[instrument] = bars[nextPos]
                self.__nextPos[instrument] += 1

        if self.__currDateTime == smallestDateTime:
            raise Exception("Duplicate bars found for %s on %s" % (ret.keys(), smallestDateTime))

        self.__currDateTime = smallestDateTime
        return bar.Bars(ret)


其中Bars對象則是對bar的進一層封裝

提供方法以下。

def __getitem__(self, instrument):
return self.__barDict[instrument]
def __contains__(self, instrument):
return instrument in self.__barDict
def items(self):
def keys(self):
def getInstruments(self):
def getDateTime(self):
def getBar(self, instrument):


至此,咱們瞭解到了feed對象,以及每一個bar是怎麼迭代的,可是尚未看到每一個bar的處理操做。

因此在回到feed的dispatch方法,處理流程以下

    def dispatch(self):
        dateTime, values = self.getNextValuesAndUpdateDS()
        if dateTime is not None:
            self.__event.emit(dateTime, values)
        return dateTime is not None

須要着重說明的就是self.__event.emit(dateTime, values)

其中values是一個bar.Bars實例。


broker的dispatch方法

def dispatch(self):
# All events were already emitted while handling barfeed events.
pass


這裏,咱們能夠看到若是dataTime不是None的話,就會經過emit提交時間

而feed裏面註冊了__onBars的handlers

因此在每次迭代的時候都會觸發event的emit操做,即執行每一個在feed中註冊了的handler,這裏只註冊了一個handler--->__onBars

def __onBars(self, dateTime, bars):
    # THE ORDER HERE IS VERY IMPORTANT
    # 1: Let analyzers process bars.
    self.__notifyAnalyzers(lambda s: s.beforeOnBars(self, bars))
    # 2: Let the strategy process current bars and submit orders.
    self.onBars(bars)
    # 3: Notify that the bars were processed.
    self.__barsProcessedEvent.emit(self, bars)


因此迭代每個bar的時候,都會執行onBar的函數。

而onBar函數是本身定義的,在本示例中,onBar的函數內容以下

def onBars(self, bars):
    def onBars(self, bars):
        # If a position was not opened, check if we should enter a long position.
        if self.__position is None:
            if cross.cross_above(self.__prices, self.__sma) > 0:
                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())
                # Enter a buy market order. The order is good till canceled.
                self.__position = self.enterLong(self.__instrument, shares, True)
        # Check if we have to exit the position.
        elif not self.__position.exitActive() and cross.cross_below(self.__prices, self.__sma) > 0:
            self.__position.exitMarket()


bar是每一個指定頻率的open,close,low,high,adj close,volume數據集合對象。

DataSeries是一個隨着迭代,不斷增長datetime,以及bar的序列。

而technical的觸發是在feed/__init__.py裏面的ds.appendWithDateTime。

    def getNextValuesAndUpdateDS(self):
        dateTime, values = self.getNextValues()
        if dateTime is not None:
            for key, value in values.items():
                # Get or create the datseries for each key.
                try:
                    ds = self.__ds[key]
                except KeyError:
                    ds = self.createDataSeries(key, self.__maxLen)
                    self.__ds[key] = ds
                ds.appendWithDateTime(dateTime, value)
        return (dateTime, values)


而後ma.py

class SMA(technical.EventBasedFilter):
    def __init__(self, dataSeries, period, maxLen=None):
    super(SMA, self).__init__(dataSeries, SMAEventWindow(period), maxLen)


而後technical/__init__.py

class EventBasedFilter(dataseries.SequenceDataSeries):
    def __init__(self, windowSize, dtype=float, skipNone=True):
        assert(windowSize > 0)
        assert(isinstance(windowSize, int))
        self.__values = collections.NumPyDeque(windowSize, dtype)
        self.__windowSize = windowSize
        self.__skipNone = skipNone
    def __onNewValue(self, dataSeries, dateTime, value):
        # Let the event window perform calculations.
        self.__eventWindow.onNewValue(dateTime, value)
        # Get the resulting value
        newValue = self.__eventWindow.getValue()
        # Add the new value.
        self.appendWithDateTime(dateTime, newValue)


而__eventWindow.onNewValue在technical/ma.py

class SMAEventWindow(technical.EventWindow):
    def __init__(self, period):
        assert(period > 0)
        super(SMAEventWindow, self).__init__(period)
        self.__value = None

    def onNewValue(self, dateTime, value):
        firstValue = None
        if len(self.getValues()) > 0:
            firstValue = self.getValues()[0]
            assert(firstValue is not None)

        super(SMAEventWindow, self).onNewValue(dateTime, value)

        if value is not None and self.windowFull():
            if self.__value is None:
                self.__value = self.getValues().mean()
            else:
                self.__value = self.__value + value / float(self.getWindowSize()) - firstValue / float(self.getWindowSize())

    def getValue(self):
        return self.__value


至此基於pyalgotrade的一個簡單示例,按照其執行流程的源碼解讀到此完畢。



後記:後面有點亂了,寫篇文章仍是蠻費時間的,太長了,pyalgotrade的源碼解讀估計還得寫一段時間去了。

這就是從無到用寫個股票分析APP系列的衍生篇了。


參考連接:

Python設計模式: http://www.cnblogs.com/Liqiongyu/p/5916710.html

PyAlgoTrade 文檔: http://gbeced.github.io/pyalgotrade/docs/v0.6/html/index.html


若是以爲不錯,並有所收穫,請我喝杯茶唄


wKioL1lU4MXwELckAADg-gB3Tsc583.jpg-wh_50wKiom1lU4Mqg8rxIAADzypnX0FU518.jpg-wh_50

相關文章
相關標籤/搜索