functools下的partial模塊應用

問題

你有一個被其餘python代碼使用的callable對象,多是一個回調函數或者是一個處理器, 可是它的參數太多了,致使調用時出錯。python

解決方案

若是須要減小某個函數的參數個數,你可使用 functools.partial()partial() 函數容許你給一個或多個參數設置固定的值,減小接下來被調用時的參數個數。 爲了演示清楚,假設你有下面這樣的函數:服務器

def spam(a, b, c, d): print(a, b, c, d)

如今咱們使用 partial() 函數來固定某些參數值:網絡

>>> from functools import partial
>>> s1 = partial(spam, 1) # a = 1
>>> s1(2, 3, 4) 1 2 3 4
>>> s1(4, 5, 6) 1 4 5 6
>>> s2 = partial(spam, d=42) # d = 42
>>> s2(1, 2, 3) 1 2 3 42
>>> s2(4, 5, 5) 4 5 5 42
>>> s3 = partial(spam, 1, 2, d=42) # a = 1, b = 2, d = 42
>>> s3(3) 1 2 3 42
>>> s3(4) 1 2 4 42
>>> s3(5) 1 2 5 42
>>>

能夠看出 partial() 固定某些參數並返回一個新的callable對象。這個新的callable接受未賦值的參數, 而後跟以前已經賦值過的參數合併起來,最後將全部參數傳遞給原始函數。app

討論

本節要解決的問題是讓本來不兼容的代碼能夠一塊兒工做。下面我會列舉一系列的例子。異步

第一個例子是,假設你有一個點的列表來表示(x,y)座標元組。 你可使用下面的函數來計算兩點之間的距離:socket

points = [ (1, 2), (3, 4), (5, 6), (7, 8) ] import math def distance(p1, p2): x1, y1 = p1 x2, y2 = p2 return math.hypot(x2 - x1, y2 - y1)

說明一下這裏的math.hypot默認是以座標原點爲幾點計算座標到原點的直線距離async

import math print(math.hypot(6,8)) >>>10

如今假設你想以某個點爲基點,根據點和基點之間的距離來排序全部的這些點。 列表的 sort() 方法接受一個關鍵字參數來自定義排序邏輯, 可是它只能接受一個單個參數的函數(distance()很明顯是不符合條件的)。 如今咱們能夠經過使用 partial() 來解決這個問題:函數

>>> pt = (4, 3) >>> points.sort(key=partial(distance,pt)) >>> points [(3, 4), (1, 2), (5, 6), (7, 8)] >>>

更進一步,partial() 一般被用來微調其餘庫函數所使用的回調函數的參數。 例如,下面是一段代碼,使用 multiprocessing 來異步計算一個結果值, 而後這個值被傳遞給一個接受一個result值和一個可選logging參數的回調函數:ui

def output_result(result, log=None): if log is not None: log.debug('Got: %r', result) # A sample function def add(x, y): return x + y if __name__ == '__main__': import logging from multiprocessing import Pool from functools import partial logging.basicConfig(level=logging.DEBUG) log = logging.getLogger('test') p = Pool() p.apply_async(add, (3, 4), callback=partial(output_result, log=log)) p.close() p.join()

當給 apply_async() 提供回調函數時,經過使用 partial() 傳遞額外的 logging 參數。 而 multiprocessing 對這些一無所知——它僅僅只是使用單個值來調用回調函數。spa

做爲一個相似的例子,考慮下編寫網絡服務器的問題,socketserver 模塊讓它變得很容易。 下面是個簡單的echo服務器:

from socketserver import StreamRequestHandler, TCPServer class EchoHandler(StreamRequestHandler): def handle(self): for line in self.rfile: self.wfile.write(b'GOT:' + line) serv = TCPServer(('', 15000), EchoHandler) serv.serve_forever()

不過,假設你想給EchoHandler增長一個能夠接受其餘配置選項的 __init__ 方法。好比:

class EchoHandler(StreamRequestHandler): # ack is added keyword-only argument. *args, **kwargs are # any normal parameters supplied (which are passed on) def __init__(self, *args, ack, **kwargs): self.ack = ack super().__init__(*args, **kwargs) def handle(self): for line in self.rfile: self.wfile.write(self.ack + line)

這麼修改後,咱們就不須要顯式地在TCPServer類中添加前綴了。 可是你再次運行程序後會報相似下面的錯誤:

Exception happened during processing of request from ('127.0.0.1', 59834) Traceback (most recent call last): ... TypeError: __init__() missing 1 required keyword-only argument: 'ack'

初看起來好像很難修正這個錯誤,除了修改 socketserver 模塊源代碼或者使用某些奇怪的方法以外。 可是,若是使用 partial() 就能很輕鬆的解決——給它傳遞 ack 參數的值來初始化便可,以下:

from functools import partial serv = TCPServer(('', 15000), partial(EchoHandler, ack=b'RECEIVED:')) serv.serve_forever()

在這個例子中,__init__() 方法中的ack參數聲明方式看上去頗有趣,其實就是聲明ack爲一個強制關鍵字參數。 關於強制關鍵字參數問題咱們在7.2小節咱們已經討論過了,讀者能夠再去回顧一下。

不少時候 partial() 能實現的效果,lambda表達式也能實現。好比,以前的幾個例子可使用下面這樣的表達式:

points.sort(key=lambda p: distance(pt, p)) p.apply_async(add, (3, 4), callback=lambda result: output_result(result,log)) serv = TCPServer(('', 15000), lambda *args, **kwargs: EchoHandler(*args, ack=b'RECEIVED:', **kwargs))

這樣寫也能實現一樣的效果,不過相比而已會顯得比較臃腫,對於閱讀代碼的人來說也更加難懂。 這時候使用 partial() 能夠更加直觀的表達你的意圖(給某些參數預先賦值)。

相關文章
相關標籤/搜索