shrew_nn/
embedding.rs

1// Embedding — Lookup table for discrete tokens
2//
3// An embedding layer maps integer indices to dense vectors. It's the standard
4// way to handle categorical data (words, tokens, item IDs) in neural networks.
5//
6// Think of it as a learnable lookup table:
7//   embedding[token_id] → vector of size embedding_dim
8//
9// For NLP: vocabulary of 50000 words → embedding(50000, 768) gives each word
10// a 768-dimensional vector representation that the network learns during training.
11//
12// IMPLEMENTATION:
13//
14// The embedding table is a [num_embeddings, embedding_dim] matrix.
15// Forward pass: given input indices [batch, seq_len], output is
16// [batch, seq_len, embedding_dim] by looking up each index.
17//
18// For now we implement this as a gather operation using to_f64_vec and
19// manual indexing. A more efficient index_select backend op can be added later.
20
21use shrew_core::backend::Backend;
22use shrew_core::dtype::DType;
23use shrew_core::error::Result;
24use shrew_core::tensor::Tensor;
25
26use crate::module::Module;
27
28/// A learnable lookup table mapping integer indices to dense vectors.
29///
30/// # Examples
31/// ```ignore
32/// let emb = Embedding::<CpuBackend>::new(1000, 128, DType::F32, &dev)?;
33/// // Input: token indices [batch=2, seq_len=5]
34/// let tokens = CpuTensor::from_f64_slice(&indices, (2, 5), DType::I64, &dev)?;
35/// let vectors = emb.forward(&tokens)?; // [2, 5, 128]
36/// ```
37pub struct Embedding<B: Backend> {
38    /// The embedding table: [num_embeddings, embedding_dim]
39    weight: Tensor<B>,
40    num_embeddings: usize,
41    embedding_dim: usize,
42}
43
44impl<B: Backend> Embedding<B> {
45    /// Create a new Embedding layer with normally-distributed random weights.
46    pub fn new(
47        num_embeddings: usize,
48        embedding_dim: usize,
49        dtype: DType,
50        device: &B::Device,
51    ) -> Result<Self> {
52        // Initialize from N(0, 1) — standard for embeddings
53        let weight =
54            Tensor::<B>::randn((num_embeddings, embedding_dim), dtype, device)?.set_variable();
55        Ok(Embedding {
56            weight,
57            num_embeddings,
58            embedding_dim,
59        })
60    }
61
62    /// Create from an existing weight matrix.
63    pub fn from_tensor(weight: Tensor<B>) -> Result<Self> {
64        let dims = weight.dims();
65        if dims.len() != 2 {
66            return Err(shrew_core::Error::msg(format!(
67                "Embedding weight must be 2D, got shape {:?}",
68                dims
69            )));
70        }
71        Ok(Embedding {
72            num_embeddings: dims[0],
73            embedding_dim: dims[1],
74            weight: weight.set_variable(),
75        })
76    }
77
78    pub fn num_embeddings(&self) -> usize {
79        self.num_embeddings
80    }
81    pub fn embedding_dim(&self) -> usize {
82        self.embedding_dim
83    }
84    pub fn weight(&self) -> &Tensor<B> {
85        &self.weight
86    }
87}
88
89impl<B: Backend> Module<B> for Embedding<B> {
90    /// Look up embeddings for the given indices.
91    ///
92    /// Input: integer tensor of any shape [...]
93    /// Output: tensor of shape [..., embedding_dim]
94    ///
95    /// Uses `index_select` to gather rows from the weight matrix directly
96    /// on-device (no host round-trip). For autograd, we record the operation
97    /// so gradients can flow back through the embedding table.
98    fn forward(&self, indices: &Tensor<B>) -> Result<Tensor<B>> {
99        let input_dims = indices.dims().to_vec();
100        let num_indices = indices.elem_count();
101
102        // Flatten indices to 1D, ensure U32 for index_select
103        let flat_idx = indices
104            .reshape(shrew_core::Shape::new(vec![num_indices]))?
105            .to_dtype(shrew_core::dtype::DType::U32)?;
106
107        // index_select on dim=0: weight[flat_idx] → [num_indices, embedding_dim]
108        let flat_result = self.weight.index_select(0, &flat_idx)?;
109
110        // Reshape to [..., embedding_dim]
111        let mut out_dims = input_dims;
112        out_dims.push(self.embedding_dim);
113        flat_result.reshape(shrew_core::Shape::new(out_dims))
114    }
115
116    fn parameters(&self) -> Vec<Tensor<B>> {
117        vec![self.weight.clone()]
118    }
119
120    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
121        vec![("weight".to_string(), self.weight.clone())]
122    }
123}