def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4)
Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the update gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the update gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the output gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the output gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)
Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters)
Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilde),
c stands for the cell state (memory)
"""
# 從 "parameters" 中取出參數。
Wf = parameters["Wf"] # 遺忘門權重
bf = parameters["bf"]
Wi = parameters["Wi"] # 更新門權重 (注意變量名下標是i不是u哦)
bi = parameters["bi"] # (notice the variable name)
Wc = parameters["Wc"] # 候選值權重
bc = parameters["bc"]
Wo = parameters["Wo"] # 輸出門權重
bo = parameters["bo"]
Wy = parameters["Wy"] # 預測值權重
by = parameters["by"]
# 鏈接 a_prev 和 xt
concat = np.concatenate((a_prev, xt), axis=0)
# 等價於下面代碼
# 從 xt 和 Wy 中取出維度
# n_x, m = xt.shape
# n_y, n_a = Wy.shape
# concat = np.zeros((n_a + n_x, m))
# concat[: n_a, :] = a_prev
# concat[n_a :, :] = xt
# 計算 ft (遺忘門), it (更新門)的值
# cct (候選值), c_next (單元狀態),
# ot (輸出門), a_next (隱藏單元)
ft = sigmoid(np.dot(Wf, concat) + bf) # 遺忘門
it = sigmoid(np.dot(Wi, concat) + bi) # 更新門
cct = np.tanh(np.dot(Wc, concat) + bc) # 候選值
c_next = ft * c_prev + it * cct # 單元狀態
ot = sigmoid(np.dot(Wo, concat) + bo) # 輸出門
a_next = ot * np.tanh(c_next) # 隱藏狀態
# 計算LSTM的預測值
yt_pred = softmax(np.dot(Wy, a_next) + by)
# 用於反向傳播的緩存
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters)
return a_next, c_next, yt_pred, cache