shrew_nn/embedding.rs
1// Embedding — Lookup table for discrete tokens
2//
3// An embedding layer maps integer indices to dense vectors. It's the standard
4// way to handle categorical data (words, tokens, item IDs) in neural networks.
5//
6// Think of it as a learnable lookup table:
7// embedding[token_id] → vector of size embedding_dim
8//
9// For NLP: vocabulary of 50000 words → embedding(50000, 768) gives each word
10// a 768-dimensional vector representation that the network learns during training.
11//
12// IMPLEMENTATION:
13//
14// The embedding table is a [num_embeddings, embedding_dim] matrix.
15// Forward pass: given input indices [batch, seq_len], output is
16// [batch, seq_len, embedding_dim] by looking up each index.
17//
18// For now we implement this as a gather operation using to_f64_vec and
19// manual indexing. A more efficient index_select backend op can be added later.
20
21use shrew_core::backend::Backend;
22use shrew_core::dtype::DType;
23use shrew_core::error::Result;
24use shrew_core::tensor::Tensor;
25
26use crate::module::Module;
27
28/// A learnable lookup table mapping integer indices to dense vectors.
29///
30/// # Examples
31/// ```ignore
32/// let emb = Embedding::<CpuBackend>::new(1000, 128, DType::F32, &dev)?;
33/// // Input: token indices [batch=2, seq_len=5]
34/// let tokens = CpuTensor::from_f64_slice(&indices, (2, 5), DType::I64, &dev)?;
35/// let vectors = emb.forward(&tokens)?; // [2, 5, 128]
36/// ```
37pub struct Embedding<B: Backend> {
38 /// The embedding table: [num_embeddings, embedding_dim]
39 weight: Tensor<B>,
40 num_embeddings: usize,
41 embedding_dim: usize,
42}
43
44impl<B: Backend> Embedding<B> {
45 /// Create a new Embedding layer with normally-distributed random weights.
46 pub fn new(
47 num_embeddings: usize,
48 embedding_dim: usize,
49 dtype: DType,
50 device: &B::Device,
51 ) -> Result<Self> {
52 // Initialize from N(0, 1) — standard for embeddings
53 let weight =
54 Tensor::<B>::randn((num_embeddings, embedding_dim), dtype, device)?.set_variable();
55 Ok(Embedding {
56 weight,
57 num_embeddings,
58 embedding_dim,
59 })
60 }
61
62 /// Create from an existing weight matrix.
63 pub fn from_tensor(weight: Tensor<B>) -> Result<Self> {
64 let dims = weight.dims();
65 if dims.len() != 2 {
66 return Err(shrew_core::Error::msg(format!(
67 "Embedding weight must be 2D, got shape {:?}",
68 dims
69 )));
70 }
71 Ok(Embedding {
72 num_embeddings: dims[0],
73 embedding_dim: dims[1],
74 weight: weight.set_variable(),
75 })
76 }
77
78 pub fn num_embeddings(&self) -> usize {
79 self.num_embeddings
80 }
81 pub fn embedding_dim(&self) -> usize {
82 self.embedding_dim
83 }
84 pub fn weight(&self) -> &Tensor<B> {
85 &self.weight
86 }
87}
88
89impl<B: Backend> Module<B> for Embedding<B> {
90 /// Look up embeddings for the given indices.
91 ///
92 /// Input: integer tensor of any shape [...]
93 /// Output: tensor of shape [..., embedding_dim]
94 ///
95 /// Uses `index_select` to gather rows from the weight matrix directly
96 /// on-device (no host round-trip). For autograd, we record the operation
97 /// so gradients can flow back through the embedding table.
98 fn forward(&self, indices: &Tensor<B>) -> Result<Tensor<B>> {
99 let input_dims = indices.dims().to_vec();
100 let num_indices = indices.elem_count();
101
102 // Flatten indices to 1D, ensure U32 for index_select
103 let flat_idx = indices
104 .reshape(shrew_core::Shape::new(vec![num_indices]))?
105 .to_dtype(shrew_core::dtype::DType::U32)?;
106
107 // index_select on dim=0: weight[flat_idx] → [num_indices, embedding_dim]
108 let flat_result = self.weight.index_select(0, &flat_idx)?;
109
110 // Reshape to [..., embedding_dim]
111 let mut out_dims = input_dims;
112 out_dims.push(self.embedding_dim);
113 flat_result.reshape(shrew_core::Shape::new(out_dims))
114 }
115
116 fn parameters(&self) -> Vec<Tensor<B>> {
117 vec![self.weight.clone()]
118 }
119
120 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
121 vec![("weight".to_string(), self.weight.clone())]
122 }
123}