shrew_nn/
rnn.rs

1// Recurrent Neural Network layers — RNN, LSTM, GRU
2//
3// This module implements the three fundamental recurrent architectures:
4//
5//   1. RNNCell / RNN   — Vanilla Elman RNN
6//   2. LSTMCell / LSTM — Long Short-Term Memory
7//   3. GRUCell / GRU   — Gated Recurrent Unit
8//
9// Each *Cell operates on a single timestep. The full RNN/LSTM/GRU wraps the
10// cell and unrolls it over the sequence dimension, collecting all hidden
11// states into a single output tensor via differentiable `cat`.
12//
13// SHAPES (batch_first convention):
14//   input:  [batch, seq_len, input_size]
15//   output: [batch, seq_len, hidden_size]
16//   h_n:    [batch, hidden_size]       (RNN, GRU)
17//   (h_n, c_n): ([batch, hidden_size], [batch, hidden_size])  (LSTM)
18//
19// WEIGHT INITIALIZATION:
20//   All weights use Kaiming uniform U(-k, k) where k = sqrt(1/hidden_size),
21//   following PyTorch's default initialization for recurrent layers.
22
23use shrew_core::backend::Backend;
24use shrew_core::dtype::DType;
25use shrew_core::error::Result;
26use shrew_core::tensor::Tensor;
27
28// RNNCell — Single-step vanilla RNN
29//
30// h_t = tanh(x_t @ W_ih^T + b_ih + h_{t-1} @ W_hh^T + b_hh)
31//
32// This is the simplest recurrent unit. It suffers from vanishing gradients
33// for long sequences, which LSTM and GRU were designed to address.
34
35/// A single-step vanilla RNN cell.
36///
37/// Computes: `h' = tanh(x @ W_ih^T + b_ih + h @ W_hh^T + b_hh)`
38///
39/// # Shapes
40/// - input x: `[batch, input_size]`
41/// - hidden h: `[batch, hidden_size]`
42/// - output h': `[batch, hidden_size]`
43pub struct RNNCell<B: Backend> {
44    w_ih: Tensor<B>,         // [hidden_size, input_size]
45    w_hh: Tensor<B>,         // [hidden_size, hidden_size]
46    b_ih: Option<Tensor<B>>, // [1, hidden_size]
47    b_hh: Option<Tensor<B>>, // [1, hidden_size]
48    pub input_size: usize,
49    pub hidden_size: usize,
50}
51
52impl<B: Backend> RNNCell<B> {
53    /// Create a new RNNCell with Kaiming uniform initialization.
54    pub fn new(
55        input_size: usize,
56        hidden_size: usize,
57        use_bias: bool,
58        dtype: DType,
59        device: &B::Device,
60    ) -> Result<Self> {
61        let k = (1.0 / hidden_size as f64).sqrt();
62
63        let w_ih = Tensor::<B>::rand((hidden_size, input_size), dtype, device)?
64            .affine(2.0 * k, -k)?
65            .set_variable();
66        let w_hh = Tensor::<B>::rand((hidden_size, hidden_size), dtype, device)?
67            .affine(2.0 * k, -k)?
68            .set_variable();
69
70        let (b_ih, b_hh) = if use_bias {
71            let bi = Tensor::<B>::rand((1, hidden_size), dtype, device)?
72                .affine(2.0 * k, -k)?
73                .set_variable();
74            let bh = Tensor::<B>::rand((1, hidden_size), dtype, device)?
75                .affine(2.0 * k, -k)?
76                .set_variable();
77            (Some(bi), Some(bh))
78        } else {
79            (None, None)
80        };
81
82        Ok(RNNCell {
83            w_ih,
84            w_hh,
85            b_ih,
86            b_hh,
87            input_size,
88            hidden_size,
89        })
90    }
91
92    /// Forward: h' = tanh(x @ W_ih^T + b_ih + h @ W_hh^T + b_hh)
93    ///
94    /// - `x`: `[batch, input_size]`
95    /// - `h`: `[batch, hidden_size]`
96    /// - returns h': `[batch, hidden_size]`
97    pub fn forward(&self, x: &Tensor<B>, h: &Tensor<B>) -> Result<Tensor<B>> {
98        // x @ W_ih^T → [batch, hidden_size]
99        let wih_t = self.w_ih.t()?.contiguous()?;
100        let mut gates = x.matmul(&wih_t)?;
101        if let Some(ref b) = self.b_ih {
102            gates = gates.add(b)?;
103        }
104
105        // h @ W_hh^T → [batch, hidden_size]
106        let whh_t = self.w_hh.t()?.contiguous()?;
107        let mut h_part = h.matmul(&whh_t)?;
108        if let Some(ref b) = self.b_hh {
109            h_part = h_part.add(b)?;
110        }
111
112        gates.add(&h_part)?.tanh()
113    }
114
115    /// Return all trainable parameters.
116    pub fn parameters(&self) -> Vec<Tensor<B>> {
117        let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
118        if let Some(ref b) = self.b_ih {
119            params.push(b.clone());
120        }
121        if let Some(ref b) = self.b_hh {
122            params.push(b.clone());
123        }
124        params
125    }
126
127    /// Return all trainable parameters with names.
128    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
129        let mut named = vec![
130            ("w_ih".to_string(), self.w_ih.clone()),
131            ("w_hh".to_string(), self.w_hh.clone()),
132        ];
133        if let Some(ref b) = self.b_ih {
134            named.push(("b_ih".to_string(), b.clone()));
135        }
136        if let Some(ref b) = self.b_hh {
137            named.push(("b_hh".to_string(), b.clone()));
138        }
139        named
140    }
141}
142
143// RNN — Unrolled vanilla RNN over a sequence
144
145/// A multi-step vanilla RNN that unrolls an RNNCell over the sequence dimension.
146///
147/// # Shapes
148/// - input:  `[batch, seq_len, input_size]`
149/// - output: `[batch, seq_len, hidden_size]` — all hidden states
150/// - h_n:    `[batch, hidden_size]` — final hidden state
151pub struct RNN<B: Backend> {
152    cell: RNNCell<B>,
153}
154
155impl<B: Backend> RNN<B> {
156    /// Create a new RNN layer.
157    pub fn new(
158        input_size: usize,
159        hidden_size: usize,
160        use_bias: bool,
161        dtype: DType,
162        device: &B::Device,
163    ) -> Result<Self> {
164        let cell = RNNCell::new(input_size, hidden_size, use_bias, dtype, device)?;
165        Ok(RNN { cell })
166    }
167
168    /// Forward pass over the full sequence.
169    ///
170    /// - `x`: `[batch, seq_len, input_size]`
171    /// - `h0`: optional initial hidden state `[batch, hidden_size]`.
172    ///   If None, zeros are used.
173    ///
174    /// Returns `(output, h_n)` where:
175    /// - `output`: `[batch, seq_len, hidden_size]`
176    /// - `h_n`: `[batch, hidden_size]`
177    pub fn forward(&self, x: &Tensor<B>, h0: Option<&Tensor<B>>) -> Result<(Tensor<B>, Tensor<B>)> {
178        let dims = x.dims();
179        let batch = dims[0];
180        let seq_len = dims[1];
181
182        // Initialize hidden state
183        let mut h = match h0 {
184            Some(h) => h.clone(),
185            None => Tensor::<B>::zeros((batch, self.cell.hidden_size), x.dtype(), x.device())?,
186        };
187
188        // Unroll over timesteps
189        let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
190        for t in 0..seq_len {
191            // x_t: [batch, 1, input_size] → [batch, input_size]
192            let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
193            h = self.cell.forward(&x_t, &h)?;
194            // h: [batch, hidden_size] → [batch, 1, hidden_size] for stacking
195            outputs.push(h.reshape((batch, 1, self.cell.hidden_size))?);
196        }
197
198        // Stack: [batch, seq_len, hidden_size]
199        let output = Tensor::cat(&outputs, 1)?;
200        Ok((output, h))
201    }
202
203    /// Return all trainable parameters.
204    pub fn parameters(&self) -> Vec<Tensor<B>> {
205        self.cell.parameters()
206    }
207
208    /// Return all trainable parameters with names.
209    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
210        self.cell
211            .named_parameters()
212            .into_iter()
213            .map(|(k, v)| (format!("cell.{k}"), v))
214            .collect()
215    }
216
217    /// Access the underlying cell.
218    pub fn cell(&self) -> &RNNCell<B> {
219        &self.cell
220    }
221}
222
223// LSTMCell — Single-step LSTM
224//
225// The LSTM uses four gates (input, forget, cell, output) to control
226// information flow, solving the vanishing gradient problem of vanilla RNNs.
227//
228// gates = x @ W_ih^T + b_ih + h @ W_hh^T + b_hh    # [batch, 4*hidden]
229// i, f, g, o = chunk(gates, 4)
230// i = sigmoid(i)   — input gate:  how much new info to let in
231// f = sigmoid(f)   — forget gate: how much old info to keep
232// g = tanh(g)      — cell gate:   candidate values to add
233// o = sigmoid(o)   — output gate: how much state to expose
234// c' = f * c + i * g
235// h' = o * tanh(c')
236
237/// A single-step LSTM cell.
238///
239/// # Shapes
240/// - input x: `[batch, input_size]`
241/// - hidden h: `[batch, hidden_size]`
242/// - cell c: `[batch, hidden_size]`
243/// - output (h', c'): `([batch, hidden_size], [batch, hidden_size])`
244pub struct LSTMCell<B: Backend> {
245    w_ih: Tensor<B>,         // [4*hidden_size, input_size]
246    w_hh: Tensor<B>,         // [4*hidden_size, hidden_size]
247    b_ih: Option<Tensor<B>>, // [1, 4*hidden_size]
248    b_hh: Option<Tensor<B>>, // [1, 4*hidden_size]
249    pub input_size: usize,
250    pub hidden_size: usize,
251}
252
253impl<B: Backend> LSTMCell<B> {
254    /// Create a new LSTMCell with Kaiming uniform initialization.
255    pub fn new(
256        input_size: usize,
257        hidden_size: usize,
258        use_bias: bool,
259        dtype: DType,
260        device: &B::Device,
261    ) -> Result<Self> {
262        let gate_size = 4 * hidden_size;
263        let k = (1.0 / hidden_size as f64).sqrt();
264
265        let w_ih = Tensor::<B>::rand((gate_size, input_size), dtype, device)?
266            .affine(2.0 * k, -k)?
267            .set_variable();
268        let w_hh = Tensor::<B>::rand((gate_size, hidden_size), dtype, device)?
269            .affine(2.0 * k, -k)?
270            .set_variable();
271
272        let (b_ih, b_hh) = if use_bias {
273            let bi = Tensor::<B>::rand((1, gate_size), dtype, device)?
274                .affine(2.0 * k, -k)?
275                .set_variable();
276            let bh = Tensor::<B>::rand((1, gate_size), dtype, device)?
277                .affine(2.0 * k, -k)?
278                .set_variable();
279            (Some(bi), Some(bh))
280        } else {
281            (None, None)
282        };
283
284        Ok(LSTMCell {
285            w_ih,
286            w_hh,
287            b_ih,
288            b_hh,
289            input_size,
290            hidden_size,
291        })
292    }
293
294    /// Forward: compute (h', c') from (x, h, c)
295    ///
296    /// - `x`: `[batch, input_size]`
297    /// - `h`: `[batch, hidden_size]`
298    /// - `c`: `[batch, hidden_size]`
299    /// - returns `(h', c')`: `([batch, hidden_size], [batch, hidden_size])`
300    pub fn forward(
301        &self,
302        x: &Tensor<B>,
303        h: &Tensor<B>,
304        c: &Tensor<B>,
305    ) -> Result<(Tensor<B>, Tensor<B>)> {
306        // Compute all 4 gates at once: [batch, 4*hidden_size]
307        let wih_t = self.w_ih.t()?.contiguous()?;
308        let mut gates = x.matmul(&wih_t)?;
309        if let Some(ref b) = self.b_ih {
310            gates = gates.add(b)?;
311        }
312
313        let whh_t = self.w_hh.t()?.contiguous()?;
314        let mut h_part = h.matmul(&whh_t)?;
315        if let Some(ref b) = self.b_hh {
316            h_part = h_part.add(b)?;
317        }
318
319        gates = gates.add(&h_part)?;
320
321        // Split into 4 gates: each [batch, hidden_size]
322        let chunks = gates.chunk(4, 1)?;
323        let i_gate = chunks[0].sigmoid()?; // input gate
324        let f_gate = chunks[1].sigmoid()?; // forget gate
325        let g_gate = chunks[2].tanh()?; // cell gate (candidate)
326        let o_gate = chunks[3].sigmoid()?; // output gate
327
328        // c' = f * c + i * g
329        let c_new = f_gate.mul(c)?.add(&i_gate.mul(&g_gate)?)?;
330
331        // h' = o * tanh(c')
332        let h_new = o_gate.mul(&c_new.tanh()?)?;
333
334        Ok((h_new, c_new))
335    }
336
337    /// Return all trainable parameters.
338    pub fn parameters(&self) -> Vec<Tensor<B>> {
339        let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
340        if let Some(ref b) = self.b_ih {
341            params.push(b.clone());
342        }
343        if let Some(ref b) = self.b_hh {
344            params.push(b.clone());
345        }
346        params
347    }
348
349    /// Return all trainable parameters with names.
350    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
351        let mut named = vec![
352            ("w_ih".to_string(), self.w_ih.clone()),
353            ("w_hh".to_string(), self.w_hh.clone()),
354        ];
355        if let Some(ref b) = self.b_ih {
356            named.push(("b_ih".to_string(), b.clone()));
357        }
358        if let Some(ref b) = self.b_hh {
359            named.push(("b_hh".to_string(), b.clone()));
360        }
361        named
362    }
363}
364
365// LSTM — Unrolled LSTM over a sequence
366
367/// A multi-step LSTM that unrolls an LSTMCell over the sequence dimension.
368///
369/// # Shapes
370/// - input:  `[batch, seq_len, input_size]`
371/// - output: `[batch, seq_len, hidden_size]`
372/// - h_n:    `[batch, hidden_size]`
373/// - c_n:    `[batch, hidden_size]`
374pub struct LSTM<B: Backend> {
375    cell: LSTMCell<B>,
376}
377
378impl<B: Backend> LSTM<B> {
379    /// Create a new LSTM layer.
380    pub fn new(
381        input_size: usize,
382        hidden_size: usize,
383        use_bias: bool,
384        dtype: DType,
385        device: &B::Device,
386    ) -> Result<Self> {
387        let cell = LSTMCell::new(input_size, hidden_size, use_bias, dtype, device)?;
388        Ok(LSTM { cell })
389    }
390
391    /// Forward pass over the full sequence.
392    ///
393    /// - `x`: `[batch, seq_len, input_size]`
394    /// - `hc0`: optional initial `(h0, c0)`, each `[batch, hidden_size]`.
395    ///   If None, zeros are used.
396    ///
397    /// Returns `(output, (h_n, c_n))` where:
398    /// - `output`: `[batch, seq_len, hidden_size]`
399    /// - `h_n`, `c_n`: `[batch, hidden_size]`
400    #[allow(clippy::type_complexity)]
401    pub fn forward(
402        &self,
403        x: &Tensor<B>,
404        hc0: Option<(&Tensor<B>, &Tensor<B>)>,
405    ) -> Result<(Tensor<B>, (Tensor<B>, Tensor<B>))> {
406        let dims = x.dims();
407        let batch = dims[0];
408        let seq_len = dims[1];
409        let hs = self.cell.hidden_size;
410
411        // Initialize hidden and cell states
412        let (mut h, mut c) = match hc0 {
413            Some((h0, c0)) => (h0.clone(), c0.clone()),
414            None => (
415                Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
416                Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
417            ),
418        };
419
420        // Unroll over timesteps
421        let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
422        for t in 0..seq_len {
423            let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
424            let (h_new, c_new) = self.cell.forward(&x_t, &h, &c)?;
425            h = h_new;
426            c = c_new;
427            outputs.push(h.reshape((batch, 1, hs))?);
428        }
429
430        let output = Tensor::cat(&outputs, 1)?;
431        Ok((output, (h, c)))
432    }
433
434    /// Return all trainable parameters.
435    pub fn parameters(&self) -> Vec<Tensor<B>> {
436        self.cell.parameters()
437    }
438
439    /// Return all trainable parameters with names.
440    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
441        self.cell
442            .named_parameters()
443            .into_iter()
444            .map(|(k, v)| (format!("cell.{k}"), v))
445            .collect()
446    }
447
448    /// Access the underlying cell.
449    pub fn cell(&self) -> &LSTMCell<B> {
450        &self.cell
451    }
452}
453
454// GRUCell — Single-step GRU
455//
456// The GRU simplifies the LSTM by merging the forget and input gates into a
457// single "update" gate, and using a "reset" gate to control how much of the
458// previous hidden state to expose.
459//
460// gates_ih = x @ W_ih^T + b_ih          [batch, 3*hidden]
461// gates_hh = h @ W_hh^T + b_hh          [batch, 3*hidden]
462// r_ih, z_ih, n_ih = chunk(gates_ih, 3)
463// r_hh, z_hh, n_hh = chunk(gates_hh, 3)
464//
465// r = sigmoid(r_ih + r_hh)    — reset gate
466// z = sigmoid(z_ih + z_hh)    — update gate
467// n = tanh(n_ih + r * n_hh)   — new gate (candidate)
468//
469// h' = (1 - z) * n + z * h
470
471/// A single-step GRU cell.
472///
473/// # Shapes
474/// - input x: `[batch, input_size]`
475/// - hidden h: `[batch, hidden_size]`
476/// - output h': `[batch, hidden_size]`
477pub struct GRUCell<B: Backend> {
478    w_ih: Tensor<B>,         // [3*hidden_size, input_size]
479    w_hh: Tensor<B>,         // [3*hidden_size, hidden_size]
480    b_ih: Option<Tensor<B>>, // [1, 3*hidden_size]
481    b_hh: Option<Tensor<B>>, // [1, 3*hidden_size]
482    pub input_size: usize,
483    pub hidden_size: usize,
484}
485
486impl<B: Backend> GRUCell<B> {
487    /// Create a new GRUCell with Kaiming uniform initialization.
488    pub fn new(
489        input_size: usize,
490        hidden_size: usize,
491        use_bias: bool,
492        dtype: DType,
493        device: &B::Device,
494    ) -> Result<Self> {
495        let gate_size = 3 * hidden_size;
496        let k = (1.0 / hidden_size as f64).sqrt();
497
498        let w_ih = Tensor::<B>::rand((gate_size, input_size), dtype, device)?
499            .affine(2.0 * k, -k)?
500            .set_variable();
501        let w_hh = Tensor::<B>::rand((gate_size, hidden_size), dtype, device)?
502            .affine(2.0 * k, -k)?
503            .set_variable();
504
505        let (b_ih, b_hh) = if use_bias {
506            let bi = Tensor::<B>::rand((1, gate_size), dtype, device)?
507                .affine(2.0 * k, -k)?
508                .set_variable();
509            let bh = Tensor::<B>::rand((1, gate_size), dtype, device)?
510                .affine(2.0 * k, -k)?
511                .set_variable();
512            (Some(bi), Some(bh))
513        } else {
514            (None, None)
515        };
516
517        Ok(GRUCell {
518            w_ih,
519            w_hh,
520            b_ih,
521            b_hh,
522            input_size,
523            hidden_size,
524        })
525    }
526
527    /// Forward: compute h' from (x, h)
528    ///
529    /// - `x`: `[batch, input_size]`
530    /// - `h`: `[batch, hidden_size]`
531    /// - returns h': `[batch, hidden_size]`
532    pub fn forward(&self, x: &Tensor<B>, h: &Tensor<B>) -> Result<Tensor<B>> {
533        // Compute input-side and hidden-side gates
534        let wih_t = self.w_ih.t()?.contiguous()?;
535        let mut gates_ih = x.matmul(&wih_t)?;
536        if let Some(ref b) = self.b_ih {
537            gates_ih = gates_ih.add(b)?;
538        }
539
540        let whh_t = self.w_hh.t()?.contiguous()?;
541        let mut gates_hh = h.matmul(&whh_t)?;
542        if let Some(ref b) = self.b_hh {
543            gates_hh = gates_hh.add(b)?;
544        }
545
546        // Split each into 3 parts: reset, update, new
547        let ih_chunks = gates_ih.chunk(3, 1)?;
548        let hh_chunks = gates_hh.chunk(3, 1)?;
549
550        // r = sigmoid(r_ih + r_hh)  — reset gate
551        let r = ih_chunks[0].add(&hh_chunks[0])?.sigmoid()?;
552
553        // z = sigmoid(z_ih + z_hh)  — update gate
554        let z = ih_chunks[1].add(&hh_chunks[1])?.sigmoid()?;
555
556        // n = tanh(n_ih + r * n_hh)  — new gate (candidate hidden state)
557        let n = ih_chunks[2].add(&r.mul(&hh_chunks[2])?)?.tanh()?;
558
559        // h' = (1 - z) * n + z * h
560        // (1 - z) = z.affine(-1.0, 1.0)
561        let one_minus_z = z.affine(-1.0, 1.0)?;
562        one_minus_z.mul(&n)?.add(&z.mul(h)?)
563    }
564
565    /// Return all trainable parameters.
566    pub fn parameters(&self) -> Vec<Tensor<B>> {
567        let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
568        if let Some(ref b) = self.b_ih {
569            params.push(b.clone());
570        }
571        if let Some(ref b) = self.b_hh {
572            params.push(b.clone());
573        }
574        params
575    }
576
577    /// Return all trainable parameters with names.
578    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
579        let mut named = vec![
580            ("w_ih".to_string(), self.w_ih.clone()),
581            ("w_hh".to_string(), self.w_hh.clone()),
582        ];
583        if let Some(ref b) = self.b_ih {
584            named.push(("b_ih".to_string(), b.clone()));
585        }
586        if let Some(ref b) = self.b_hh {
587            named.push(("b_hh".to_string(), b.clone()));
588        }
589        named
590    }
591}
592
593// GRU — Unrolled GRU over a sequence
594
595/// A multi-step GRU that unrolls a GRUCell over the sequence dimension.
596///
597/// # Shapes
598/// - input:  `[batch, seq_len, input_size]`
599/// - output: `[batch, seq_len, hidden_size]`
600/// - h_n:    `[batch, hidden_size]`
601pub struct GRU<B: Backend> {
602    cell: GRUCell<B>,
603}
604
605impl<B: Backend> GRU<B> {
606    /// Create a new GRU layer.
607    pub fn new(
608        input_size: usize,
609        hidden_size: usize,
610        use_bias: bool,
611        dtype: DType,
612        device: &B::Device,
613    ) -> Result<Self> {
614        let cell = GRUCell::new(input_size, hidden_size, use_bias, dtype, device)?;
615        Ok(GRU { cell })
616    }
617
618    /// Forward pass over the full sequence.
619    ///
620    /// - `x`: `[batch, seq_len, input_size]`
621    /// - `h0`: optional initial hidden state `[batch, hidden_size]`.
622    ///   If None, zeros are used.
623    ///
624    /// Returns `(output, h_n)` where:
625    /// - `output`: `[batch, seq_len, hidden_size]`
626    /// - `h_n`: `[batch, hidden_size]`
627    pub fn forward(&self, x: &Tensor<B>, h0: Option<&Tensor<B>>) -> Result<(Tensor<B>, Tensor<B>)> {
628        let dims = x.dims();
629        let batch = dims[0];
630        let seq_len = dims[1];
631        let hs = self.cell.hidden_size;
632
633        let mut h = match h0 {
634            Some(h) => h.clone(),
635            None => Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
636        };
637
638        let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
639        for t in 0..seq_len {
640            let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
641            h = self.cell.forward(&x_t, &h)?;
642            outputs.push(h.reshape((batch, 1, hs))?);
643        }
644
645        let output = Tensor::cat(&outputs, 1)?;
646        Ok((output, h))
647    }
648
649    /// Return all trainable parameters.
650    pub fn parameters(&self) -> Vec<Tensor<B>> {
651        self.cell.parameters()
652    }
653
654    /// Return all trainable parameters with names.
655    pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
656        self.cell
657            .named_parameters()
658            .into_iter()
659            .map(|(k, v)| (format!("cell.{k}"), v))
660            .collect()
661    }
662
663    /// Access the underlying cell.
664    pub fn cell(&self) -> &GRUCell<B> {
665        &self.cell
666    }
667}