//
// The batch is (nL, nB, ...)
//
nB, nO = Yt1.shape[0]
// Variables:
//
// At: Activations for this timestep, i.e. output of hstack(Xt, Yt1)@W+b
// Yt1: Previous hidden-layer output (i.e. Y_{t-1}
After Change
ho = sigmoid(At3_ho) // 1c
hc = xp.tanh(At3_hc) // 1d
Ct3 = hf * Ct2 // 2a
Ct3 += hi * hc // 2b
tanhCt3 = tanh(Ct3) // 3a
Yt3 = tanhCt3 * ho // 3b
// We don"t need the gradient for this, it"s just for backprop calculation.
Gt3 = xp.concatenate((hf, hi, ho, hc), axis=-1)
return (Yt3, Ct3), Gt3
@jax_jit()
def backprop_lstm_gates(