Example 28
intermediate
28
RNN
LSTM
GRU
Sequences
Neural Networks

Recurrent Neural Network Layers

Recurrent neural networks process sequential data by maintaining hidden state across time steps. This example demonstrates Deepbox's three recurrent layers: RNN (simple recurrent unit — prone to vanishing gradients for long sequences), LSTM (Long Short-Term Memory — uses gates to control information flow, handling long-range dependencies), and GRU (Gated Recurrent Unit — simplified LSTM with fewer parameters, often comparable performance). For each layer, you create an instance with specified input/hidden sizes, pass a sequence tensor [batch, seq_len, features], and inspect both the output sequence and final hidden state. The example shows how to stack layers and use bidirectional processing.

Deepbox Modules Used

deepbox/ndarraydeepbox/nn

What You Will Learn

  • RNN is simplest but suffers from vanishing gradients for long sequences
  • LSTM adds forget/input/output gates — handles long-range dependencies
  • GRU simplifies LSTM to 2 gates — fewer parameters, often similar performance
  • Output shape: [batch, seq_len, hidden_size] — one hidden state per time step
  • Use final hidden state for classification, full output for sequence-to-sequence

Source Code

28-rnn-lstm-gru/index.ts
1import { randn } from "deepbox/ndarray";2import { RNN, LSTM, GRU } from "deepbox/nn";34console.log("=== RNN / LSTM / GRU ===\n");56// Input: [batch=2, seq_len=5, features=10]7const x = randn([2, 5, 10]);89// Simple RNN10const rnn = new RNN(10, 20);  // 10 input, 20 hidden11const { output: rnnOut, hidden: rnnH } = rnn.forward(x);12console.log("RNN output:", rnnOut.shape);   // [2, 5, 20]13console.log("RNN hidden:", rnnH.shape);     // [2, 20]1415// LSTM (returns output, hidden, cell)16const lstm = new LSTM(10, 32);17const { output: lstmOut, hidden: lstmH, cell: lstmC } = lstm.forward(x);18console.log("\nLSTM output:", lstmOut.shape);  // [2, 5, 32]19console.log("LSTM hidden:", lstmH.shape);     // [2, 32]20console.log("LSTM cell:  ", lstmC.shape);     // [2, 32]2122// GRU23const gru = new GRU(10, 24);24const { output: gruOut, hidden: gruH } = gru.forward(x);25console.log("\nGRU output:", gruOut.shape);  // [2, 5, 24]26console.log("GRU hidden:", gruH.shape);     // [2, 24]

Console Output

$ npx tsx 28-rnn-lstm-gru/index.ts
=== RNN / LSTM / GRU ===

RNN output: [2, 5, 20]
RNN hidden: [2, 20]

LSTM output: [2, 5, 32]
LSTM hidden: [2, 32]
LSTM cell:   [2, 32]

GRU output: [2, 5, 24]
GRU hidden: [2, 24]