shrew/exec/
train.rs

1// =============================================================================
2// Trainer — Training loop runner powered by the graph executor
3// =============================================================================
4//
5// Reads the @training block from an IrProgram and runs the training loop:
6//   1. Initialize parameters
7//   2. For each epoch:
8//      a. Forward pass through the model graph
9//      b. Compute loss
10//      c. Backward pass (autograd)
11//      d. Optimizer step
12//      e. Log metrics
13//
14// Uses the Executor for graph evaluation and the shrew-optim optimizers.
15
16use std::collections::HashMap;
17
18use shrew_core::backend::Backend;
19use shrew_core::error::Result;
20use shrew_core::tensor::Tensor;
21
22use shrew_ir::graph::IrProgram;
23
24use shrew_nn::{cross_entropy_loss, mse_loss};
25use shrew_optim::{Adam, AdamW, Optimizer, SGD};
26
27use super::engine::{Executor, RuntimeConfig};
28
29// ─────────────────────────────────────────────────────────────────────────────
30// Training result types
31// ─────────────────────────────────────────────────────────────────────────────
32
33/// Summary of a full training run.
34#[derive(Debug, Clone)]
35pub struct TrainResult {
36    /// Per-epoch logs.
37    pub epochs: Vec<EpochLog>,
38    /// Final loss value.
39    pub final_loss: f64,
40}
41
42/// Log for a single training epoch.
43#[derive(Debug, Clone)]
44pub struct EpochLog {
45    /// Epoch number (0-indexed).
46    pub epoch: usize,
47    /// Average loss for this epoch.
48    pub loss: f64,
49}
50
51impl std::fmt::Display for TrainResult {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        writeln!(f, "Training complete — {} epochs", self.epochs.len())?;
54        for log in &self.epochs {
55            writeln!(f, "  epoch {}: loss = {:.6}", log.epoch, log.loss)?;
56        }
57        write!(f, "  final loss: {:.6}", self.final_loss)
58    }
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// Trainer
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// High-level training loop runner.
66///
67/// Reads the @training configuration from an IrProgram and orchestrates
68/// forward, backward, and optimizer steps.
69pub struct Trainer<B: Backend> {
70    /// The graph executor.
71    pub executor: Executor<B>,
72    /// Name of the model graph.
73    model_graph: String,
74    /// Loss function name.
75    loss_fn: String,
76    /// Number of epochs.
77    epochs: usize,
78    /// Batch size (for reference, data batching is external).
79    pub batch_size: usize,
80}
81
82impl<B: Backend> Trainer<B> {
83    /// Create a Trainer from an IrProgram.
84    ///
85    /// Reads the @training block and fails if no training config exists.
86    pub fn from_program(
87        program: IrProgram,
88        device: B::Device,
89        config: RuntimeConfig,
90    ) -> Result<Self> {
91        let training = program.training.as_ref().ok_or_else(|| {
92            shrew_core::Error::msg("Program has no @training block. Cannot create Trainer.")
93        })?;
94
95        let model_graph = training.model_graph.clone();
96        let loss_fn = training.loss.clone();
97        let epochs = training.epochs as usize;
98        let batch_size = training.batch_size as usize;
99
100        let executor = Executor::<B>::new(program, device, config)?;
101
102        Ok(Self {
103            executor,
104            model_graph,
105            loss_fn,
106            epochs,
107            batch_size,
108        })
109    }
110
111    /// Run the training loop with an iterator of input batches.
112    ///
113    /// Each batch is a `HashMap<String, Tensor<B>>` mapping input names to
114    /// tensors. Returns a `TrainResult` with per-epoch loss logs.
115    ///
116    /// `targets_key` is the name of the target tensor in each batch map.
117    pub fn train(
118        &mut self,
119        data: &[HashMap<String, Tensor<B>>],
120        targets_key: &str,
121    ) -> Result<TrainResult> {
122        let training = self
123            .executor
124            .program()
125            .training
126            .as_ref()
127            .ok_or_else(|| shrew_core::Error::msg("No @training config"))?;
128
129        // Create optimizer
130        let params = self.executor.graph_params(&self.model_graph);
131        let lr = training.optimizer.lr;
132
133        let mut optimizer: Box<dyn Optimizer<B>> = match training.optimizer.kind.as_str() {
134            "SGD" | "sgd" => {
135                let momentum = training
136                    .optimizer
137                    .extra
138                    .get("momentum")
139                    .and_then(|v| match v {
140                        shrew_ir::graph::ConfigValue::Float(f) => Some(*f),
141                        _ => None,
142                    })
143                    .unwrap_or(0.0);
144                Box::new(SGD::new(params.clone(), lr, momentum, 0.0))
145            }
146            "Adam" | "adam" => Box::new(Adam::new(params.clone(), lr)),
147            "AdamW" | "adamw" => Box::new(AdamW::new(params.clone(), lr, 0.01)),
148            other => {
149                return Err(shrew_core::Error::msg(format!(
150                    "Unknown optimizer type: '{}'. Supported: SGD, Adam, AdamW",
151                    other
152                )));
153            }
154        };
155
156        // Set training mode
157        self.executor.config_mut().training = true;
158
159        let mut epoch_logs = Vec::new();
160
161        for epoch in 0..self.epochs {
162            let mut epoch_loss = 0.0;
163            let mut n_batches = 0;
164
165            for batch in data {
166                // Forward pass
167                let result = self.executor.run(&self.model_graph, batch)?;
168
169                // Get model output
170                let output = result
171                    .output()
172                    .ok_or_else(|| shrew_core::Error::msg("Model graph produced no output"))?;
173
174                // Get target from batch
175                let target = batch.get(targets_key).ok_or_else(|| {
176                    shrew_core::Error::msg(format!(
177                        "Target tensor '{}' not found in batch",
178                        targets_key
179                    ))
180                })?;
181
182                // Compute loss
183                let loss = match self.loss_fn.as_str() {
184                    "cross_entropy" | "CrossEntropy" => cross_entropy_loss(output, target)?,
185                    "mse" | "mse_loss" | "MSE" => mse_loss(output, target)?,
186                    other => {
187                        return Err(shrew_core::Error::msg(format!(
188                            "Unknown loss function: '{}'. Supported: cross_entropy, mse",
189                            other
190                        )));
191                    }
192                };
193
194                let loss_val = loss.to_scalar_f64()?;
195                epoch_loss += loss_val;
196                n_batches += 1;
197
198                // Backward pass
199                let grads = loss.backward()?;
200
201                // Optimizer step → new parameters
202                let new_params = optimizer.step(&grads)?;
203
204                // Update parameters in executor
205                self.executor.update_params(&self.model_graph, &new_params);
206            }
207
208            let avg_loss = if n_batches > 0 {
209                epoch_loss / n_batches as f64
210            } else {
211                0.0
212            };
213
214            epoch_logs.push(EpochLog {
215                epoch,
216                loss: avg_loss,
217            });
218        }
219
220        let final_loss = epoch_logs.last().map_or(0.0, |l| l.loss);
221        Ok(TrainResult {
222            epochs: epoch_logs,
223            final_loss,
224        })
225    }
226
227    /// Run a single forward pass (inference mode).
228    pub fn infer(&self, inputs: &HashMap<String, Tensor<B>>) -> Result<HashMap<String, Tensor<B>>> {
229        let result = self.executor.run(&self.model_graph, inputs)?;
230        Ok(result.outputs)
231    }
232
233    /// Get the model graph name.
234    pub fn model_graph_name(&self) -> &str {
235        &self.model_graph
236    }
237
238    /// Get the loss function name.
239    pub fn loss_fn_name(&self) -> &str {
240        &self.loss_fn
241    }
242
243    /// Get the number of epochs.
244    pub fn epochs(&self) -> usize {
245        self.epochs
246    }
247}
248
249// ─────────────────────────────────────────────────────────────────────────────
250// Convenience functions
251// ─────────────────────────────────────────────────────────────────────────────
252
253/// Parse, lower, validate, optimize, and prepare an executor from .sw source.
254pub fn load_program<B: Backend>(
255    source: &str,
256    device: B::Device,
257    config: RuntimeConfig,
258) -> Result<Executor<B>> {
259    let ast =
260        shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
261    let mut ir = shrew_ir::lower(&ast)
262        .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
263
264    // Validate
265    if let Err(errors) = shrew_ir::validate(&ir) {
266        let msg = errors
267            .iter()
268            .map(|e| e.to_string())
269            .collect::<Vec<_>>()
270            .join("\n");
271        return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
272    }
273
274    // Infer shapes and optimize
275    shrew_ir::infer_shapes(&mut ir);
276    shrew_ir::optimize(&mut ir);
277
278    Executor::<B>::new(ir, device, config)
279}
280
281/// Parse, lower, validate, optimize, and prepare a trainer from .sw source.
282pub fn load_trainer<B: Backend>(
283    source: &str,
284    device: B::Device,
285    config: RuntimeConfig,
286) -> Result<Trainer<B>> {
287    let ast =
288        shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
289    let mut ir = shrew_ir::lower(&ast)
290        .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
291
292    if let Err(errors) = shrew_ir::validate(&ir) {
293        let msg = errors
294            .iter()
295            .map(|e| e.to_string())
296            .collect::<Vec<_>>()
297            .join("\n");
298        return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
299    }
300
301    shrew_ir::infer_shapes(&mut ir);
302    shrew_ir::optimize(&mut ir);
303
304    Trainer::<B>::from_program(ir, device, config)
305}