LSTM Cell (Long Short-Term Memory)
Implement an LSTM (Long Short-Term Memory) cell. LSTM is a specialized RNN architecture that solves the vanishing gradient problem by using a gating mechanism with three gates: **forget gate**, **input gate**, and **output gate**, plus a separate **cell state** that carries long-term information.
Formula:
combined = [x_t ; h_{t-1}] (concatenate input and previous hidden state)
Forget gate: f_t = sigmoid(W_f @ combined)
Input gate: i_t = sigmoid(W_i @ combined)
Cell candidate: g_t = tanh(W_g @ combined)
Output gate: o_t = sigmoid(W_o @ combined)
New cell state: c_t = f_t * c_{t-1} + i_t * g_t
New hidden state: h_t = o_t * tanh(c_t)
Where sigmoid(x) = 1 / (1 + exp(-x))
Example:
input = [1.0, 0.5], prev_hidden = [0.0, 0.0], prev_cell = [0.0, 0.0]
W_f = [[0.5, 0.3, 0.1, 0.2], [0.1, 0.4, 0.2, 0.3]] (forget gate weights)
W_i = [[0.2, 0.1, 0.3, 0.4], [0.3, 0.2, 0.1, 0.5]] (input gate weights)
W_g = [[0.4, 0.2, 0.1, 0.3], [0.2, 0.5, 0.3, 0.1]] (cell candidate weights)
W_o = [[0.1, 0.3, 0.2, 0.4], [0.4, 0.1, 0.3, 0.2]] (output gate weights)
Step 1: combined = [1.0, 0.5, 0.0, 0.0]
Step 2: f_t = sigmoid(W_f @ combined) = sigmoid([0.65, 0.30]) = [0.657, 0.574]
Step 3: i_t = sigmoid(W_i @ combined) = sigmoid([0.25, 0.40]) = [0.562, 0.599]
Step 4: g_t = tanh(W_g @ combined) = tanh([0.50, 0.45]) = [0.462, 0.422]
Step 5: o_t = sigmoid(W_o @ combined) = sigmoid([0.25, 0.45]) = [0.562, 0.611]
Step 6: c_t = f_t * [0,0] + i_t * g_t = [0.260, 0.253]
Step 7: h_t = o_t * tanh(c_t) = [0.562, 0.611] * [0.254, 0.248] = [0.143, 0.151]
**Explanation:** The LSTM cell addresses the vanishing gradient problem that plagues standard RNNs. The **forget gate** decides what information to discard from the previous cell state. The **input gate** and **cell candidate** determine what new information to store. The **output gate** controls what part of the cell state is exposed as the hidden state. This gating mechanism allows LSTMs to learn long-term dependencies effectively, making them ideal for tasks like language modeling, machine translation, and speech recognition.
Constraints:
Test Cases
x=[1.0,0.5], h=[0.0,0.0], c=[0.0,0.0], W=shape(8,4)(h_new, c_new) both shape (2,)x=[0.0,0.0], h=[0.0,0.0], c=[0.0,0.0], W=anyh=[0.0,0.0], c=[0.0,0.0] (sigmoid(0)=0.5, tanh(0)=0)