1use std::collections::HashMap;
17
18use shrew_core::backend::Backend;
19use shrew_core::error::Result;
20use shrew_core::tensor::Tensor;
21
22use shrew_ir::graph::IrProgram;
23
24use shrew_nn::{cross_entropy_loss, mse_loss};
25use shrew_optim::{Adam, AdamW, Optimizer, SGD};
26
27use super::engine::{Executor, RuntimeConfig};
28
29#[derive(Debug, Clone)]
35pub struct TrainResult {
36 pub epochs: Vec<EpochLog>,
38 pub final_loss: f64,
40}
41
42#[derive(Debug, Clone)]
44pub struct EpochLog {
45 pub epoch: usize,
47 pub loss: f64,
49}
50
51impl std::fmt::Display for TrainResult {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 writeln!(f, "Training complete — {} epochs", self.epochs.len())?;
54 for log in &self.epochs {
55 writeln!(f, " epoch {}: loss = {:.6}", log.epoch, log.loss)?;
56 }
57 write!(f, " final loss: {:.6}", self.final_loss)
58 }
59}
60
61pub struct Trainer<B: Backend> {
70 pub executor: Executor<B>,
72 model_graph: String,
74 loss_fn: String,
76 epochs: usize,
78 pub batch_size: usize,
80}
81
82impl<B: Backend> Trainer<B> {
83 pub fn from_program(
87 program: IrProgram,
88 device: B::Device,
89 config: RuntimeConfig,
90 ) -> Result<Self> {
91 let training = program.training.as_ref().ok_or_else(|| {
92 shrew_core::Error::msg("Program has no @training block. Cannot create Trainer.")
93 })?;
94
95 let model_graph = training.model_graph.clone();
96 let loss_fn = training.loss.clone();
97 let epochs = training.epochs as usize;
98 let batch_size = training.batch_size as usize;
99
100 let executor = Executor::<B>::new(program, device, config)?;
101
102 Ok(Self {
103 executor,
104 model_graph,
105 loss_fn,
106 epochs,
107 batch_size,
108 })
109 }
110
111 pub fn train(
118 &mut self,
119 data: &[HashMap<String, Tensor<B>>],
120 targets_key: &str,
121 ) -> Result<TrainResult> {
122 let training = self
123 .executor
124 .program()
125 .training
126 .as_ref()
127 .ok_or_else(|| shrew_core::Error::msg("No @training config"))?;
128
129 let params = self.executor.graph_params(&self.model_graph);
131 let lr = training.optimizer.lr;
132
133 let mut optimizer: Box<dyn Optimizer<B>> = match training.optimizer.kind.as_str() {
134 "SGD" | "sgd" => {
135 let momentum = training
136 .optimizer
137 .extra
138 .get("momentum")
139 .and_then(|v| match v {
140 shrew_ir::graph::ConfigValue::Float(f) => Some(*f),
141 _ => None,
142 })
143 .unwrap_or(0.0);
144 Box::new(SGD::new(params.clone(), lr, momentum, 0.0))
145 }
146 "Adam" | "adam" => Box::new(Adam::new(params.clone(), lr)),
147 "AdamW" | "adamw" => Box::new(AdamW::new(params.clone(), lr, 0.01)),
148 other => {
149 return Err(shrew_core::Error::msg(format!(
150 "Unknown optimizer type: '{}'. Supported: SGD, Adam, AdamW",
151 other
152 )));
153 }
154 };
155
156 self.executor.config_mut().training = true;
158
159 let mut epoch_logs = Vec::new();
160
161 for epoch in 0..self.epochs {
162 let mut epoch_loss = 0.0;
163 let mut n_batches = 0;
164
165 for batch in data {
166 let result = self.executor.run(&self.model_graph, batch)?;
168
169 let output = result
171 .output()
172 .ok_or_else(|| shrew_core::Error::msg("Model graph produced no output"))?;
173
174 let target = batch.get(targets_key).ok_or_else(|| {
176 shrew_core::Error::msg(format!(
177 "Target tensor '{}' not found in batch",
178 targets_key
179 ))
180 })?;
181
182 let loss = match self.loss_fn.as_str() {
184 "cross_entropy" | "CrossEntropy" => cross_entropy_loss(output, target)?,
185 "mse" | "mse_loss" | "MSE" => mse_loss(output, target)?,
186 other => {
187 return Err(shrew_core::Error::msg(format!(
188 "Unknown loss function: '{}'. Supported: cross_entropy, mse",
189 other
190 )));
191 }
192 };
193
194 let loss_val = loss.to_scalar_f64()?;
195 epoch_loss += loss_val;
196 n_batches += 1;
197
198 let grads = loss.backward()?;
200
201 let new_params = optimizer.step(&grads)?;
203
204 self.executor.update_params(&self.model_graph, &new_params);
206 }
207
208 let avg_loss = if n_batches > 0 {
209 epoch_loss / n_batches as f64
210 } else {
211 0.0
212 };
213
214 epoch_logs.push(EpochLog {
215 epoch,
216 loss: avg_loss,
217 });
218 }
219
220 let final_loss = epoch_logs.last().map_or(0.0, |l| l.loss);
221 Ok(TrainResult {
222 epochs: epoch_logs,
223 final_loss,
224 })
225 }
226
227 pub fn infer(&self, inputs: &HashMap<String, Tensor<B>>) -> Result<HashMap<String, Tensor<B>>> {
229 let result = self.executor.run(&self.model_graph, inputs)?;
230 Ok(result.outputs)
231 }
232
233 pub fn model_graph_name(&self) -> &str {
235 &self.model_graph
236 }
237
238 pub fn loss_fn_name(&self) -> &str {
240 &self.loss_fn
241 }
242
243 pub fn epochs(&self) -> usize {
245 self.epochs
246 }
247}
248
249pub fn load_program<B: Backend>(
255 source: &str,
256 device: B::Device,
257 config: RuntimeConfig,
258) -> Result<Executor<B>> {
259 let ast =
260 shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
261 let mut ir = shrew_ir::lower(&ast)
262 .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
263
264 if let Err(errors) = shrew_ir::validate(&ir) {
266 let msg = errors
267 .iter()
268 .map(|e| e.to_string())
269 .collect::<Vec<_>>()
270 .join("\n");
271 return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
272 }
273
274 shrew_ir::infer_shapes(&mut ir);
276 shrew_ir::optimize(&mut ir);
277
278 Executor::<B>::new(ir, device, config)
279}
280
281pub fn load_trainer<B: Backend>(
283 source: &str,
284 device: B::Device,
285 config: RuntimeConfig,
286) -> Result<Trainer<B>> {
287 let ast =
288 shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
289 let mut ir = shrew_ir::lower(&ast)
290 .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
291
292 if let Err(errors) = shrew_ir::validate(&ir) {
293 let msg = errors
294 .iter()
295 .map(|e| e.to_string())
296 .collect::<Vec<_>>()
297 .join("\n");
298 return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
299 }
300
301 shrew_ir::infer_shapes(&mut ir);
302 shrew_ir::optimize(&mut ir);
303
304 Trainer::<B>::from_program(ir, device, config)
305}