小白的Python 學習筆記(九)itertools深度解析,滿滿的乾貨(下)

簡單實戰

你們好,我又來了,在通過以前兩篇文章的介紹後相信你們對itertools的一些常見的好用的方法有了一個大體的瞭解,我本身在學完以後仿照別人的例子進行了真實場景下的模擬練習,今天和你們一塊兒分享,有不少部分還能夠優化,但願有更好主意和建議的朋友們能夠留言哈,讓咱們一塊兒進步python

實戰:分析標準普爾500指數

數據源及目標

在這個例子中,咱們首先嚐試使用itertools來操做大型數據集:標準普爾500指數的歷史每日價格數據。 我會在這個部分的最後附上下載連接和py文件,這裏的數據源來自雅虎財經git

目標: 找到標準普爾500指數的單日最大收益,最大損失(百分比),和最長的增加週期github

首先咱們手上獲得了 SP500.csv ,讓咱們對數據有個大概的印象,前十行的數據以下:bash

Date,Open,High,Low,Close,Adj Close,Volume
1950-01-03,16.660000,16.660000,16.660000,16.660000,16.660000,1260000
1950-01-04,16.850000,16.850000,16.850000,16.850000,16.850000,1890000
1950-01-05,16.930000,16.930000,16.930000,16.930000,16.930000,2550000
1950-01-06,16.980000,16.980000,16.980000,16.980000,16.980000,2010000
1950-01-09,17.080000,17.080000,17.080000,17.080000,17.080000,2520000
1950-01-10,17.030001,17.030001,17.030001,17.030001,17.030001,2160000
1950-01-11,17.090000,17.090000,17.090000,17.090000,17.090000,2630000
1950-01-12,16.760000,16.760000,16.760000,16.760000,16.760000,2970000
1950-01-13,16.670000,16.670000,16.670000,16.670000,16.670000,3330000
複製代碼

爲了實現目標,具體思路以下:less

  • 讀取csv文件,並利用 Adj Close這一列轉換爲每日百分比變化的序列,表明收益,命名爲gain
  • 找到gain這一序列中的最大值和最小值,而且找到對應的日期,固然,有可能會出現對應多個日期的狀況,咱們這裏選取日期最近的就好。
  • 定義一個sequence叫作growth_streaks,其中包含了全部 gain中出現的連續爲正值的元素組成的tuple,咱們要找到這些tuples中長度最長的一個,從而定位其對應的開始時間和結束時間,固然這裏也是同樣,有可能出現最大長度同樣的的狀況,這種狀況下,咱們仍是選擇日期最近的。

這裏有關百分比的計算公式以下:函數

enter image description here

分步實現

首先在這裏,咱們會常常處理日期,爲了方便後續操做,這裏咱們引入collections模塊的namedtuple來實現對日期的相關操做:post

from collections import namedtuple


class DataPoint(namedtuple('DataPoint', ['date', 'value'])):
    __slots__ = ()

    def __le__(self, other):
        return self.value <= other.value

    def __lt__(self, other):
        return self.value < other.value

    def __gt__(self, other):
        return self.value > other.value
複製代碼

這裏有不少小技巧,以後我會再系統的開一個Python OOP筆記,會爲你們都講到,這裏面涉及的小知識點以下:學習

  • slots :這是一個節省變量內存的好東西,__slot__後面通常都是跟class中 init 方法裏面用到的變量,好處在於可以大量節省內存
  • namedtuple:能夠實現相似屬性同樣調用tuple裏面的元素,我在collections裏面詳細說過,你們能夠看看:小白的Python 學習筆記(七)神奇寶藏 Collections
  • le:運算符重載,能夠獲得class中一個變量的長度,必須是整數,也就是說若是傳入的是list,dict,tuple,set這些必定沒有問題,由於這些序列的長度必定是整數,這裏面傳遞的是tuple()
  • lt:運算符重載(less than ):能夠實現利用 < 比較一個class的不一樣對象中的值大小的比較
  • gt:運算符重載(greater than):能夠實現利用 > 比較一個class的不一樣對象中的值大小的比較

下面爲了喚醒你們的記憶,我這裏快速舉一個有關於namedtuple,le,lt,gt的小栗子:優化

from collections import namedtuple
class Person(namedtuple('person', ['name', 'age','city','job'])):

    def __le__(self):
        return len(self)

    def __lt__(self,other):
        return self.age < other.age

    def __gt__(self,other):
        return self.age > other.age


xiaobai = Person('xiaobai', 18, 'paris','student')
laobai = Person('Walter White',52, 'albuquerque','cook')


print('Infomation for first person: ', xiaobai)     # 顯示所有信息
print('Age of second person is: ', laobai.age)    # 根據name獲得tuple的數據
print(len(xiaobai))
print(xiaobai > laobai)
print(xiaobai < laobai)


Out: Infomation for first person:  Person(name='xiaobai', age=18, city='paris',job='student')
     Age of second person is:  52
     4
     False
     True
複製代碼

若是你們對這個例子中的一些地方還有疑問,不用擔憂,我會在下一個專欄Python OOP學習筆記中和你們慢慢說的 。好的,如今回到剛纔的實戰:ui

from collections import namedtuple

class DataPoint(namedtuple('DataPoint', ['date', 'value'])):
    __slots__ = ()

    def __le__(self, other):
        return self.value <= other.value

    def __lt__(self, other):
        return self.value < other.value

    def __gt__(self, other):
        return self.value > other.value
複製代碼

這裏咱們的DataPoint類有兩個主要屬性,一個是datetime類型的日期,一個是當天的標普500值

接下來讓咱們讀取csv文件,並將每行中的Date和Adj Close列中的值存爲DataPoint的對象,最後把全部的對象組合爲一個sequence序列:

import csv
from datetime import datetime


def read_prices(csvfile, _strptime=datetime.strptime):
    with open(csvfile) as infile:
        reader = csv.DictReader(infile)
        for row in reader:
            yield DataPoint(date=_strptime(row['Date'], '%Y-%m-%d').date(),
                            value=float(row['Adj Close']))


prices = tuple(read_prices('SP500.csv'))
複製代碼

read_prices()生成器打開 SP500.csv 並使用 csv.DictReader()讀取數據的每一行。DictReader()將每一行做爲 OrderedDict 返回,其中key是每行中的列名。

對於每一行,read_prices()都會生成一個DataPoint對象,其中包含「Date」和「Adj Close」列中的值。 最後,完整的數據點序列做爲元組提交給內存並存儲在prices變量中

Ps: Ordereddict是我在collections中漏掉的知識點,我立刻會補上,你們能夠隨時收藏小白的Python 學習筆記(七)神奇寶藏 Collections,我會繼續更新

接下來咱們要把prices這個轉變爲表達每日價格變化百分比的序列,利用的公式就是剛纔提到的,若是忘了的朋友能夠往回翻~

gains = tuple(DataPoint(day.date, 100*(day.value/prev_day.value - 1.))
                for day, prev_day in zip(prices[1:], prices))
複製代碼

爲了獲得標普500單日最大漲幅,咱們能夠用一下方法:

max_gain = DataPoint(None, 0)
for data_point in gains:
    max_gain = max(data_point, max_gain)

print(max_gain)   # DataPoint(date='2008-10-28', value=11.58)
複製代碼

咱們能夠把這個方法用以前提到過的reduce簡化一下:

import functools as ft

max_gain = ft.reduce(max, gains)

print(max_gain)  # DataPoint(date='2008-10-28', value=11.58)
複製代碼

這裏有關reduce 和 lambda的用法,咱們能夠經過一個小栗子來回憶一下:

import functools as ft
x = ft.reduce(lambda x,y:x+y,[1, 2, 3, 4, 5])
print(x)

Out: 15
複製代碼

固然,若是求和在實際場景直接用sum就好,這裏只是爲了讓你們有個印象,若是回憶不起來的老鐵們也沒有關係,輕輕點擊如下連接馬上重溫:

好了,書規正傳,咱們發現用reduce改進了for循環後獲得了一樣的結果,單日最大漲幅的日期也同樣,可是這裏須要注意的是reduce和剛纔的for循環徹底不是一回事

咱們能夠想象一下,假如CSV文件中的數據天天都是跌的話。 max_gain最後究竟是多少?

在 for 循環中,首先設置max_gain = DataPoint(None,0),所以若是沒有漲幅,則最終的max_gain值將是此空 DataPoint 對象。可是,reduce()解決方案會返回最小的單日跌幅,這不是咱們想要的,可能會引入一個難以找到的bug

這就是itertools能夠幫助到咱們的地方。 itertools.filterfalse()函數有兩個參數:一個返回True或False的函數,和一個可迭代的輸入。它返回一個迭代器,是迭代結果都爲False的狀況。這裏是個小栗子:

import itertools as it
only_positives = it.filterfalse(lambda x: x <= 0, [0, 1, -1, 2, -2])
print(list(only_positives))


Out:[1, 2]

複製代碼

因此如今咱們能夠用 itertools.filterfalse()去除掉gains中那些小於0或者爲負數的值,這樣reduce會僅僅做用在咱們想要的正收益上:

max_gain = ft.reduce(max, it.filterfalse(lambda p: p <= 0, gains))

複製代碼

這裏咱們默認爲gains中必定存在大於0的值,這也是事實,可是若是假設gains中沒有的話,咱們會報錯,所以在使用itertools.filterfalse()的實際場景中要注意到這一點。

針對這種狀況,可能你想到的應對方案是在合適的狀況下添加TryExpect捕獲錯誤,可是reduce有個更好的解決方案,reuce裏面能夠傳遞第三個參數,用作reduce返回結果不存在時的默認值,這一點和字典的get方法有殊途同歸之妙,若是對get有疑問的朋友能夠回顧我以前的文章:Python 進階之路 (二) Dict 進階寶典,初二快樂!,仍是看一個小栗子:

>>> ft.reduce(max, it.filterfalse(lambda x: x <= 0, [-1, -2, -3]), 0)
0
複製代碼

這回很好理解了,所以咱們應用到咱們標準普爾指數的實戰上:

zdp = DataPoint(None, 0)  # zero DataPoint
max_gain = ft.reduce(max, it.filterfalse(lambda p: p.value <= 0, diffs), zdp)
複製代碼

同理,對於標普500單日最大跌幅咱們也照貓畫虎:

max_loss = ft.reduce(min, it.filterfalse(lambda p: p.value > 0, gains), zdp)

print(max_loss)  # DataPoint(date='2018-02-08', value=-20.47)
複製代碼

根據咱們的數據源是2018年2月8號那一天,我沒有谷歌查詢那一天發生了什麼,你們感興趣能夠看看哈,可是應該是沒有問題的,由於數據源來自雅虎財經

如今咱們已經獲得了標普500歷史上的單日最大漲跌的日期,咱們接下來要找到它的最長時間段,其實這個問題等同於在gains序列中找到最長的連續爲正數的點的集合,itertools.takewhile()和itertools.dropwhile()函數很是適合處理這種狀況。

itertools.takewhile()接受兩個參數,一個爲判斷的條件,一個爲可迭代的序列,會返回第一個判斷結果爲False時以前的迭代過的全部元素,下面的小栗子很好的解釋了這一點

it.takewhile(lambda x: x < 3, [0, 1, 2, 3, 4])  # 0, 1, 2
複製代碼

itertools.dropwhile() 則偏偏相反:

it.dropwhile(lambda x: x < 3, [0, 1, 2, 3, 4])  # 3, 4

複製代碼

所以咱們能夠建立一下方法來實如今gains中找到連續爲正數的序列:

def consecutive_positives(sequence, zero=0):
    def _consecutives():
        for itr in it.repeat(iter(sequence)):
            yield tuple(it.takewhile(lambda p: p > zero,
                                     it.dropwhile(lambda p: p <= zero, itr)))
    return it.takewhile(lambda t: len(t), _consecutives())
    
growth_streaks = consecutive_positives(gains, zero=DataPoint(None, 0))
longest_streak = ft.reduce(lambda x, y: x if len(x) > len(y) else y,
                           growth_streaks)
複製代碼

最後讓咱們看一下完整的代碼:

from collections import namedtuple
import csv
from datetime import datetime
import itertools as it
import functools as ft


class DataPoint(namedtuple('DataPoint', ['date', 'value'])):
    __slots__ = ()

    def __le__(self, other):
        return self.value <= other.value

    def __lt__(self, other):
        return self.value < other.value

    def __gt__(self, other):
        return self.value > other.value


def consecutive_positives(sequence, zero=0):
    def _consecutives():
        for itr in it.repeat(iter(sequence)):
            yield tuple(it.takewhile(lambda p: p > zero,
                                     it.dropwhile(lambda p: p <= zero, itr)))
    return it.takewhile(lambda t: len(t), _consecutives())


def read_prices(csvfile, _strptime=datetime.strptime):
    with open(csvfile) as infile:
        reader = csv.DictReader(infile)
        for row in reader:
            yield DataPoint(date=_strptime(row['Date'], '%Y-%m-%d').date(),
                            value=float(row['Adj Close']))


# Read prices and calculate daily percent change.
prices = tuple(read_prices('SP500.csv'))
gains = tuple(DataPoint(day.date, 100*(day.value/prev_day.value - 1.))
              for day, prev_day in zip(prices[1:], prices))

# Find maximum daily gain/loss.
zdp = DataPoint(None, 0)  # zero DataPoint
max_gain = ft.reduce(max, it.filterfalse(lambda p: p.value <= zdp, gains))
max_loss = ft.reduce(min, it.filterfalse(lambda p: p.value > zdp, gains), zdp)


# Find longest growth streak.
growth_streaks = consecutive_positives(gains, zero=DataPoint(None, 0))
longest_streak = ft.reduce(lambda x, y: x if len(x) > len(y) else y,
                           growth_streaks)

# Display results.
print('Max gain: {1:.2f}% on {0}'.format(*max_gain))
print('Max loss: {1:.2f}% on {0}'.format(*max_loss))

print('Longest growth streak: {num_days} days ({first} to {last})'.format(
    num_days=len(longest_streak),
    first=longest_streak[0].date,
    last=longest_streak[-1].date
))
複製代碼

最終結果以下:

Max gain: 11.58% on 2008-10-13
Max loss: -20.47% on 1987-10-19
Longest growth streak: 14 days (1971-03-26 to 1971-04-15)
複製代碼

數據源能夠點擊這裏下載

總結

此次我爲你們梳理一個利用itertools進行了簡單實戰的小栗子,這裏咱們旨在多深刻了解itertools,可是真實的生活中,遇到這種問題,哪有這麼麻煩,一個pandas包就搞定了,我之後會和你們分享和pandas有關的知識,這一次接連三期的itertools總結但願你們喜歡。

itertools深度解析至此全劇終。

相關文章
相關標籤/搜索