Python - softmax 實現

Softmax

softmax函數將任意n維的實值向量轉換爲取值範圍在(0,1)之間的n維實值向量,而且總和爲1。
例如:向量softmax([1.0, 2.0, 3.0]) ------> [0.09003057, 0.24472847, 0.66524096]app

性質:函數

  1. 由於softmax是單調遞增函數,所以不改變原始數據的大小順序。
  2. 將原始輸入映射到(0,1)區間,而且總和爲1,經常使用於表徵機率。
  3. softmax(x) = softmax(x+c), 這個性質用於保證數值的穩定性。

softmax的實現及數值穩定性

一個最簡單的計算給定向量的softmax的實現以下:測試

import numpy as np
def softmax(x):
"""Compute the softmax of vector x."""
    exp_x = np.exp(x)
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x

讓咱們來測試一下上面的代碼:code

softmax([1, 2, 3])
array([0.09003057, 0.24472847, 0.66524096])

可是,當咱們嘗試輸入一個比較大的數值向量時,就會出錯:input

softmax([1000, 2000, 3000])
array([nan, nan, nan])

這是由numpy中的浮點型數值範圍限制所致使的。當輸入一個較大的數值時,sofmax函數將會超出限制,致使出錯。
爲了解決這一問題,這時咱們就能用到sofmax的第三個性質,即:softmax(x) = softmax(x+c),
通常在實際運用中,一般設定c = - max(x)。
接下來,咱們從新定義softmax函數:io

import numpy as np
def softmax(x):
"""Compute the softmax in a numerically stable way."""
    x = x - np.max(x)
    exp_x = np.exp(x)
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x

而後再次測試一下:table

softmax([1000, 2000, 3000])
array([ 0.,  0.,  1.])

Done!function

以上都是基於向量上的softmax實現,下面提供了基於向量以及矩陣的softmax實現,代碼以下:import

import numpy as np
def softmax(x):
    """
    Compute the softmax function for each row of the input x.

    Arguments:
    x -- A N dimensional vector or M x N dimensional numpy matrix.

    Return:
    x -- You are allowed to modify x in-place
    """
    orig_shape = x.shape

    if len(x.shape) > 1:
        # Matrix
        exp_minmax = lambda x: np.exp(x - np.max(x))
        denom = lambda x: 1.0 / np.sum(x)
        x = np.apply_along_axis(exp_minmax,1,x)
        denominator = np.apply_along_axis(denom,1,x) 
        
        if len(denominator.shape) == 1:
            denominator = denominator.reshape((denominator.shape[0],1))
        
        x = x * denominator
    else:
        # Vector
        x_max = np.max(x)
        x = x - x_max
        numerator = np.exp(x)
        denominator =  1.0 / np.sum(numerator)
        x = numerator.dot(denominator)
    
    assert x.shape == orig_shape
    return x
相關文章
相關標籤/搜索