1use shrew_core::backend::Backend;
24use shrew_core::dtype::DType;
25use shrew_core::error::Result;
26use shrew_core::tensor::Tensor;
27
28pub struct RNNCell<B: Backend> {
44 w_ih: Tensor<B>, w_hh: Tensor<B>, b_ih: Option<Tensor<B>>, b_hh: Option<Tensor<B>>, pub input_size: usize,
49 pub hidden_size: usize,
50}
51
52impl<B: Backend> RNNCell<B> {
53 pub fn new(
55 input_size: usize,
56 hidden_size: usize,
57 use_bias: bool,
58 dtype: DType,
59 device: &B::Device,
60 ) -> Result<Self> {
61 let k = (1.0 / hidden_size as f64).sqrt();
62
63 let w_ih = Tensor::<B>::rand((hidden_size, input_size), dtype, device)?
64 .affine(2.0 * k, -k)?
65 .set_variable();
66 let w_hh = Tensor::<B>::rand((hidden_size, hidden_size), dtype, device)?
67 .affine(2.0 * k, -k)?
68 .set_variable();
69
70 let (b_ih, b_hh) = if use_bias {
71 let bi = Tensor::<B>::rand((1, hidden_size), dtype, device)?
72 .affine(2.0 * k, -k)?
73 .set_variable();
74 let bh = Tensor::<B>::rand((1, hidden_size), dtype, device)?
75 .affine(2.0 * k, -k)?
76 .set_variable();
77 (Some(bi), Some(bh))
78 } else {
79 (None, None)
80 };
81
82 Ok(RNNCell {
83 w_ih,
84 w_hh,
85 b_ih,
86 b_hh,
87 input_size,
88 hidden_size,
89 })
90 }
91
92 pub fn forward(&self, x: &Tensor<B>, h: &Tensor<B>) -> Result<Tensor<B>> {
98 let wih_t = self.w_ih.t()?.contiguous()?;
100 let mut gates = x.matmul(&wih_t)?;
101 if let Some(ref b) = self.b_ih {
102 gates = gates.add(b)?;
103 }
104
105 let whh_t = self.w_hh.t()?.contiguous()?;
107 let mut h_part = h.matmul(&whh_t)?;
108 if let Some(ref b) = self.b_hh {
109 h_part = h_part.add(b)?;
110 }
111
112 gates.add(&h_part)?.tanh()
113 }
114
115 pub fn parameters(&self) -> Vec<Tensor<B>> {
117 let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
118 if let Some(ref b) = self.b_ih {
119 params.push(b.clone());
120 }
121 if let Some(ref b) = self.b_hh {
122 params.push(b.clone());
123 }
124 params
125 }
126
127 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
129 let mut named = vec![
130 ("w_ih".to_string(), self.w_ih.clone()),
131 ("w_hh".to_string(), self.w_hh.clone()),
132 ];
133 if let Some(ref b) = self.b_ih {
134 named.push(("b_ih".to_string(), b.clone()));
135 }
136 if let Some(ref b) = self.b_hh {
137 named.push(("b_hh".to_string(), b.clone()));
138 }
139 named
140 }
141}
142
143pub struct RNN<B: Backend> {
152 cell: RNNCell<B>,
153}
154
155impl<B: Backend> RNN<B> {
156 pub fn new(
158 input_size: usize,
159 hidden_size: usize,
160 use_bias: bool,
161 dtype: DType,
162 device: &B::Device,
163 ) -> Result<Self> {
164 let cell = RNNCell::new(input_size, hidden_size, use_bias, dtype, device)?;
165 Ok(RNN { cell })
166 }
167
168 pub fn forward(&self, x: &Tensor<B>, h0: Option<&Tensor<B>>) -> Result<(Tensor<B>, Tensor<B>)> {
178 let dims = x.dims();
179 let batch = dims[0];
180 let seq_len = dims[1];
181
182 let mut h = match h0 {
184 Some(h) => h.clone(),
185 None => Tensor::<B>::zeros((batch, self.cell.hidden_size), x.dtype(), x.device())?,
186 };
187
188 let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
190 for t in 0..seq_len {
191 let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
193 h = self.cell.forward(&x_t, &h)?;
194 outputs.push(h.reshape((batch, 1, self.cell.hidden_size))?);
196 }
197
198 let output = Tensor::cat(&outputs, 1)?;
200 Ok((output, h))
201 }
202
203 pub fn parameters(&self) -> Vec<Tensor<B>> {
205 self.cell.parameters()
206 }
207
208 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
210 self.cell
211 .named_parameters()
212 .into_iter()
213 .map(|(k, v)| (format!("cell.{k}"), v))
214 .collect()
215 }
216
217 pub fn cell(&self) -> &RNNCell<B> {
219 &self.cell
220 }
221}
222
223pub struct LSTMCell<B: Backend> {
245 w_ih: Tensor<B>, w_hh: Tensor<B>, b_ih: Option<Tensor<B>>, b_hh: Option<Tensor<B>>, pub input_size: usize,
250 pub hidden_size: usize,
251}
252
253impl<B: Backend> LSTMCell<B> {
254 pub fn new(
256 input_size: usize,
257 hidden_size: usize,
258 use_bias: bool,
259 dtype: DType,
260 device: &B::Device,
261 ) -> Result<Self> {
262 let gate_size = 4 * hidden_size;
263 let k = (1.0 / hidden_size as f64).sqrt();
264
265 let w_ih = Tensor::<B>::rand((gate_size, input_size), dtype, device)?
266 .affine(2.0 * k, -k)?
267 .set_variable();
268 let w_hh = Tensor::<B>::rand((gate_size, hidden_size), dtype, device)?
269 .affine(2.0 * k, -k)?
270 .set_variable();
271
272 let (b_ih, b_hh) = if use_bias {
273 let bi = Tensor::<B>::rand((1, gate_size), dtype, device)?
274 .affine(2.0 * k, -k)?
275 .set_variable();
276 let bh = Tensor::<B>::rand((1, gate_size), dtype, device)?
277 .affine(2.0 * k, -k)?
278 .set_variable();
279 (Some(bi), Some(bh))
280 } else {
281 (None, None)
282 };
283
284 Ok(LSTMCell {
285 w_ih,
286 w_hh,
287 b_ih,
288 b_hh,
289 input_size,
290 hidden_size,
291 })
292 }
293
294 pub fn forward(
301 &self,
302 x: &Tensor<B>,
303 h: &Tensor<B>,
304 c: &Tensor<B>,
305 ) -> Result<(Tensor<B>, Tensor<B>)> {
306 let wih_t = self.w_ih.t()?.contiguous()?;
308 let mut gates = x.matmul(&wih_t)?;
309 if let Some(ref b) = self.b_ih {
310 gates = gates.add(b)?;
311 }
312
313 let whh_t = self.w_hh.t()?.contiguous()?;
314 let mut h_part = h.matmul(&whh_t)?;
315 if let Some(ref b) = self.b_hh {
316 h_part = h_part.add(b)?;
317 }
318
319 gates = gates.add(&h_part)?;
320
321 let chunks = gates.chunk(4, 1)?;
323 let i_gate = chunks[0].sigmoid()?; let f_gate = chunks[1].sigmoid()?; let g_gate = chunks[2].tanh()?; let o_gate = chunks[3].sigmoid()?; let c_new = f_gate.mul(c)?.add(&i_gate.mul(&g_gate)?)?;
330
331 let h_new = o_gate.mul(&c_new.tanh()?)?;
333
334 Ok((h_new, c_new))
335 }
336
337 pub fn parameters(&self) -> Vec<Tensor<B>> {
339 let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
340 if let Some(ref b) = self.b_ih {
341 params.push(b.clone());
342 }
343 if let Some(ref b) = self.b_hh {
344 params.push(b.clone());
345 }
346 params
347 }
348
349 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
351 let mut named = vec![
352 ("w_ih".to_string(), self.w_ih.clone()),
353 ("w_hh".to_string(), self.w_hh.clone()),
354 ];
355 if let Some(ref b) = self.b_ih {
356 named.push(("b_ih".to_string(), b.clone()));
357 }
358 if let Some(ref b) = self.b_hh {
359 named.push(("b_hh".to_string(), b.clone()));
360 }
361 named
362 }
363}
364
365pub struct LSTM<B: Backend> {
375 cell: LSTMCell<B>,
376}
377
378impl<B: Backend> LSTM<B> {
379 pub fn new(
381 input_size: usize,
382 hidden_size: usize,
383 use_bias: bool,
384 dtype: DType,
385 device: &B::Device,
386 ) -> Result<Self> {
387 let cell = LSTMCell::new(input_size, hidden_size, use_bias, dtype, device)?;
388 Ok(LSTM { cell })
389 }
390
391 #[allow(clippy::type_complexity)]
401 pub fn forward(
402 &self,
403 x: &Tensor<B>,
404 hc0: Option<(&Tensor<B>, &Tensor<B>)>,
405 ) -> Result<(Tensor<B>, (Tensor<B>, Tensor<B>))> {
406 let dims = x.dims();
407 let batch = dims[0];
408 let seq_len = dims[1];
409 let hs = self.cell.hidden_size;
410
411 let (mut h, mut c) = match hc0 {
413 Some((h0, c0)) => (h0.clone(), c0.clone()),
414 None => (
415 Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
416 Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
417 ),
418 };
419
420 let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
422 for t in 0..seq_len {
423 let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
424 let (h_new, c_new) = self.cell.forward(&x_t, &h, &c)?;
425 h = h_new;
426 c = c_new;
427 outputs.push(h.reshape((batch, 1, hs))?);
428 }
429
430 let output = Tensor::cat(&outputs, 1)?;
431 Ok((output, (h, c)))
432 }
433
434 pub fn parameters(&self) -> Vec<Tensor<B>> {
436 self.cell.parameters()
437 }
438
439 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
441 self.cell
442 .named_parameters()
443 .into_iter()
444 .map(|(k, v)| (format!("cell.{k}"), v))
445 .collect()
446 }
447
448 pub fn cell(&self) -> &LSTMCell<B> {
450 &self.cell
451 }
452}
453
454pub struct GRUCell<B: Backend> {
478 w_ih: Tensor<B>, w_hh: Tensor<B>, b_ih: Option<Tensor<B>>, b_hh: Option<Tensor<B>>, pub input_size: usize,
483 pub hidden_size: usize,
484}
485
486impl<B: Backend> GRUCell<B> {
487 pub fn new(
489 input_size: usize,
490 hidden_size: usize,
491 use_bias: bool,
492 dtype: DType,
493 device: &B::Device,
494 ) -> Result<Self> {
495 let gate_size = 3 * hidden_size;
496 let k = (1.0 / hidden_size as f64).sqrt();
497
498 let w_ih = Tensor::<B>::rand((gate_size, input_size), dtype, device)?
499 .affine(2.0 * k, -k)?
500 .set_variable();
501 let w_hh = Tensor::<B>::rand((gate_size, hidden_size), dtype, device)?
502 .affine(2.0 * k, -k)?
503 .set_variable();
504
505 let (b_ih, b_hh) = if use_bias {
506 let bi = Tensor::<B>::rand((1, gate_size), dtype, device)?
507 .affine(2.0 * k, -k)?
508 .set_variable();
509 let bh = Tensor::<B>::rand((1, gate_size), dtype, device)?
510 .affine(2.0 * k, -k)?
511 .set_variable();
512 (Some(bi), Some(bh))
513 } else {
514 (None, None)
515 };
516
517 Ok(GRUCell {
518 w_ih,
519 w_hh,
520 b_ih,
521 b_hh,
522 input_size,
523 hidden_size,
524 })
525 }
526
527 pub fn forward(&self, x: &Tensor<B>, h: &Tensor<B>) -> Result<Tensor<B>> {
533 let wih_t = self.w_ih.t()?.contiguous()?;
535 let mut gates_ih = x.matmul(&wih_t)?;
536 if let Some(ref b) = self.b_ih {
537 gates_ih = gates_ih.add(b)?;
538 }
539
540 let whh_t = self.w_hh.t()?.contiguous()?;
541 let mut gates_hh = h.matmul(&whh_t)?;
542 if let Some(ref b) = self.b_hh {
543 gates_hh = gates_hh.add(b)?;
544 }
545
546 let ih_chunks = gates_ih.chunk(3, 1)?;
548 let hh_chunks = gates_hh.chunk(3, 1)?;
549
550 let r = ih_chunks[0].add(&hh_chunks[0])?.sigmoid()?;
552
553 let z = ih_chunks[1].add(&hh_chunks[1])?.sigmoid()?;
555
556 let n = ih_chunks[2].add(&r.mul(&hh_chunks[2])?)?.tanh()?;
558
559 let one_minus_z = z.affine(-1.0, 1.0)?;
562 one_minus_z.mul(&n)?.add(&z.mul(h)?)
563 }
564
565 pub fn parameters(&self) -> Vec<Tensor<B>> {
567 let mut params = vec![self.w_ih.clone(), self.w_hh.clone()];
568 if let Some(ref b) = self.b_ih {
569 params.push(b.clone());
570 }
571 if let Some(ref b) = self.b_hh {
572 params.push(b.clone());
573 }
574 params
575 }
576
577 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
579 let mut named = vec![
580 ("w_ih".to_string(), self.w_ih.clone()),
581 ("w_hh".to_string(), self.w_hh.clone()),
582 ];
583 if let Some(ref b) = self.b_ih {
584 named.push(("b_ih".to_string(), b.clone()));
585 }
586 if let Some(ref b) = self.b_hh {
587 named.push(("b_hh".to_string(), b.clone()));
588 }
589 named
590 }
591}
592
593pub struct GRU<B: Backend> {
602 cell: GRUCell<B>,
603}
604
605impl<B: Backend> GRU<B> {
606 pub fn new(
608 input_size: usize,
609 hidden_size: usize,
610 use_bias: bool,
611 dtype: DType,
612 device: &B::Device,
613 ) -> Result<Self> {
614 let cell = GRUCell::new(input_size, hidden_size, use_bias, dtype, device)?;
615 Ok(GRU { cell })
616 }
617
618 pub fn forward(&self, x: &Tensor<B>, h0: Option<&Tensor<B>>) -> Result<(Tensor<B>, Tensor<B>)> {
628 let dims = x.dims();
629 let batch = dims[0];
630 let seq_len = dims[1];
631 let hs = self.cell.hidden_size;
632
633 let mut h = match h0 {
634 Some(h) => h.clone(),
635 None => Tensor::<B>::zeros((batch, hs), x.dtype(), x.device())?,
636 };
637
638 let mut outputs: Vec<Tensor<B>> = Vec::with_capacity(seq_len);
639 for t in 0..seq_len {
640 let x_t = x.narrow(1, t, 1)?.reshape((batch, self.cell.input_size))?;
641 h = self.cell.forward(&x_t, &h)?;
642 outputs.push(h.reshape((batch, 1, hs))?);
643 }
644
645 let output = Tensor::cat(&outputs, 1)?;
646 Ok((output, h))
647 }
648
649 pub fn parameters(&self) -> Vec<Tensor<B>> {
651 self.cell.parameters()
652 }
653
654 pub fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
656 self.cell
657 .named_parameters()
658 .into_iter()
659 .map(|(k, v)| (format!("cell.{k}"), v))
660 .collect()
661 }
662
663 pub fn cell(&self) -> &GRUCell<B> {
665 &self.cell
666 }
667}