謝謝朋友們的陪伴來一塊兒學習,這周的文章咱們瞭解一下 KNNBasic() 算法如何實現,以及搞清楚整個算法的結構梳理。node
首先在 surprise 中,全部的算法都從一個基類 AlgoBase() 中繼承而來,這個類設置了你們公有的功能函數的接口,如 fit() 函數用來讓算法對數據集進行擬合,predict() 函數對給定的 user 和 item 進行評分預測。算法
圖片 algo base 類ruby
以 KNN 系列算法爲例,首先定義一個名爲 knns 的庫,其中包含了多個 KNN 系列的算法,例如 KNNBasic(),KNNWithMeans(),KNNWithZScore() 等。微信
咱們在前面推薦理論系列文章中介紹了基於鄰域的思路,KNN 算法本質上就是基於鄰域的協同過濾方法,而這些變種則是對類似度的不一樣應用,好比對類似度進行歸一化,對用戶的平均打分標準進行歸一化等。機器學習
考慮到 KNN 有兩種對稱的子方法:基於用戶的協同過濾和基於物品的協同過濾。兩種方法在計算上是對稱的,因此 KNN 系列的方法有一個基類是 SymmetricAlgo(),這個類繼承了前面提到的 AlgoBase() 類,是全部 KNN 類的父類。函數
class SymmetricAlgo(AlgoBase): """This is an abstract class aimed to ease the use of symmetric algorithms.
A symmetric algorithm is an algorithm that can can be based on users or on items indifferently, e.g. all the algorithms in this module.
When the algo is user-based x denotes a user and y an item. Else, it's reversed. """
def __init__(self, sim_options={}, verbose=True, **kwargs):
AlgoBase.__init__(self, sim_options=sim_options, **kwargs) self.verbose = verbose
def fit(self, trainset):
AlgoBase.fit(self, trainset)
ub = self.sim_options['user_based'] self.n_x = self.trainset.n_users if ub else self.trainset.n_items self.n_y = self.trainset.n_items if ub else self.trainset.n_users self.xr = self.trainset.ur if ub else self.trainset.ir self.yr = self.trainset.ir if ub else self.trainset.ur
return self
def switch(self, u_stuff, i_stuff): """Return x_stuff and y_stuff depending on the user_based field."""
if self.sim_options['user_based']: return u_stuff, i_stuff else: return i_stuff, u_stuff
咱們能夠看到 SymmetricAlgo() 類中重寫了 fit() 函數,在父函數 fit() 函數的基礎上,還加上了一個功能就是判斷是基於用戶類似仍是基於物品類似。學習
獲得判斷結果 ub 後,fit() 函數利用 ub 的值對數據集從新進行處理,至關於經過這個函數後,咱們不用再區分基於用戶或者基於物品的區別,而只須要直接對獲得的 self.n_x,self.n_y 進行處理。測試
這裏咱們能夠看到還有一個 switch() 函數,內容就是根據 self.sim_options['user_based'] 的值來選擇是否翻轉給定的兩組輸入。其實和前面的 fit() 函數思想一致,只不過一個是爲了訓練集的數據處理,一個是爲了測試集的數據處理。this
接下來咱們以 KNNBasic() 函數爲例,介紹一下算法實現的具體思路,具體的計算方法本週的文章就不深刻,主要是捋清楚算法運行的流程和結構。spa
class KNNBasic(SymmetricalAlgo): """ Args: k(int): The (max) number of neighbors to take into account for aggregation (see :ref:`this note <actual_k_note>`). Default is ``40``. min_k(int): The minimum number of neighbors to take into account for aggregation. If there are not enough neighbors, the prediction is set to the global mean of all ratings. Default is ``1``. sim_options(dict): A dictionary of options for the similarity measure. See :ref:`similarity_measures_configuration` for accepted options. verbose(bool): Whether to print trace messages of bias estimation, similarity, etc. Default is True. """ def __init__(self, k=40, min_k=1, sim_options={}, verbose=True, **kwargs): SymmetricalAlgo.__init__(self, sim_options=sim_options, verbose=verbose, **kwargs) self.k = k self.min_k = min_k
def fit(self, trainset): SymmetricalAlgo.fit(self, trainset) self.sim = self.compute_similarities()
return self
def estimate(self, u, i):
if not (self.trainset.knows_user(u) and self.trainset.knows_item(i)): raise PredictionImpossible('User and/or item is unkown.')
x, y = self.switch(u, i)
neighbors = [(self.sim[x, x2], r) for x2, r in self.yr[y]] k_neighbors = heapq.nlargest(self.k, neighbors, key=lambda t: t[0])
sum_sim = sum_ratings = actual_k = 0 for (sim, r) in k_neighbors: if sim > 0: sum_sim += sim sum_ratings += sim*r actual_k += 1
if actual_k < self.min_k: raise PredictionImpossible('Not enough neighbors')
est = sum_ratings / sum_sim
details = {'actual_k': actual_k} return est, details
首先能夠看到在這個子類中再次重寫了 fit() 函數,從 fit() 函數內部經過 self 去調用父類中的 self.compute_similarities() 函數,這個函數被定義在 AlgoBase() 類中,這樣子對算法自己而言只須要直接調用 fit() 就能夠完成擬合過程,將具體的計算抽象了出去。
而 self.sim 則是返回的計算出來的類似性矩陣,給出了用戶兩兩之間的類似性。
再看 estimate() 函數須要輸入目標用戶 u 和目標評分項 i,最終返回一個預測的評分結果。在進行預測以前先調用父類的 switch() 函數,對基於用戶仍是基於物品進行肯定。
具體的計算方式則是按照給定的 k 值,獲得與輸入的給定用戶 u 最類似的 k 個用戶,具體的形式是一個列表,列表中有 k 個元組,每一個元組有兩項,分別是類似用戶和該用戶對物品 i 的評分。
經過獲得最類似的 k 個用戶及其評分,目標用戶對目標物品評分的預測就很簡單了,最直接的就是讓和每一個用戶的類似性乘以該用戶的評分,最終求和取平均則是預測結果。
經過上面的函數進行替代目標的 KNNBasic(),咱們的模型能夠完整的運行一個 demo。
總結
這篇文章很簡單,可是思路很清晰,能夠有效的看清源碼中的調用結構。下篇文章咱們再完善基類 AlgoBase() 的各個方法,將來再繼續豐富其它各種算法,歡迎朋友們和我一塊兒交流學習,共同進步。
往期回顧
若是喜歡做者,請關注我吧 ❤️
本文分享自微信公衆號 - 機器學習與推薦系統(ml-recsys)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。