pandas高效實現條件邏輯

做者|Louis Chan
編譯|VK
來源|Towards Data Sciencehtml

Python能夠說是當今最酷的編程語言(多虧了機器學習和數據科學),但與最好的編程語言之一C相比,它的效率並非很高。python

在開發機器學習模型時,很常見的狀況是,咱們須要根據從統計分析或上一次迭代的結果導出的硬編碼規則,而後以編程方式更新。認可這一點並不羞恥:我一直在用Pandas apply編寫代碼,直到有一天我對嵌套很是厭煩,因而決定研究(又稱Google)其餘更可維護、更高效的方法git

演示數據集

咱們將要使用的數據集是iris數據集,你能夠經過pandas或seaborn免費得到它。github

import pandas as pd
iris = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')

# import seaborn as sns
# iris = sns.load_dataset("iris")

iris數據集的前5行算法

數據統計信息編程

假設在初始分析以後,咱們但願用如下邏輯標記數據集:多線程

  • 若是萼片長度(sepal length)< 5.1,則標籤爲0;app

  • 不然,若是萼片寬度(sepal width)> 3.3和萼片長度< 5.8,則標籤爲1;機器學習

  • 不然,若是萼片寬度> 3.3,花瓣長度(petal length)> 5.1,則標籤爲2;編程語言

  • 不然,若是萼片寬度> 3.3,花瓣長度< 1.6且萼片長度< 6.4或花瓣寬度< 1.3,則標籤3;

  • 不然,若是萼片寬度>3.3且萼片長度< 6.4或花瓣寬度< 1.3,則標籤爲4;

  • 不然,若是萼片寬度> 3.3,則標籤爲5;

  • 不然標籤6

在深刻研究代碼以前,讓咱們快速地將一個新的label列設置爲None:

iris['label'] = None

Pandas.iterrows+嵌套If Else塊

若是你還在用這個,這篇博文絕對是適合你的地方!

%%timeit
for idx, row in iris.iterrows():
  if row['sepal_length'] < 5.1:
    iris.loc[idx, 'label'] = 0
  elif row['sepal_width'] > 3.3:
    if row['sepal_length'] < 5.8:
      iris.loc[idx, 'label'] = 1
    elif row['petal_length'] > 5.1:
      iris.loc[idx, 'label'] = 2
    elif (row['sepal_length'] < 6.4) or (row['petal_width'] < 1.3):
      if row['petal_length'] < 1.6:
        iris.loc[idx, 'label'] = 3
      else:
        iris.loc[idx, 'label'] = 4
    else:
      iris.loc[idx, 'label'] = 5
  else:
    iris.loc[idx, 'label'] = 6
1min 29s ± 8.91 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

時間挺長…好吧,咱們繼續…

Pandas .apply

Pandas.apply直接用於沿數據幀的軸或Series來應用函數。例如,若是咱們有一個函數f,它能夠是一個數列的和(例如,能夠是一個list, np.array, tuple等),並將其傳遞給以下數據幀,咱們將跨行求和:

def f(numbers):
    return sum(numbers)
    
df['Row Subtotal'] = df.apply(f, axis=1)

在axis=1上應用函數。默認狀況下,apply參數axis=0,即逐行應用函數;而axis=1將逐列應用函數。

如今咱們已經對pandas.apply有了基本的瞭解,如今讓咱們編寫分配標籤的邏輯代碼,看看它運行多長時間:

%%timeit
def rules(row):
  if row['sepal_length'] < 5.1:
    return 0
  elif row['sepal_width'] > 3.3:
    if row['sepal_length'] < 5.8:
      return 1
    elif row['petal_length'] > 5.1:
      return 2
    elif (row['sepal_length'] < 6.4) or (row['petal_width'] < 1.3):
      if row['petal_length'] < 1.6:
        return 3
      return 4
    return 5
  return 6

iris['label'] = iris.apply(rules, 1)
1.43 s ± 115 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

15萬行只須要1.43s比以前的水平有了很大的提升,但仍然很是緩慢。

想象一下,若是你須要處理一個由數百萬個交易數據或信貸批准組成的數據集,那麼每次咱們要應用一組規則並將函數應用在一個列時,它將佔用14秒以上。運行足夠多的列,你一個下午可能就沒了。

Pandas.loc[]

若是你熟悉SQL,那麼使用.loc[]爲新列賦值實際上只是一個帶有WHERE條件的UPDATE語句。所以,這應該比將函數應用於每一個行或列要好得多。

%%timeit
iris['label'] = 6
iris.loc[iris['sepal_width'] > 3.3, 'label'] = 5
iris.loc[
  (iris['sepal_width'] > 3.3) & 
  ((iris['sepal_length'] < 6.4) | (iris['petal_width'] < 1.3)), 
  'label'] = 4
iris.loc[
  (iris['sepal_width'] > 3.3) & 
  ((iris['sepal_length'] < 6.4) | (iris['petal_width'] < 1.3)) & 
  (iris['petal_length'] < 1.6), 
  'label'] = 3
iris.loc[
  (iris['sepal_width'] > 3.3) & 
  (iris['petal_length'] > 5.1), 
  'label'] = 2
iris.loc[
  (iris['sepal_width'] > 3.3) & 
  (iris['sepal_length'] < 5.8), 
  'label'] = 1
iris.loc[
  (iris['sepal_length'] < 5.1), 
  'label'] = 0
13.3 ms ± 837 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

如今咱們只花了前一次的十分之一的時間,這意味着當你在家工做的時候,你沒有更多的藉口離開辦公桌。不過,咱們目前只使用pandas內置的函數。儘管pandas爲咱們提供了一個很是方便的高級接口來與數據表交互,可是經過層層抽象,效率可能會下降。

Numpy.where

Numpy有一個較低級別的接口,容許與n維iterables(即向量、矩陣、張量等)進行更有效的交互。它的方法一般是基於C語言的,當涉及到更復雜的計算時,它使用了優化的算法,使得它比咱們從新發明的輪子更快。

根據numpy的官方文件,np.where()接受如下語法:

np.where(condition, return value if True, return value if False)

本質上,這是一種二分,其中條件將被計算爲布爾值並相應地返回值。這裏的技巧是條件實際上能夠是iterable(即布爾ndarray類型)。這意味着咱們能夠將df['feature']==1做爲條件,並將where邏輯編碼爲:

np.where(
    df['feature'] == 1, 
    'It is one', 
    'It is not one'
)

因此你可能會問,咱們如何用一個像np.where()這樣的二分函數來實現上述邏輯呢?答案很簡單,但卻使人不安。嵌套np.where()

%%timeit
iris['label'] = np.where(
  iris['sepal_length'] < 5.1,
  0,
  np.where(
    iris['sepal_width'] > 3.3,
    np.where(
      iris['sepal_length'] < 5.8,
      1,
      np.where(
        iris['petal_length'] > 5.1,
        2,
        np.where(
          (iris['sepal_length'] < 6.4) | (iris['petal_width'] < 1.3),
          np.where(
            iris['petal_length'] < 1.6,
            3,
            4
          ),
          5
        )
      )
    ),
    6
  )
)
3.6 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

恭喜你,你挺過來了。我不能告訴你我花了多少次來計算右括號,可是嘿,這就完成了!咱們又從pandas身上砍下了10毫秒。loc[]。然而,這個代碼片斷是不可維護的,這意味着,它是不可接受的。

Numpy.select

Numpy.select,它與.where不一樣,它是用來實現多線程邏輯的函數。

np.select(condlist, choicelist, default=0)

它的語法近似於np.where,但第一個參數如今是一個條件列表,它的長度應該與選項的長度相同。使用時要記住一件事np.select是在知足第一個條件後當即選擇一個選項。

這意味着,若是超集規則出如今列表中的子集規則以前,那麼子集選擇將永遠不會被選擇。具體說來:

condlist = [
    df['A'] <= 1,
    df['A'] < 1
]

choicelist = ['<=1', '<1']

selection = np.select(condlist, choicelist, default='>1')

由於全部命中df['A']<1的行也將被df['A']<=1捕獲,所以沒有行最終被標記爲'<1'。爲了不這種狀況發生,請務必在更具體的規則以前先制定一個不太具體的規則:

condlist = [
    df['A'] < 1, # < ───┬ 交換
    df['A'] <= 1 # < ───┘
]

choicelist = ['<1', '<=1'] # 記住也要更新這個!

selection = np.select(condlist, choicelist, default='>1')

從上面能夠看到,你須要同時更新condlist和choicelsit,以確保代碼順利運行。但說真的,這一步也耗咱們本身的時間。經過將其更改成字典,咱們將達到大體相同的時間和內存複雜性,但使用更易於維護的代碼片斷:

%%timeit
rules = {
  0: (iris['sepal_length'] < 5.1),
  1: (iris['sepal_width'] > 3.3) & (iris['sepal_length'] < 5.8),
  2: (iris['sepal_width'] > 3.3) & (iris['petal_length'] > 5.1),
  3: (
    (iris['sepal_width'] > 3.3) & \
    ((iris['sepal_length'] < 6.4) | (iris['petal_width'] < 1.3)) & \
    (iris['petal_length'] < 1.6)
  ),
  4: (
    (iris['sepal_width'] > 3.3) & \
    ((iris['sepal_length'] < 6.4) | (iris['petal_width'] < 1.3))
  ),
  5: (iris['sepal_width'] > 3.3),
}

iris['label'] = np.select(rules.values(), rules.keys(), default=6)
6.29 ms ± 475 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

大約是np.where的一半,但這不只使你免於對各類嵌套的調試,並且使choicelist發生了變化。以前我已經忘記更新choicelist太屢次了,以致於我花了四倍多的時間來調試個人機器學習模型。相信我,np.select和dict。這是很是好的選擇

優秀函數

  1. Numpy的向量化操做:若是你的代碼涉及循環和計算一元函數、二進制函數或對數字序列進行操做的函數。你應該經過將數據轉換爲numpy-ndarray來重構代碼,並充分利用numpy的向量化操做來極大地提升腳本的速度。在Numpy的官方文檔中查看一元函數、二元函數或對數字序列進行操做的函數的示例:https://www.pythonlikeyoumeanit.com/Module3_IntroducingNumpy/VectorizedOperations.html#NumPy’s-Mathematical-Functions

  2. np.vectorize:不要被這個函數的名字愚弄。這只是一個方便的函數,並不會使代碼運行得更快。要使用此函數,首先須要將邏輯編碼爲可調用函數,而後運行np.vectorize(你的函數)(你的數據系列)。另外一個大的缺點是須要將數據幀轉換爲一維的iterable,以便傳遞到「矢量化」函數中。結論:若是不方便使用np.vectorize,別使用。

  3. numba.njit:如今這是真正的向量化。它試圖將任何numpy值移動到儘量接近C語言,以提升其效率。雖然它能夠加速數值計算,但它也將本身限制爲數值計算,這意味着沒有pandas系列,沒有字符串索引,只有具備int、float、datetime、bool和category類型的numpy的ndarray。結論:若是你可以輕鬆地使用Numpy的ndarray並將邏輯轉換爲數值計算或僅轉換爲數值計算,那麼它將是一個很是優秀的選擇。從這裏瞭解更多:https://numba.pydata.org/numba-doc/dev/user/5minguide.html

結尾

若是可能的話,去爭取numba.njit;不然,使用np.select和dict就能夠幫助你遠航了。記住,每一點改進都會有幫助!

原文連接:https://towardsdatascience.com/efficient-implementation-of-conditional-logic-on-pandas-dataframes-4afa61eb7fce

歡迎關注磐創AI博客站:
http://panchuang.net/

sklearn機器學習中文官方文檔:
http://sklearn123.com/

歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/

相關文章
相關標籤/搜索