1use std::collections::HashMap;
10
11use shrew_core::backend::Backend;
12use shrew_core::dtype::DType as CoreDType;
13use shrew_core::error::Result;
14use shrew_core::tensor::Tensor;
15
16use shrew_ir::graph::{
17 ConfigValue, ConstantValue, DType as IrDType, Dim, InitStrategy, IrGraph, IrNode, IrProgram,
18 IrType, OpKind,
19};
20
21use shrew_nn::{
22 cross_entropy_loss, mse_loss, Dropout, Embedding, LayerNorm, Linear, Module, TransformerBlock,
23};
24
25#[derive(Debug, Clone)]
31pub struct RuntimeConfig {
32 pub dims: HashMap<String, usize>,
34 pub default_dtype: CoreDType,
36 pub training: bool,
38}
39
40impl Default for RuntimeConfig {
41 fn default() -> Self {
42 Self {
43 dims: HashMap::new(),
44 default_dtype: CoreDType::F32,
45 training: false,
46 }
47 }
48}
49
50impl RuntimeConfig {
51 pub fn set_dim(mut self, name: impl Into<String>, value: usize) -> Self {
53 self.dims.insert(name.into(), value);
54 self
55 }
56
57 pub fn with_training(mut self, training: bool) -> Self {
59 self.training = training;
60 self
61 }
62
63 pub fn with_dtype(mut self, dtype: CoreDType) -> Self {
65 self.default_dtype = dtype;
66 self
67 }
68}
69
70#[derive(Debug)]
76pub struct ExecResult<B: Backend> {
77 pub outputs: HashMap<String, Tensor<B>>,
79 pub values: HashMap<usize, Tensor<B>>,
81}
82
83impl<B: Backend> ExecResult<B> {
84 pub fn output(&self) -> Option<&Tensor<B>> {
86 self.outputs.values().next()
87 }
88
89 pub fn get(&self, name: &str) -> Option<&Tensor<B>> {
91 self.outputs.get(name)
92 }
93}
94
95pub struct Executor<B: Backend> {
101 program: IrProgram,
103 config: RuntimeConfig,
105 device: B::Device,
107 params: HashMap<(String, String), Tensor<B>>,
109}
110
111impl<B: Backend> Executor<B> {
112 pub fn new(program: IrProgram, device: B::Device, config: RuntimeConfig) -> Result<Self> {
114 let mut exec = Self {
115 program,
116 config,
117 device,
118 params: HashMap::new(),
119 };
120 exec.init_all_params()?;
121 Ok(exec)
122 }
123
124 pub fn program(&self) -> &IrProgram {
126 &self.program
127 }
128
129 pub fn config(&self) -> &RuntimeConfig {
131 &self.config
132 }
133
134 pub fn config_mut(&mut self) -> &mut RuntimeConfig {
136 &mut self.config
137 }
138
139 pub fn params(&self) -> &HashMap<(String, String), Tensor<B>> {
141 &self.params
142 }
143
144 pub fn all_params(&self) -> Vec<Tensor<B>> {
146 self.params.values().cloned().collect()
147 }
148
149 pub fn named_params(&self) -> Vec<(String, Tensor<B>)> {
151 let mut pairs: Vec<(String, Tensor<B>)> = self
152 .params
153 .iter()
154 .map(|((g, p), t)| (format!("{g}/{p}"), t.clone()))
155 .collect();
156 pairs.sort_by(|a, b| a.0.cmp(&b.0));
157 pairs
158 }
159
160 pub fn set_param_by_key(&mut self, key: &str, tensor: Tensor<B>) -> bool {
162 if let Some(pos) = key.find('/') {
163 let graph = &key[..pos];
164 let param = &key[pos + 1..];
165 let k = (graph.to_string(), param.to_string());
166 if let std::collections::hash_map::Entry::Occupied(mut e) = self.params.entry(k) {
167 e.insert(tensor.set_variable());
168 return true;
169 }
170 }
171 false
172 }
173
174 pub fn device(&self) -> &B::Device {
176 &self.device
177 }
178
179 pub fn run(
181 &self,
182 graph_name: &str,
183 inputs: &HashMap<String, Tensor<B>>,
184 ) -> Result<ExecResult<B>> {
185 let graph = self.program.get_graph(graph_name).ok_or_else(|| {
186 shrew_core::Error::msg(format!("Graph '{}' not found in program", graph_name))
187 })?;
188 self.execute_graph(graph, inputs)
189 }
190
191 fn execute_graph(
193 &self,
194 graph: &IrGraph,
195 inputs: &HashMap<String, Tensor<B>>,
196 ) -> Result<ExecResult<B>> {
197 let order = graph.topo_order();
198 let mut values: HashMap<usize, Tensor<B>> = HashMap::new();
199
200 for &input_id in &graph.inputs {
202 let node = graph.node(input_id);
203 if let Some(tensor) = inputs.get(&node.name) {
204 values.insert(input_id.0, tensor.clone());
205 }
206 }
207
208 for param in &graph.params {
210 let key = (graph.name.clone(), param.name.clone());
211 if let Some(tensor) = self.params.get(&key) {
212 values.insert(param.node_id.0, tensor.clone());
213 }
214 }
215
216 for &node_id in &order {
218 if values.contains_key(&node_id.0) {
219 continue; }
221 let node = graph.node(node_id);
222 let result = self.execute_node(graph, node, &values)?;
223 values.insert(node_id.0, result);
224 }
225
226 let mut outputs = HashMap::new();
228 for output in &graph.outputs {
229 if let Some(tensor) = values.get(&output.node_id.0) {
230 outputs.insert(output.name.clone(), tensor.clone());
231 }
232 }
233
234 Ok(ExecResult { outputs, values })
235 }
236
237 fn execute_node(
239 &self,
240 _graph: &IrGraph,
241 node: &IrNode,
242 values: &HashMap<usize, Tensor<B>>,
243 ) -> Result<Tensor<B>> {
244 let input_tensors: Vec<&Tensor<B>> = node
246 .inputs
247 .iter()
248 .filter_map(|id| values.get(&id.0))
249 .collect();
250
251 match &node.op {
252 OpKind::Identity => input_tensors.first().map(|t| (*t).clone()).ok_or_else(|| {
254 shrew_core::Error::msg(format!("Identity node '{}' has no input", node.name))
255 }),
256
257 OpKind::Neg => unary(&input_tensors, &node.name, |t| t.neg()),
259 OpKind::Relu => unary(&input_tensors, &node.name, |t| t.relu()),
260 OpKind::Gelu => unary(&input_tensors, &node.name, |t| t.gelu()),
261 OpKind::Silu => unary(&input_tensors, &node.name, |t| t.silu()),
262 OpKind::Sigmoid => unary(&input_tensors, &node.name, |t| t.sigmoid()),
263 OpKind::Tanh => unary(&input_tensors, &node.name, |t| t.tanh()),
264 OpKind::Exp => unary(&input_tensors, &node.name, |t| t.exp()),
265 OpKind::Log => unary(&input_tensors, &node.name, |t| t.log()),
266 OpKind::Sqrt => unary(&input_tensors, &node.name, |t| t.sqrt()),
267
268 OpKind::Transpose => {
270 let t = require_input(&input_tensors, 0, &node.name)?;
271 let rank = t.rank();
272 if rank < 2 {
273 return Err(shrew_core::Error::msg(format!(
274 "Transpose requires rank >= 2, got {} for '{}'",
275 rank, node.name
276 )));
277 }
278 t.transpose(rank - 2, rank - 1)
279 }
280
281 OpKind::Add => binary(&input_tensors, &node.name, |a, b| a.add(b)),
283 OpKind::Sub => binary(&input_tensors, &node.name, |a, b| a.sub(b)),
284 OpKind::Mul => binary(&input_tensors, &node.name, |a, b| a.mul(b)),
285 OpKind::Div => binary(&input_tensors, &node.name, |a, b| a.div(b)),
286 OpKind::MatMul => binary(&input_tensors, &node.name, |a, b| a.matmul(b)),
287
288 OpKind::Pow => {
290 let base = require_input(&input_tensors, 0, &node.name)?;
291 let exp_t = require_input(&input_tensors, 1, &node.name)?;
292 base.log()?.mul(exp_t)?.exp()
294 }
295
296 OpKind::Mod => {
298 let a = require_input(&input_tensors, 0, &node.name)?;
299 let b = require_input(&input_tensors, 1, &node.name)?;
300 let quotient = a.div(b)?.floor()?;
301 let product = quotient.mul(b)?;
302 a.sub(&product)
303 }
304
305 OpKind::Sum { dims, keepdim } => {
307 let t = require_input(&input_tensors, 0, &node.name)?;
308 if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
309 t.sum_all()
310 } else {
311 let dim = resolve_neg_dim(dims[0], t.rank());
312 t.sum(dim, *keepdim)
313 }
314 }
315
316 OpKind::Mean { dims, keepdim } => {
317 let t = require_input(&input_tensors, 0, &node.name)?;
318 if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
319 t.mean_all()
320 } else {
321 let dim = resolve_neg_dim(dims[0], t.rank());
322 t.mean(dim, *keepdim)
323 }
324 }
325
326 OpKind::Max { dim, keepdim } => {
327 let t = require_input(&input_tensors, 0, &node.name)?;
328 let d = resolve_neg_dim(*dim, t.rank());
329 t.max(d, *keepdim)
330 }
331
332 OpKind::Min { dim, keepdim } => {
333 let t = require_input(&input_tensors, 0, &node.name)?;
334 let d = resolve_neg_dim(*dim, t.rank());
335 t.min(d, *keepdim)
336 }
337
338 OpKind::Variance { dims, keepdim } => {
339 let t = require_input(&input_tensors, 0, &node.name)?;
340 if dims.is_empty() {
341 t.var(0, *keepdim)
342 } else {
343 let dim = resolve_neg_dim(dims[0], t.rank());
344 t.var(dim, *keepdim)
345 }
346 }
347
348 OpKind::Softmax { dim } => {
350 let t = require_input(&input_tensors, 0, &node.name)?;
351 let d = resolve_neg_dim(*dim, t.rank());
352 t.softmax(d)
353 }
354
355 OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
357 let t = require_input(&input_tensors, 0, &node.name)?;
358 let shape = self.resolve_shape_vec(target_shape)?;
359 t.reshape(shape)
360 }
361
362 OpKind::Permute { dims: perm_dims } => {
363 let t = require_input(&input_tensors, 0, &node.name)?;
364 let mut result = t.clone();
366 let mut current: Vec<usize> = (0..t.rank()).collect();
367 for i in 0..perm_dims.len() {
368 let target = perm_dims[i] as usize;
369 if current[i] != target {
370 let j = current.iter().position(|&x| x == target).ok_or_else(|| {
371 shrew_core::Error::msg(format!(
372 "permute: dimension {} not found in current layout",
373 target
374 ))
375 })?;
376 result = result.transpose(i, j)?;
377 current.swap(i, j);
378 }
379 }
380 Ok(result)
381 }
382
383 OpKind::Expand { target_shape } => {
384 let t = require_input(&input_tensors, 0, &node.name)?;
385 let shape = self.resolve_shape_vec(target_shape)?;
386 t.expand(shape)
387 }
388
389 OpKind::Concat { dim } => {
390 if input_tensors.is_empty() {
391 return Err(shrew_core::Error::msg(format!(
392 "Concat node '{}' has no inputs",
393 node.name
394 )));
395 }
396 let owned: Vec<Tensor<B>> = input_tensors.iter().map(|t| (*t).clone()).collect();
397 Tensor::<B>::cat(&owned, *dim as usize)
398 }
399
400 OpKind::Embedding => {
403 let indices = require_input(&input_tensors, 0, &node.name)?;
404 let table = require_input(&input_tensors, 1, &node.name)?;
405 let emb = Embedding::<B>::from_tensor(table.clone())?;
406 emb.forward(indices)
407 }
408
409 OpKind::Linear { bias } => {
412 let input = require_input(&input_tensors, 0, &node.name)?;
413 let weight = require_input(&input_tensors, 1, &node.name)?;
414 if *bias && input_tensors.len() >= 3 {
415 let bias_t = require_input(&input_tensors, 2, &node.name)?;
416 let lin = Linear::<B>::from_tensors(weight.clone(), Some(bias_t.clone()))?;
417 lin.forward(input)
418 } else {
419 let lin = Linear::<B>::from_tensors(weight.clone(), None)?;
420 lin.forward(input)
421 }
422 }
423
424 OpKind::LayerNorm { eps } => {
427 let input = require_input(&input_tensors, 0, &node.name)?;
428 let weight = require_input(&input_tensors, 1, &node.name)?;
429 let bias_t = require_input(&input_tensors, 2, &node.name)?;
430 let ln = LayerNorm::<B>::from_tensors(weight.clone(), bias_t.clone(), *eps)?;
431 ln.forward(input)
432 }
433
434 OpKind::MultiHeadAttention { n_heads } => {
436 let input = require_input(&input_tensors, 0, &node.name)?;
437 let d_model = *input
438 .dims()
439 .last()
440 .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
441 let mha = shrew_nn::MultiHeadAttention::<B>::new(
442 d_model,
443 *n_heads as usize,
444 input.dtype(),
445 input.device(),
446 )?;
447 mha.forward(input)
448 }
449
450 OpKind::TransformerBlock { n_heads } => {
452 let input = require_input(&input_tensors, 0, &node.name)?;
453 let dims = input.dims();
454 if dims.len() != 3 {
455 return Err(shrew_core::Error::msg(format!(
456 "TransformerBlock expects [batch, seq, d_model], got {:?}",
457 dims
458 )));
459 }
460 let d_model = dims[2];
461 let d_ff = d_model * 4;
462 let block = TransformerBlock::<B>::new(
463 d_model,
464 *n_heads as usize,
465 d_ff,
466 true, input.dtype(),
468 input.device(),
469 )?;
470 block.forward(input)
471 }
472
473 OpKind::Dropout { p } => {
475 let input = require_input(&input_tensors, 0, &node.name)?;
476 let dropout = Dropout::new(*p);
477 if self.config.training {
478 dropout.forward_t(input)
479 } else {
480 Ok(input.clone())
481 }
482 }
483
484 OpKind::CrossEntropy => {
486 let predictions = require_input(&input_tensors, 0, &node.name)?;
487 let targets = require_input(&input_tensors, 1, &node.name)?;
488 cross_entropy_loss(predictions, targets)
489 }
490
491 OpKind::MseLoss => {
492 let predictions = require_input(&input_tensors, 0, &node.name)?;
493 let targets = require_input(&input_tensors, 1, &node.name)?;
494 mse_loss(predictions, targets)
495 }
496
497 OpKind::Equal
499 | OpKind::NotEqual
500 | OpKind::Less
501 | OpKind::Greater
502 | OpKind::LessEqual
503 | OpKind::GreaterEqual => {
504 let lhs = require_input(&input_tensors, 0, &node.name)?;
505 let rhs = require_input(&input_tensors, 1, &node.name)?;
506 match &node.op {
507 OpKind::Equal => lhs.eq(rhs),
508 OpKind::NotEqual => lhs.ne(rhs),
509 OpKind::Less => lhs.lt(rhs),
510 OpKind::Greater => lhs.gt(rhs),
511 OpKind::LessEqual => lhs.le(rhs),
512 OpKind::GreaterEqual => lhs.ge(rhs),
513 _ => unreachable!(),
514 }
515 }
516
517 OpKind::Constant(val) => self.materialize_constant(val, &node.output_type),
519
520 OpKind::Repeat { count, body_op } => {
522 let input = require_input(&input_tensors, 0, &node.name)?;
523 let mut current = input.clone();
524 for _ in 0..*count {
525 current = self.execute_body_op(body_op, ¤t)?;
526 }
527 Ok(current)
528 }
529
530 OpKind::Call { graph_name } => {
532 let sub_graph = self.program.get_graph(graph_name).ok_or_else(|| {
534 shrew_core::Error::msg(format!("Called graph '{}' not found", graph_name))
535 })?;
536 let mut sub_inputs = HashMap::new();
537 for (i, &input_id) in sub_graph.inputs.iter().enumerate() {
538 let input_node = sub_graph.node(input_id);
539 if let Some(tensor) = input_tensors.get(i) {
540 sub_inputs.insert(input_node.name.clone(), (*tensor).clone());
541 }
542 }
543 let result = self.execute_graph(sub_graph, &sub_inputs)?;
544 result.output().cloned().ok_or_else(|| {
545 shrew_core::Error::msg(format!(
546 "Called graph '{}' produced no output",
547 graph_name
548 ))
549 })
550 }
551
552 OpKind::Range => {
554 let (start, end) = if input_tensors.len() >= 2 {
556 let s = input_tensors[0].to_scalar_f64()?;
557 let e = input_tensors[1].to_scalar_f64()?;
558 (s as i64, e as i64)
559 } else if input_tensors.len() == 1 {
560 (0i64, input_tensors[0].to_scalar_f64()? as i64)
561 } else {
562 match &node.output_type {
564 IrType::Tensor { shape, .. } => {
565 if let Some(Dim::Fixed(n)) = shape.first() {
566 (0, *n)
567 } else if let Some(Dim::Symbolic(name)) = shape.first() {
568 let n = self.resolve_symbolic(name)? as i64;
569 (0, n)
570 } else {
571 (0, 1)
572 }
573 }
574 _ => (0, 1),
575 }
576 };
577 let data: Vec<f64> = (start..end).map(|i| i as f64).collect();
578 let len = data.len();
579 Tensor::<B>::from_f64_slice(&data, len, CoreDType::I64, &self.device)
580 }
581
582 OpKind::BatchNorm { eps } => {
585 let input = require_input(&input_tensors, 0, &node.name)?;
586 if input_tensors.len() >= 3 {
587 let weight = require_input(&input_tensors, 1, &node.name)?;
588 let bias_t = require_input(&input_tensors, 2, &node.name)?;
589 let bn = shrew_nn::BatchNorm2d::<B>::from_tensors(
590 weight.clone(),
591 bias_t.clone(),
592 *eps,
593 )?;
594 bn.forward(input)
595 } else {
596 let dims = input.dims();
598 if dims.len() != 4 {
599 return Err(shrew_core::Error::msg(format!(
600 "BatchNorm expects 4D input [N,C,H,W], got {:?}",
601 dims
602 )));
603 }
604 let c = dims[1];
605 let bn =
606 shrew_nn::BatchNorm2d::<B>::new(c, *eps, 0.1, input.dtype(), &self.device)?;
607 bn.forward(input)
608 }
609 }
610
611 OpKind::Split { dim, chunks } => {
613 let input = require_input(&input_tensors, 0, &node.name)?;
614 let d = resolve_neg_dim(*dim, input.rank());
615 let result = input.chunk(*chunks as usize, d)?;
616 result
618 .into_iter()
619 .next()
620 .ok_or_else(|| shrew_core::Error::msg("Split produced no chunks"))
621 }
622
623 OpKind::And => {
625 let lhs = require_input(&input_tensors, 0, &node.name)?;
626 let rhs = require_input(&input_tensors, 1, &node.name)?;
627 let a_data = lhs.to_f64_vec()?;
629 let b_data = rhs.to_f64_vec()?;
630 let result: Vec<f64> = a_data
631 .iter()
632 .zip(b_data.iter())
633 .map(|(&a, &b)| if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 })
634 .collect();
635 let n = result.len();
636 Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
637 }
638 OpKind::Or => {
639 let lhs = require_input(&input_tensors, 0, &node.name)?;
640 let rhs = require_input(&input_tensors, 1, &node.name)?;
641 let a_data = lhs.to_f64_vec()?;
642 let b_data = rhs.to_f64_vec()?;
643 let result: Vec<f64> = a_data
644 .iter()
645 .zip(b_data.iter())
646 .map(|(&a, &b)| if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 })
647 .collect();
648 let n = result.len();
649 Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
650 }
651 OpKind::Not => {
652 let input = require_input(&input_tensors, 0, &node.name)?;
653 let data = input.to_f64_vec()?;
654 let result: Vec<f64> = data
655 .iter()
656 .map(|&v| if v == 0.0 { 1.0 } else { 0.0 })
657 .collect();
658 let n = result.len();
659 Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
660 }
661
662 OpKind::Custom { name, .. } => {
664 match name.as_str() {
665 "fused_matmul_add" => {
667 let a = require_input(&input_tensors, 0, &node.name)?;
668 let b = require_input(&input_tensors, 1, &node.name)?;
669 let c = require_input(&input_tensors, 2, &node.name)?;
670 a.matmul(b)?.add(c)
671 }
672 "fused_add_relu" => {
674 let a = require_input(&input_tensors, 0, &node.name)?;
675 let b = require_input(&input_tensors, 1, &node.name)?;
676 a.add(b)?.relu()
677 }
678 "fused_sub_relu" => {
680 let a = require_input(&input_tensors, 0, &node.name)?;
681 let b = require_input(&input_tensors, 1, &node.name)?;
682 a.sub(b)?.relu()
683 }
684 "fused_matmul_relu" => {
686 let a = require_input(&input_tensors, 0, &node.name)?;
687 let b = require_input(&input_tensors, 1, &node.name)?;
688 a.matmul(b)?.relu()
689 }
690 _ => Err(shrew_core::Error::msg(format!(
691 "Custom op '{}' is not implemented in the executor",
692 name
693 ))),
694 }
695 }
696 }
697 }
698
699 fn execute_body_op(&self, op: &OpKind, input: &Tensor<B>) -> Result<Tensor<B>> {
701 match op {
702 OpKind::TransformerBlock { n_heads } => {
703 let dims = input.dims();
704 if dims.len() != 3 {
705 return Err(shrew_core::Error::msg(format!(
706 "TransformerBlock expects [batch, seq, d_model], got {:?}",
707 dims
708 )));
709 }
710 let d_model = dims[2];
711 let d_ff = d_model * 4;
712 let block = TransformerBlock::<B>::new(
713 d_model,
714 *n_heads as usize,
715 d_ff,
716 true,
717 input.dtype(),
718 input.device(),
719 )?;
720 block.forward(input)
721 }
722 OpKind::MultiHeadAttention { n_heads } => {
723 let d_model = *input
724 .dims()
725 .last()
726 .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
727 let mha = shrew_nn::MultiHeadAttention::<B>::new(
728 d_model,
729 *n_heads as usize,
730 input.dtype(),
731 input.device(),
732 )?;
733 mha.forward(input)
734 }
735 _ => Err(shrew_core::Error::msg(format!(
738 "Unsupported op in Repeat body: {:?}. \
739 Only TransformerBlock and MultiHeadAttention are supported.",
740 op
741 ))),
742 }
743 }
744
745 fn init_all_params(&mut self) -> Result<()> {
751 let graphs: Vec<(String, Vec<_>)> = self
752 .program
753 .graphs
754 .iter()
755 .map(|g| {
756 (
757 g.name.clone(),
758 g.params
759 .iter()
760 .map(|p| (p.name.clone(), p.ty.clone(), p.init.clone(), p.frozen))
761 .collect::<Vec<_>>(),
762 )
763 })
764 .collect();
765
766 for (graph_name, params) in &graphs {
767 for (param_name, ty, init, frozen) in params {
768 let tensor = self.init_param(ty, init, *frozen)?;
769 self.params
770 .insert((graph_name.clone(), param_name.clone()), tensor);
771 }
772 }
773 Ok(())
774 }
775
776 fn init_param(&self, ty: &IrType, init: &InitStrategy, frozen: bool) -> Result<Tensor<B>> {
778 let (shape, dtype) = self.resolve_type(ty)?;
779 let tensor = match init {
780 InitStrategy::Zeros => Tensor::<B>::zeros(shape, dtype, &self.device)?,
781 InitStrategy::Ones => Tensor::<B>::ones(shape, dtype, &self.device)?,
782 InitStrategy::Normal { mean, std } => {
783 Tensor::<B>::randn(shape, dtype, &self.device)?.affine(*std, *mean)?
784 }
785 InitStrategy::Uniform { low, high } => {
786 let range = high - low;
787 Tensor::<B>::rand(shape, dtype, &self.device)?.affine(range, *low)?
788 }
789 InitStrategy::XavierUniform => {
790 let (fan_in, fan_out) = compute_fans(&shape);
792 let a = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
793 Tensor::<B>::rand(shape, dtype, &self.device)?.affine(2.0 * a, -a)?
794 }
795 InitStrategy::XavierNormal => {
796 let (fan_in, fan_out) = compute_fans(&shape);
798 let std = (2.0_f64 / (fan_in + fan_out) as f64).sqrt();
799 Tensor::<B>::randn(shape, dtype, &self.device)?.affine(std, 0.0)?
800 }
801 InitStrategy::KaimingUniform => {
802 let (fan_in, _) = compute_fans(&shape);
804 let bound = (3.0_f64 / fan_in as f64).sqrt();
805 Tensor::<B>::rand(shape, dtype, &self.device)?.affine(2.0 * bound, -bound)?
806 }
807 InitStrategy::KaimingNormal => {
808 let (fan_in, _) = compute_fans(&shape);
810 let std = (2.0_f64 / fan_in as f64).sqrt();
811 Tensor::<B>::randn(shape, dtype, &self.device)?.affine(std, 0.0)?
812 }
813 InitStrategy::Custom(_) => Tensor::<B>::randn(shape, dtype, &self.device)?,
814 };
815
816 if frozen {
817 Ok(tensor)
818 } else {
819 Ok(tensor.set_variable())
820 }
821 }
822
823 pub fn update_params(&mut self, graph_name: &str, new_params: &[Tensor<B>]) {
825 let param_names: Vec<String> = self
826 .params
827 .keys()
828 .filter(|(g, _)| g == graph_name)
829 .map(|(_, n)| n.clone())
830 .collect();
831
832 for (name, tensor) in param_names.into_iter().zip(new_params.iter()) {
833 self.params
834 .insert((graph_name.to_string(), name), tensor.clone());
835 }
836 }
837
838 pub fn graph_params(&self, graph_name: &str) -> Vec<Tensor<B>> {
840 self.params
841 .iter()
842 .filter(|((g, _), _)| g == graph_name)
843 .map(|(_, t)| t.clone())
844 .collect()
845 }
846
847 fn resolve_dim(&self, dim: &Dim) -> Result<usize> {
853 match dim {
854 Dim::Fixed(n) => Ok(*n as usize),
855 Dim::Symbolic(name) => self.resolve_symbolic(name),
856 Dim::Dynamic => Err(shrew_core::Error::msg(
857 "Cannot resolve dynamic dimension at runtime",
858 )),
859 }
860 }
861
862 fn resolve_symbolic(&self, name: &str) -> Result<usize> {
864 if let Some(&val) = self.config.dims.get(name) {
866 return Ok(val);
867 }
868 if let Some(ConfigValue::Int(n)) = self.program.config.get(name) {
870 return Ok(*n as usize);
871 }
872 Err(shrew_core::Error::msg(format!(
873 "Unresolved symbolic dimension: '{}'. Set it via RuntimeConfig::set_dim()",
874 name
875 )))
876 }
877
878 fn resolve_type(&self, ty: &IrType) -> Result<(shrew_core::Shape, CoreDType)> {
880 match ty {
881 IrType::Tensor { shape, dtype } => {
882 let dims: Vec<usize> = shape
883 .iter()
884 .map(|d| self.resolve_dim(d))
885 .collect::<Result<Vec<_>>>()?;
886 let core_dtype = ir_dtype_to_core(*dtype)?;
887 Ok((shrew_core::Shape::new(dims), core_dtype))
888 }
889 IrType::Scalar(dtype) => {
890 let core_dtype = ir_dtype_to_core(*dtype)?;
891 Ok((shrew_core::Shape::new(vec![1]), core_dtype))
892 }
893 IrType::Int => Ok((shrew_core::Shape::new(vec![1]), CoreDType::I64)),
894 _ => Ok((shrew_core::Shape::new(vec![1]), self.config.default_dtype)),
895 }
896 }
897
898 fn resolve_shape_vec(&self, dims: &[Dim]) -> Result<Vec<usize>> {
900 dims.iter().map(|d| self.resolve_dim(d)).collect()
901 }
902
903 fn materialize_constant(&self, val: &ConstantValue, ty: &IrType) -> Result<Tensor<B>> {
905 match val {
906 ConstantValue::Int(n) => {
907 Tensor::<B>::from_f64_slice(&[*n as f64], 1, CoreDType::I64, &self.device)
908 }
909 ConstantValue::Float(f) => Tensor::<B>::from_f64_slice(
910 &[*f],
911 1,
912 ir_type_dtype(ty, self.config.default_dtype)?,
913 &self.device,
914 ),
915 ConstantValue::Bool(b) => Tensor::<B>::from_f64_slice(
916 &[if *b { 1.0 } else { 0.0 }],
917 1,
918 CoreDType::U8,
919 &self.device,
920 ),
921 ConstantValue::Str(_) => {
922 Tensor::<B>::zeros(1, self.config.default_dtype, &self.device)
924 }
925 ConstantValue::Null => Tensor::<B>::zeros(1, self.config.default_dtype, &self.device),
926 }
927 }
928}
929
930pub fn ir_dtype_to_core(dt: IrDType) -> Result<CoreDType> {
936 match dt {
937 IrDType::F32 => Ok(CoreDType::F32),
938 IrDType::F64 => Ok(CoreDType::F64),
939 IrDType::U8 => Ok(CoreDType::U8),
940 IrDType::U32 => Ok(CoreDType::U32),
941 IrDType::I64 => Ok(CoreDType::I64),
942 IrDType::F16 | IrDType::Bf16 => Ok(CoreDType::F32),
944 IrDType::I8 | IrDType::I16 | IrDType::I32 => Ok(CoreDType::I64),
945 IrDType::U16 => Ok(CoreDType::U32),
946 IrDType::U64 => Ok(CoreDType::U32),
947 IrDType::Bool => Ok(CoreDType::U8),
948 _ => Err(shrew_core::Error::msg(format!(
949 "Unsupported IR dtype: {dt}"
950 ))),
951 }
952}
953
954fn ir_type_dtype(ty: &IrType, default: CoreDType) -> Result<CoreDType> {
956 match ty {
957 IrType::Tensor { dtype, .. } => ir_dtype_to_core(*dtype),
958 IrType::Scalar(dtype) => ir_dtype_to_core(*dtype),
959 _ => Ok(default),
960 }
961}
962
963fn resolve_neg_dim(dim: i64, rank: usize) -> usize {
965 if dim < 0 {
966 (rank as i64 + dim) as usize
967 } else {
968 dim as usize
969 }
970}
971
972fn require_input<'a, B: Backend>(
974 inputs: &[&'a Tensor<B>],
975 idx: usize,
976 node_name: &str,
977) -> Result<&'a Tensor<B>> {
978 inputs.get(idx).copied().ok_or_else(|| {
979 shrew_core::Error::msg(format!(
980 "Node '{}' expected input at index {}, but only {} inputs available",
981 node_name,
982 idx,
983 inputs.len()
984 ))
985 })
986}
987
988fn unary<B: Backend>(
990 inputs: &[&Tensor<B>],
991 node_name: &str,
992 f: impl FnOnce(&Tensor<B>) -> Result<Tensor<B>>,
993) -> Result<Tensor<B>> {
994 let t = require_input(inputs, 0, node_name)?;
995 f(t)
996}
997
998fn binary<B: Backend>(
1000 inputs: &[&Tensor<B>],
1001 node_name: &str,
1002 f: impl FnOnce(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
1003) -> Result<Tensor<B>> {
1004 let a = require_input(inputs, 0, node_name)?;
1005 let b = require_input(inputs, 1, node_name)?;
1006 f(a, b)
1007}
1008
1009fn compute_fans(shape: &shrew_core::Shape) -> (usize, usize) {
1016 let dims = shape.dims();
1017 match dims.len() {
1018 0 => (1, 1),
1019 1 => (dims[0], dims[0]),
1020 2 => (dims[1], dims[0]),
1021 _ => {
1022 let receptive: usize = dims[2..].iter().product();
1024 let fan_in = dims[1] * receptive;
1025 let fan_out = dims[0] * receptive;
1026 (fan_in, fan_out)
1027 }
1028 }
1029}