1use std::collections::HashMap;
39use std::fmt;
40use std::time::Instant;
41
42use shrew_core::backend::Backend;
43use shrew_core::dtype::DType as CoreDType;
44use shrew_core::error::Result;
45use shrew_core::tensor::Tensor;
46
47use shrew_ir::graph::{ConstantValue, Dim, IrGraph, IrNode, IrProgram, IrType, NodeId, OpKind};
48
49use shrew_nn::Module;
50
51use super::engine::{ir_dtype_to_core, RuntimeConfig};
52
53#[derive(Debug, Clone)]
59pub struct SlotRef {
60 pub slot: usize,
62}
63
64#[derive(Debug, Clone)]
69pub enum Instruction {
70 LoadInput {
73 name: String,
75 dst: usize,
77 },
78 LoadParam {
80 graph_name: String,
82 param_name: String,
83 dst: usize,
85 },
86
87 Unary {
89 op: UnaryInstr,
90 src: usize,
91 dst: usize,
92 },
93
94 Binary {
96 op: BinaryInstr,
97 lhs: usize,
98 rhs: usize,
99 dst: usize,
100 },
101
102 Reduce {
104 op: ReduceInstr,
105 src: usize,
106 dst: usize,
107 dims: Vec<i64>,
108 keepdim: bool,
109 },
110
111 Reshape {
113 src: usize,
114 dst: usize,
115 shape: Vec<usize>,
117 },
118 Transpose {
119 src: usize,
120 dst: usize,
121 },
122 Permute {
123 src: usize,
124 dst: usize,
125 dims: Vec<i64>,
126 },
127 Expand {
128 src: usize,
129 dst: usize,
130 shape: Vec<usize>,
131 },
132 Concat {
133 srcs: Vec<usize>,
134 dst: usize,
135 dim: usize,
136 },
137 Split {
138 src: usize,
139 dst: usize,
140 dim: usize,
141 chunks: usize,
142 },
143
144 Softmax {
146 src: usize,
147 dst: usize,
148 dim: usize,
149 },
150 Embedding {
151 indices: usize,
152 table: usize,
153 dst: usize,
154 },
155 Linear {
156 input: usize,
157 weight: usize,
158 bias: Option<usize>,
159 dst: usize,
160 },
161 LayerNorm {
162 input: usize,
163 weight: usize,
164 bias: usize,
165 dst: usize,
166 eps: f64,
167 },
168 BatchNorm {
169 input: usize,
170 weight: Option<usize>,
171 bias: Option<usize>,
172 dst: usize,
173 eps: f64,
174 },
175 MultiHeadAttention {
176 input: usize,
177 dst: usize,
178 n_heads: usize,
179 },
180 TransformerBlock {
181 input: usize,
182 dst: usize,
183 n_heads: usize,
184 },
185 Dropout {
186 src: usize,
187 dst: usize,
188 p: f64,
189 },
190
191 CrossEntropy {
193 predictions: usize,
194 targets: usize,
195 dst: usize,
196 },
197 MseLoss {
198 predictions: usize,
199 targets: usize,
200 dst: usize,
201 },
202
203 Constant {
205 value: ConstantValue,
206 output_type: IrType,
207 dst: usize,
208 },
209
210 Repeat {
212 count: i64,
213 body_op: Box<OpKind>,
214 src: usize,
215 dst: usize,
216 },
217 Call {
218 graph_name: String,
219 inputs: Vec<usize>,
220 dst: usize,
221 },
222
223 Compare {
225 op: CompareInstr,
226 lhs: usize,
227 rhs: usize,
228 dst: usize,
229 },
230 LogicalNot {
231 src: usize,
232 dst: usize,
233 },
234 LogicalBinOp {
235 op: LogicalBinInstr,
236 lhs: usize,
237 rhs: usize,
238 dst: usize,
239 },
240
241 FusedMatMulAdd {
243 a: usize,
244 b: usize,
245 c: usize,
246 dst: usize,
247 },
248 FusedAddRelu {
249 lhs: usize,
250 rhs: usize,
251 dst: usize,
252 },
253 FusedSubRelu {
254 lhs: usize,
255 rhs: usize,
256 dst: usize,
257 },
258 FusedMatMulRelu {
259 lhs: usize,
260 rhs: usize,
261 dst: usize,
262 },
263
264 Copy {
266 src: usize,
267 dst: usize,
268 },
269
270 Range {
272 inputs: Vec<usize>,
273 output_type: IrType,
274 dst: usize,
275 },
276
277 Free {
279 slot: usize,
280 },
281}
282
283#[derive(Debug, Clone, Copy)]
285pub enum UnaryInstr {
286 Neg,
287 Relu,
288 Gelu,
289 Silu,
290 Sigmoid,
291 Tanh,
292 Exp,
293 Log,
294 Sqrt,
295}
296
297#[derive(Debug, Clone, Copy)]
299pub enum BinaryInstr {
300 Add,
301 Sub,
302 Mul,
303 Div,
304 MatMul,
305 Pow,
306 Mod,
307}
308
309#[derive(Debug, Clone, Copy)]
311pub enum ReduceInstr {
312 Sum,
313 Mean,
314 Max,
315 Min,
316 Variance,
317}
318
319#[derive(Debug, Clone, Copy)]
321pub enum CompareInstr {
322 Equal,
323 NotEqual,
324 Less,
325 Greater,
326 LessEqual,
327 GreaterEqual,
328}
329
330#[derive(Debug, Clone, Copy)]
332pub enum LogicalBinInstr {
333 And,
334 Or,
335}
336
337#[derive(Debug, Clone)]
343pub struct ValueLifetime {
344 pub produced_at: usize,
346 pub last_used_at: usize,
348 pub node_id: NodeId,
350 pub is_output: bool,
352 pub is_external: bool,
354}
355
356#[derive(Debug, Clone)]
359pub struct MemoryPlan {
360 pub num_slots: usize,
362 pub node_to_slot: HashMap<usize, usize>,
364 pub lifetimes: Vec<ValueLifetime>,
366 pub free_points: Vec<(usize, usize)>,
368 pub reuse_count: usize,
370}
371
372#[derive(Debug)]
381pub struct CompiledGraph {
382 pub graph_name: String,
384 pub instructions: Vec<Instruction>,
386 pub memory_plan: MemoryPlan,
388 pub output_slots: HashMap<String, usize>,
390 pub num_slots: usize,
392 pub stats: CompileStats,
394}
395
396#[derive(Debug, Clone)]
398pub struct CompileStats {
399 pub num_instructions: usize,
401 pub num_source_nodes: usize,
403 pub num_slots: usize,
405 pub num_reused: usize,
407 pub num_frees: usize,
409 pub num_fused: usize,
411 pub compile_time_us: u64,
413}
414
415impl fmt::Display for CompileStats {
416 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417 write!(
418 f,
419 "CompiledGraph: {} instructions ({} source nodes), {} slots ({} reused), {} frees, {} fused, compiled in {}μs",
420 self.num_instructions,
421 self.num_source_nodes,
422 self.num_slots,
423 self.num_reused,
424 self.num_frees,
425 self.num_fused,
426 self.compile_time_us,
427 )
428 }
429}
430
431pub fn compile_graph(
437 graph: &IrGraph,
438 program: &IrProgram,
439 config: &RuntimeConfig,
440) -> Result<CompiledGraph> {
441 let start = Instant::now();
442
443 let order = graph.topo_order();
445
446 let mut node_to_slot: HashMap<usize, usize> = HashMap::new();
448 let mut next_slot = 0usize;
449 let mut produced_at: HashMap<usize, usize> = HashMap::new();
450 let mut last_used_at: HashMap<usize, usize> = HashMap::new();
451
452 let output_node_ids: std::collections::HashSet<usize> =
454 graph.outputs.iter().map(|o| o.node_id.0).collect();
455
456 let input_node_ids: std::collections::HashSet<usize> =
458 graph.inputs.iter().map(|id| id.0).collect();
459 let param_node_ids: std::collections::HashSet<usize> =
460 graph.params.iter().map(|p| p.node_id.0).collect();
461
462 for &node_id in &order {
464 let slot = next_slot;
465 node_to_slot.insert(node_id.0, slot);
466 next_slot += 1;
467 }
468
469 let mut instructions: Vec<Instruction> = Vec::with_capacity(order.len() + graph.inputs.len());
471 let mut num_fused = 0;
472
473 for (instr_idx, &node_id) in order.iter().enumerate() {
474 let node = graph.node(node_id);
475 let dst = node_to_slot[&node_id.0];
476
477 produced_at.insert(node_id.0, instr_idx);
479
480 for &input_id in &node.inputs {
482 last_used_at.insert(input_id.0, instr_idx);
483 }
484
485 if input_node_ids.contains(&node_id.0) {
487 instructions.push(Instruction::LoadInput {
488 name: node.name.clone(),
489 dst,
490 });
491 continue;
492 }
493
494 if param_node_ids.contains(&node_id.0) {
496 if let Some(param) = graph.params.iter().find(|p| p.node_id == node_id) {
497 instructions.push(Instruction::LoadParam {
498 graph_name: graph.name.clone(),
499 param_name: param.name.clone(),
500 dst,
501 });
502 }
503 continue;
504 }
505
506 let instr = compile_node(graph, node, &node_to_slot, config, program)?;
508
509 match &instr {
511 Instruction::FusedMatMulAdd { .. }
512 | Instruction::FusedAddRelu { .. }
513 | Instruction::FusedSubRelu { .. }
514 | Instruction::FusedMatMulRelu { .. } => {
515 num_fused += 1;
516 }
517 _ => {}
518 }
519
520 instructions.push(instr);
521 }
522
523 let mut lifetimes = Vec::new();
525 let mut free_points = Vec::new();
526
527 for &node_id in &order {
528 let slot = node_to_slot[&node_id.0];
529 let is_output = output_node_ids.contains(&node_id.0);
530 let is_external =
531 input_node_ids.contains(&node_id.0) || param_node_ids.contains(&node_id.0);
532 let prod = produced_at.get(&node_id.0).copied().unwrap_or(0);
533 let last = last_used_at.get(&node_id.0).copied().unwrap_or(prod);
534
535 lifetimes.push(ValueLifetime {
536 produced_at: prod,
537 last_used_at: last,
538 node_id,
539 is_output,
540 is_external,
541 });
542
543 if !is_output && !is_external && last < instructions.len().saturating_sub(1) {
545 free_points.push((slot, last));
546 }
547 }
548
549 free_points.sort_by(|a, b| b.1.cmp(&a.1));
551
552 let num_frees = free_points.len();
554 for (slot, after_idx) in &free_points {
555 let insert_pos = (*after_idx + 1).min(instructions.len());
556 instructions.insert(insert_pos, Instruction::Free { slot: *slot });
557 }
558
559 let mut output_slots = HashMap::new();
561 for output in &graph.outputs {
562 if let Some(&slot) = node_to_slot.get(&output.node_id.0) {
563 output_slots.insert(output.name.clone(), slot);
564 }
565 }
566
567 let num_slots = next_slot;
568 let compile_time = start.elapsed();
569
570 let memory_plan = MemoryPlan {
571 num_slots,
572 node_to_slot,
573 lifetimes,
574 free_points: Vec::new(), reuse_count: 0, };
577
578 let stats = CompileStats {
579 num_instructions: instructions.len(),
580 num_source_nodes: graph.nodes.len(),
581 num_slots,
582 num_reused: 0,
583 num_frees,
584 num_fused,
585 compile_time_us: compile_time.as_micros() as u64,
586 };
587
588 Ok(CompiledGraph {
589 graph_name: graph.name.clone(),
590 instructions,
591 memory_plan,
592 output_slots,
593 num_slots,
594 stats,
595 })
596}
597
598fn compile_node(
600 _graph: &IrGraph,
601 node: &IrNode,
602 node_to_slot: &HashMap<usize, usize>,
603 config: &RuntimeConfig,
604 program: &IrProgram,
605) -> Result<Instruction> {
606 let dst = node_to_slot[&node.id.0];
607
608 let slot = |idx: usize| -> Result<usize> {
610 let input_id = node.inputs.get(idx).ok_or_else(|| {
611 shrew_core::Error::msg(format!(
612 "Node '{}' expected input at index {}, but has {} inputs",
613 node.name,
614 idx,
615 node.inputs.len()
616 ))
617 })?;
618 node_to_slot.get(&input_id.0).copied().ok_or_else(|| {
619 shrew_core::Error::msg(format!(
620 "Node '{}' input {} (NodeId {}) not found in slot map",
621 node.name, idx, input_id.0
622 ))
623 })
624 };
625
626 match &node.op {
627 OpKind::Identity => Ok(Instruction::Copy { src: slot(0)?, dst }),
628
629 OpKind::Neg => Ok(Instruction::Unary {
631 op: UnaryInstr::Neg,
632 src: slot(0)?,
633 dst,
634 }),
635 OpKind::Relu => Ok(Instruction::Unary {
636 op: UnaryInstr::Relu,
637 src: slot(0)?,
638 dst,
639 }),
640 OpKind::Gelu => Ok(Instruction::Unary {
641 op: UnaryInstr::Gelu,
642 src: slot(0)?,
643 dst,
644 }),
645 OpKind::Silu => Ok(Instruction::Unary {
646 op: UnaryInstr::Silu,
647 src: slot(0)?,
648 dst,
649 }),
650 OpKind::Sigmoid => Ok(Instruction::Unary {
651 op: UnaryInstr::Sigmoid,
652 src: slot(0)?,
653 dst,
654 }),
655 OpKind::Tanh => Ok(Instruction::Unary {
656 op: UnaryInstr::Tanh,
657 src: slot(0)?,
658 dst,
659 }),
660 OpKind::Exp => Ok(Instruction::Unary {
661 op: UnaryInstr::Exp,
662 src: slot(0)?,
663 dst,
664 }),
665 OpKind::Log => Ok(Instruction::Unary {
666 op: UnaryInstr::Log,
667 src: slot(0)?,
668 dst,
669 }),
670 OpKind::Sqrt => Ok(Instruction::Unary {
671 op: UnaryInstr::Sqrt,
672 src: slot(0)?,
673 dst,
674 }),
675
676 OpKind::Add => Ok(Instruction::Binary {
678 op: BinaryInstr::Add,
679 lhs: slot(0)?,
680 rhs: slot(1)?,
681 dst,
682 }),
683 OpKind::Sub => Ok(Instruction::Binary {
684 op: BinaryInstr::Sub,
685 lhs: slot(0)?,
686 rhs: slot(1)?,
687 dst,
688 }),
689 OpKind::Mul => Ok(Instruction::Binary {
690 op: BinaryInstr::Mul,
691 lhs: slot(0)?,
692 rhs: slot(1)?,
693 dst,
694 }),
695 OpKind::Div => Ok(Instruction::Binary {
696 op: BinaryInstr::Div,
697 lhs: slot(0)?,
698 rhs: slot(1)?,
699 dst,
700 }),
701 OpKind::MatMul => Ok(Instruction::Binary {
702 op: BinaryInstr::MatMul,
703 lhs: slot(0)?,
704 rhs: slot(1)?,
705 dst,
706 }),
707 OpKind::Pow => Ok(Instruction::Binary {
708 op: BinaryInstr::Pow,
709 lhs: slot(0)?,
710 rhs: slot(1)?,
711 dst,
712 }),
713 OpKind::Mod => Ok(Instruction::Binary {
714 op: BinaryInstr::Mod,
715 lhs: slot(0)?,
716 rhs: slot(1)?,
717 dst,
718 }),
719
720 OpKind::Transpose => Ok(Instruction::Transpose { src: slot(0)?, dst }),
722
723 OpKind::Sum { dims, keepdim } => Ok(Instruction::Reduce {
725 op: ReduceInstr::Sum,
726 src: slot(0)?,
727 dst,
728 dims: dims.clone(),
729 keepdim: *keepdim,
730 }),
731 OpKind::Mean { dims, keepdim } => Ok(Instruction::Reduce {
732 op: ReduceInstr::Mean,
733 src: slot(0)?,
734 dst,
735 dims: dims.clone(),
736 keepdim: *keepdim,
737 }),
738 OpKind::Max { dim, keepdim } => Ok(Instruction::Reduce {
739 op: ReduceInstr::Max,
740 src: slot(0)?,
741 dst,
742 dims: vec![*dim],
743 keepdim: *keepdim,
744 }),
745 OpKind::Min { dim, keepdim } => Ok(Instruction::Reduce {
746 op: ReduceInstr::Min,
747 src: slot(0)?,
748 dst,
749 dims: vec![*dim],
750 keepdim: *keepdim,
751 }),
752 OpKind::Variance { dims, keepdim } => Ok(Instruction::Reduce {
753 op: ReduceInstr::Variance,
754 src: slot(0)?,
755 dst,
756 dims: dims.clone(),
757 keepdim: *keepdim,
758 }),
759
760 OpKind::Softmax { dim } => {
762 let d = *dim;
763 Ok(Instruction::Softmax {
764 src: slot(0)?,
765 dst,
766 dim: d as usize,
767 })
768 }
769
770 OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
772 let shape = resolve_shape_vec(target_shape, config, program)?;
773 Ok(Instruction::Reshape {
774 src: slot(0)?,
775 dst,
776 shape,
777 })
778 }
779 OpKind::Permute { dims } => Ok(Instruction::Permute {
780 src: slot(0)?,
781 dst,
782 dims: dims.clone(),
783 }),
784 OpKind::Expand { target_shape } => {
785 let shape = resolve_shape_vec(target_shape, config, program)?;
786 Ok(Instruction::Expand {
787 src: slot(0)?,
788 dst,
789 shape,
790 })
791 }
792 OpKind::Concat { dim } => {
793 let srcs: Vec<usize> = (0..node.inputs.len())
794 .map(&slot)
795 .collect::<Result<Vec<_>>>()?;
796 Ok(Instruction::Concat {
797 srcs,
798 dst,
799 dim: *dim as usize,
800 })
801 }
802 OpKind::Split { dim, chunks } => Ok(Instruction::Split {
803 src: slot(0)?,
804 dst,
805 dim: resolve_neg_dim(*dim, 4), chunks: *chunks as usize,
807 }),
808
809 OpKind::Embedding => Ok(Instruction::Embedding {
811 indices: slot(0)?,
812 table: slot(1)?,
813 dst,
814 }),
815 OpKind::Linear { bias } => {
816 let bias_slot = if *bias && node.inputs.len() >= 3 {
817 Some(slot(2)?)
818 } else {
819 None
820 };
821 Ok(Instruction::Linear {
822 input: slot(0)?,
823 weight: slot(1)?,
824 bias: bias_slot,
825 dst,
826 })
827 }
828 OpKind::LayerNorm { eps } => Ok(Instruction::LayerNorm {
829 input: slot(0)?,
830 weight: slot(1)?,
831 bias: slot(2)?,
832 dst,
833 eps: *eps,
834 }),
835 OpKind::BatchNorm { eps } => {
836 let weight = if node.inputs.len() >= 2 {
837 Some(slot(1)?)
838 } else {
839 None
840 };
841 let bias = if node.inputs.len() >= 3 {
842 Some(slot(2)?)
843 } else {
844 None
845 };
846 Ok(Instruction::BatchNorm {
847 input: slot(0)?,
848 weight,
849 bias,
850 dst,
851 eps: *eps,
852 })
853 }
854 OpKind::MultiHeadAttention { n_heads } => Ok(Instruction::MultiHeadAttention {
855 input: slot(0)?,
856 dst,
857 n_heads: *n_heads as usize,
858 }),
859 OpKind::TransformerBlock { n_heads } => Ok(Instruction::TransformerBlock {
860 input: slot(0)?,
861 dst,
862 n_heads: *n_heads as usize,
863 }),
864 OpKind::Dropout { p } => Ok(Instruction::Dropout {
865 src: slot(0)?,
866 dst,
867 p: *p,
868 }),
869
870 OpKind::CrossEntropy => Ok(Instruction::CrossEntropy {
872 predictions: slot(0)?,
873 targets: slot(1)?,
874 dst,
875 }),
876 OpKind::MseLoss => Ok(Instruction::MseLoss {
877 predictions: slot(0)?,
878 targets: slot(1)?,
879 dst,
880 }),
881
882 OpKind::Constant(val) => Ok(Instruction::Constant {
884 value: val.clone(),
885 output_type: node.output_type.clone(),
886 dst,
887 }),
888
889 OpKind::Repeat { count, body_op } => Ok(Instruction::Repeat {
891 count: *count,
892 body_op: body_op.clone(),
893 src: slot(0)?,
894 dst,
895 }),
896
897 OpKind::Call { graph_name } => {
899 let inputs: Vec<usize> = (0..node.inputs.len())
900 .map(&slot)
901 .collect::<Result<Vec<_>>>()?;
902 Ok(Instruction::Call {
903 graph_name: graph_name.clone(),
904 inputs,
905 dst,
906 })
907 }
908
909 OpKind::Range => {
911 let inputs: Vec<usize> = (0..node.inputs.len())
912 .map(&slot)
913 .collect::<Result<Vec<_>>>()?;
914 Ok(Instruction::Range {
915 inputs,
916 output_type: node.output_type.clone(),
917 dst,
918 })
919 }
920
921 OpKind::Equal => Ok(Instruction::Compare {
923 op: CompareInstr::Equal,
924 lhs: slot(0)?,
925 rhs: slot(1)?,
926 dst,
927 }),
928 OpKind::NotEqual => Ok(Instruction::Compare {
929 op: CompareInstr::NotEqual,
930 lhs: slot(0)?,
931 rhs: slot(1)?,
932 dst,
933 }),
934 OpKind::Less => Ok(Instruction::Compare {
935 op: CompareInstr::Less,
936 lhs: slot(0)?,
937 rhs: slot(1)?,
938 dst,
939 }),
940 OpKind::Greater => Ok(Instruction::Compare {
941 op: CompareInstr::Greater,
942 lhs: slot(0)?,
943 rhs: slot(1)?,
944 dst,
945 }),
946 OpKind::LessEqual => Ok(Instruction::Compare {
947 op: CompareInstr::LessEqual,
948 lhs: slot(0)?,
949 rhs: slot(1)?,
950 dst,
951 }),
952 OpKind::GreaterEqual => Ok(Instruction::Compare {
953 op: CompareInstr::GreaterEqual,
954 lhs: slot(0)?,
955 rhs: slot(1)?,
956 dst,
957 }),
958
959 OpKind::And => Ok(Instruction::LogicalBinOp {
961 op: LogicalBinInstr::And,
962 lhs: slot(0)?,
963 rhs: slot(1)?,
964 dst,
965 }),
966 OpKind::Or => Ok(Instruction::LogicalBinOp {
967 op: LogicalBinInstr::Or,
968 lhs: slot(0)?,
969 rhs: slot(1)?,
970 dst,
971 }),
972 OpKind::Not => Ok(Instruction::LogicalNot { src: slot(0)?, dst }),
973
974 OpKind::Custom { name, .. } => match name.as_str() {
976 "fused_matmul_add" => Ok(Instruction::FusedMatMulAdd {
977 a: slot(0)?,
978 b: slot(1)?,
979 c: slot(2)?,
980 dst,
981 }),
982 "fused_add_relu" => Ok(Instruction::FusedAddRelu {
983 lhs: slot(0)?,
984 rhs: slot(1)?,
985 dst,
986 }),
987 "fused_sub_relu" => Ok(Instruction::FusedSubRelu {
988 lhs: slot(0)?,
989 rhs: slot(1)?,
990 dst,
991 }),
992 "fused_matmul_relu" => Ok(Instruction::FusedMatMulRelu {
993 lhs: slot(0)?,
994 rhs: slot(1)?,
995 dst,
996 }),
997 other => Err(shrew_core::Error::msg(format!(
998 "Unknown custom op '{}' during JIT compilation",
999 other
1000 ))),
1001 },
1002 }
1003}
1004
1005pub struct JitExecutor<B: Backend> {
1024 compiled: HashMap<String, CompiledGraph>,
1026 program: IrProgram,
1028 config: RuntimeConfig,
1030 device: B::Device,
1032 params: HashMap<(String, String), Tensor<B>>,
1034}
1035
1036#[derive(Debug)]
1038pub struct JitResult<B: Backend> {
1039 pub outputs: HashMap<String, Tensor<B>>,
1041}
1042
1043impl<B: Backend> JitResult<B> {
1044 pub fn output(&self) -> Option<&Tensor<B>> {
1046 self.outputs.values().next()
1047 }
1048
1049 pub fn get(&self, name: &str) -> Option<&Tensor<B>> {
1051 self.outputs.get(name)
1052 }
1053}
1054
1055impl<B: Backend> JitExecutor<B> {
1056 pub fn compile(program: IrProgram, device: B::Device, config: RuntimeConfig) -> Result<Self> {
1058 let mut compiled = HashMap::new();
1059
1060 for graph in &program.graphs {
1062 let cg = compile_graph(graph, &program, &config)?;
1063 compiled.insert(graph.name.clone(), cg);
1064 }
1065
1066 let mut params = HashMap::new();
1068 for graph in &program.graphs {
1069 for param in &graph.params {
1070 let tensor = init_param::<B>(
1071 ¶m.ty,
1072 ¶m.init,
1073 param.frozen,
1074 &config,
1075 &program,
1076 &device,
1077 )?;
1078 params.insert((graph.name.clone(), param.name.clone()), tensor);
1079 }
1080 }
1081
1082 Ok(Self {
1083 compiled,
1084 program,
1085 config,
1086 device,
1087 params,
1088 })
1089 }
1090
1091 pub fn stats(&self, graph_name: &str) -> Option<&CompileStats> {
1093 self.compiled.get(graph_name).map(|cg| &cg.stats)
1094 }
1095
1096 pub fn all_stats(&self) -> Vec<(&str, &CompileStats)> {
1098 self.compiled
1099 .iter()
1100 .map(|(name, cg)| (name.as_str(), &cg.stats))
1101 .collect()
1102 }
1103
1104 pub fn run(
1106 &self,
1107 graph_name: &str,
1108 inputs: &HashMap<String, Tensor<B>>,
1109 ) -> Result<JitResult<B>> {
1110 let cg = self.compiled.get(graph_name).ok_or_else(|| {
1111 shrew_core::Error::msg(format!(
1112 "Graph '{}' not compiled. Available: {:?}",
1113 graph_name,
1114 self.compiled.keys().collect::<Vec<_>>()
1115 ))
1116 })?;
1117
1118 let mut slots: Vec<Option<Tensor<B>>> = vec![None; cg.num_slots];
1120
1121 for instr in &cg.instructions {
1123 match instr {
1124 Instruction::LoadInput { name, dst } => {
1125 if let Some(tensor) = inputs.get(name) {
1126 slots[*dst] = Some(tensor.clone());
1127 }
1128 }
1129
1130 Instruction::LoadParam {
1131 graph_name,
1132 param_name,
1133 dst,
1134 } => {
1135 let key = (graph_name.clone(), param_name.clone());
1136 if let Some(tensor) = self.params.get(&key) {
1137 slots[*dst] = Some(tensor.clone());
1138 }
1139 }
1140
1141 Instruction::Unary { op, src, dst } => {
1142 let t = get_slot(&slots, *src)?;
1143 let result = match op {
1144 UnaryInstr::Neg => t.neg(),
1145 UnaryInstr::Relu => t.relu(),
1146 UnaryInstr::Gelu => t.gelu(),
1147 UnaryInstr::Silu => t.silu(),
1148 UnaryInstr::Sigmoid => t.sigmoid(),
1149 UnaryInstr::Tanh => t.tanh(),
1150 UnaryInstr::Exp => t.exp(),
1151 UnaryInstr::Log => t.log(),
1152 UnaryInstr::Sqrt => t.sqrt(),
1153 }?;
1154 slots[*dst] = Some(result);
1155 }
1156
1157 Instruction::Binary { op, lhs, rhs, dst } => {
1158 let a = get_slot(&slots, *lhs)?;
1159 let b = get_slot(&slots, *rhs)?;
1160 let result = match op {
1161 BinaryInstr::Add => a.add(b),
1162 BinaryInstr::Sub => a.sub(b),
1163 BinaryInstr::Mul => a.mul(b),
1164 BinaryInstr::Div => a.div(b),
1165 BinaryInstr::MatMul => a.matmul(b),
1166 BinaryInstr::Pow => a.log()?.mul(b)?.exp(), BinaryInstr::Mod => {
1168 let quotient = a.div(b)?.floor()?;
1169 let product = quotient.mul(b)?;
1170 a.sub(&product)
1171 }
1172 }?;
1173 slots[*dst] = Some(result);
1174 }
1175
1176 Instruction::Reduce {
1177 op,
1178 src,
1179 dst,
1180 dims,
1181 keepdim,
1182 } => {
1183 let t = get_slot(&slots, *src)?;
1184 let result = match op {
1185 ReduceInstr::Sum => {
1186 if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
1187 t.sum_all()
1188 } else {
1189 let d = resolve_neg_dim(dims[0], t.rank());
1190 t.sum(d as usize, *keepdim)
1191 }
1192 }
1193 ReduceInstr::Mean => {
1194 if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
1195 t.mean_all()
1196 } else {
1197 let d = resolve_neg_dim(dims[0], t.rank());
1198 t.mean(d as usize, *keepdim)
1199 }
1200 }
1201 ReduceInstr::Max => {
1202 let d = resolve_neg_dim(dims[0], t.rank());
1203 t.max(d as usize, *keepdim)
1204 }
1205 ReduceInstr::Min => {
1206 let d = resolve_neg_dim(dims[0], t.rank());
1207 t.min(d as usize, *keepdim)
1208 }
1209 ReduceInstr::Variance => {
1210 if dims.is_empty() {
1211 t.var(0, *keepdim)
1212 } else {
1213 let d = resolve_neg_dim(dims[0], t.rank());
1214 t.var(d as usize, *keepdim)
1215 }
1216 }
1217 }?;
1218 slots[*dst] = Some(result);
1219 }
1220
1221 Instruction::Reshape { src, dst, shape } => {
1222 let t = get_slot(&slots, *src)?;
1223 slots[*dst] = Some(t.reshape(shape.clone())?);
1224 }
1225
1226 Instruction::Transpose { src, dst } => {
1227 let t = get_slot(&slots, *src)?;
1228 let rank = t.rank();
1229 slots[*dst] = Some(t.transpose(rank - 2, rank - 1)?);
1230 }
1231
1232 Instruction::Permute { src, dst, dims } => {
1233 let t = get_slot(&slots, *src)?;
1234 let mut result = t.clone();
1235 let mut current: Vec<usize> = (0..t.rank()).collect();
1236 for i in 0..dims.len() {
1237 let target = dims[i] as usize;
1238 if current[i] != target {
1239 if let Some(j) = current.iter().position(|&x| x == target) {
1240 result = result.transpose(i, j)?;
1241 current.swap(i, j);
1242 }
1243 }
1244 }
1245 slots[*dst] = Some(result);
1246 }
1247
1248 Instruction::Expand { src, dst, shape } => {
1249 let t = get_slot(&slots, *src)?;
1250 slots[*dst] = Some(t.expand(shape.clone())?);
1251 }
1252
1253 Instruction::Concat { srcs, dst, dim } => {
1254 let tensors: Vec<Tensor<B>> = srcs
1255 .iter()
1256 .map(|s| get_slot(&slots, *s).cloned())
1257 .collect::<Result<Vec<_>>>()?;
1258 slots[*dst] = Some(Tensor::<B>::cat(&tensors, *dim)?);
1259 }
1260
1261 Instruction::Split {
1262 src,
1263 dst,
1264 dim,
1265 chunks,
1266 } => {
1267 let t = get_slot(&slots, *src)?;
1268 let result = t.chunk(*chunks, *dim)?;
1269 if let Some(first) = result.into_iter().next() {
1270 slots[*dst] = Some(first);
1271 }
1272 }
1273
1274 Instruction::Softmax { src, dst, dim } => {
1275 let t = get_slot(&slots, *src)?;
1276 slots[*dst] = Some(t.softmax(*dim)?);
1277 }
1278
1279 Instruction::Embedding {
1280 indices,
1281 table,
1282 dst,
1283 } => {
1284 let idx = get_slot(&slots, *indices)?;
1285 let tbl = get_slot(&slots, *table)?;
1286 let emb = shrew_nn::Embedding::<B>::from_tensor(tbl.clone())?;
1287 slots[*dst] = Some(emb.forward(idx)?);
1288 }
1289
1290 Instruction::Linear {
1291 input,
1292 weight,
1293 bias,
1294 dst,
1295 } => {
1296 let inp = get_slot(&slots, *input)?;
1297 let w = get_slot(&slots, *weight)?;
1298 let b = bias.map(|s| get_slot(&slots, s).cloned()).transpose()?;
1299 let lin = shrew_nn::Linear::<B>::from_tensors(w.clone(), b)?;
1300 slots[*dst] = Some(lin.forward(inp)?);
1301 }
1302
1303 Instruction::LayerNorm {
1304 input,
1305 weight,
1306 bias,
1307 dst,
1308 eps,
1309 } => {
1310 let inp = get_slot(&slots, *input)?;
1311 let w = get_slot(&slots, *weight)?;
1312 let b = get_slot(&slots, *bias)?;
1313 let ln = shrew_nn::LayerNorm::<B>::from_tensors(w.clone(), b.clone(), *eps)?;
1314 slots[*dst] = Some(ln.forward(inp)?);
1315 }
1316
1317 Instruction::BatchNorm {
1318 input,
1319 weight,
1320 bias,
1321 dst,
1322 eps,
1323 } => {
1324 let inp = get_slot(&slots, *input)?;
1325 if let (Some(ws), Some(bs)) = (weight, bias) {
1326 let w = get_slot(&slots, *ws)?;
1327 let b = get_slot(&slots, *bs)?;
1328 let bn =
1329 shrew_nn::BatchNorm2d::<B>::from_tensors(w.clone(), b.clone(), *eps)?;
1330 slots[*dst] = Some(bn.forward(inp)?);
1331 } else {
1332 let dims = inp.dims();
1333 let c = if dims.len() == 4 { dims[1] } else { dims[0] };
1334 let bn = shrew_nn::BatchNorm2d::<B>::new(
1335 c,
1336 *eps,
1337 0.1,
1338 inp.dtype(),
1339 &self.device,
1340 )?;
1341 slots[*dst] = Some(bn.forward(inp)?);
1342 }
1343 }
1344
1345 Instruction::MultiHeadAttention {
1346 input,
1347 dst,
1348 n_heads,
1349 } => {
1350 let inp = get_slot(&slots, *input)?;
1351 let d_model = *inp
1352 .dims()
1353 .last()
1354 .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
1355 let mha = shrew_nn::MultiHeadAttention::<B>::new(
1356 d_model,
1357 *n_heads,
1358 inp.dtype(),
1359 inp.device(),
1360 )?;
1361 slots[*dst] = Some(mha.forward(inp)?);
1362 }
1363
1364 Instruction::TransformerBlock {
1365 input,
1366 dst,
1367 n_heads,
1368 } => {
1369 let inp = get_slot(&slots, *input)?;
1370 let dims = inp.dims();
1371 let d_model = dims[dims.len() - 1];
1372 let d_ff = d_model * 4;
1373 let block = shrew_nn::TransformerBlock::<B>::new(
1374 d_model,
1375 *n_heads,
1376 d_ff,
1377 true,
1378 inp.dtype(),
1379 inp.device(),
1380 )?;
1381 slots[*dst] = Some(block.forward(inp)?);
1382 }
1383
1384 Instruction::Dropout { src, dst, p } => {
1385 let t = get_slot(&slots, *src)?;
1386 let dropout = shrew_nn::Dropout::new(*p);
1387 if self.config.training {
1388 slots[*dst] = Some(dropout.forward_t(t)?);
1389 } else {
1390 slots[*dst] = Some(t.clone());
1391 }
1392 }
1393
1394 Instruction::CrossEntropy {
1395 predictions,
1396 targets,
1397 dst,
1398 } => {
1399 let p = get_slot(&slots, *predictions)?;
1400 let t = get_slot(&slots, *targets)?;
1401 slots[*dst] = Some(shrew_nn::cross_entropy_loss(p, t)?);
1402 }
1403
1404 Instruction::MseLoss {
1405 predictions,
1406 targets,
1407 dst,
1408 } => {
1409 let p = get_slot(&slots, *predictions)?;
1410 let t = get_slot(&slots, *targets)?;
1411 slots[*dst] = Some(shrew_nn::mse_loss(p, t)?);
1412 }
1413
1414 Instruction::Constant {
1415 value,
1416 output_type,
1417 dst,
1418 } => {
1419 let tensor = materialize_constant::<B>(
1420 value,
1421 output_type,
1422 self.config.default_dtype,
1423 &self.device,
1424 )?;
1425 slots[*dst] = Some(tensor);
1426 }
1427
1428 Instruction::Repeat {
1429 count,
1430 body_op,
1431 src,
1432 dst,
1433 } => {
1434 let t = get_slot(&slots, *src)?;
1435 let mut current = t.clone();
1436 for _ in 0..(*count as u32) {
1437 current = execute_body_op::<B>(body_op, ¤t, &self.device)?;
1438 }
1439 slots[*dst] = Some(current);
1440 }
1441
1442 Instruction::Call {
1443 graph_name,
1444 inputs: input_slots,
1445 dst,
1446 } => {
1447 let _sub_cg = self.compiled.get(graph_name).ok_or_else(|| {
1448 shrew_core::Error::msg(format!(
1449 "Called graph '{}' not compiled",
1450 graph_name
1451 ))
1452 })?;
1453 let sub_graph = self.program.get_graph(graph_name).ok_or_else(|| {
1454 shrew_core::Error::msg(format!("Called graph '{}' not found", graph_name))
1455 })?;
1456 let mut sub_inputs = HashMap::new();
1457 for (i, &input_id) in sub_graph.inputs.iter().enumerate() {
1458 if let Some(&s) = input_slots.get(i) {
1459 if let Some(tensor) = &slots[s] {
1460 let input_name = sub_graph.node(input_id).name.clone();
1461 sub_inputs.insert(input_name, tensor.clone());
1462 }
1463 }
1464 }
1465 let result = self.run(graph_name, &sub_inputs)?;
1466 if let Some(out) = result.output() {
1467 slots[*dst] = Some(out.clone());
1468 }
1469 }
1470
1471 Instruction::Compare { op, lhs, rhs, dst } => {
1472 let a = get_slot(&slots, *lhs)?;
1473 let b = get_slot(&slots, *rhs)?;
1474 let result = match op {
1475 CompareInstr::Equal => a.eq(b),
1476 CompareInstr::NotEqual => a.ne(b),
1477 CompareInstr::Less => a.lt(b),
1478 CompareInstr::Greater => a.gt(b),
1479 CompareInstr::LessEqual => a.le(b),
1480 CompareInstr::GreaterEqual => a.ge(b),
1481 }?;
1482 slots[*dst] = Some(result);
1483 }
1484
1485 Instruction::LogicalNot { src, dst } => {
1486 let t = get_slot(&slots, *src)?;
1487 let data = t.to_f64_vec()?;
1488 let result: Vec<f64> = data
1489 .iter()
1490 .map(|&v| if v == 0.0 { 1.0 } else { 0.0 })
1491 .collect();
1492 let n = result.len();
1493 slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1494 &result,
1495 n,
1496 CoreDType::U8,
1497 &self.device,
1498 )?);
1499 }
1500
1501 Instruction::LogicalBinOp { op, lhs, rhs, dst } => {
1502 let a = get_slot(&slots, *lhs)?;
1503 let b = get_slot(&slots, *rhs)?;
1504 let a_data = a.to_f64_vec()?;
1505 let b_data = b.to_f64_vec()?;
1506 let result: Vec<f64> = a_data
1507 .iter()
1508 .zip(b_data.iter())
1509 .map(|(&x, &y)| match op {
1510 LogicalBinInstr::And => {
1511 if x != 0.0 && y != 0.0 {
1512 1.0
1513 } else {
1514 0.0
1515 }
1516 }
1517 LogicalBinInstr::Or => {
1518 if x != 0.0 || y != 0.0 {
1519 1.0
1520 } else {
1521 0.0
1522 }
1523 }
1524 })
1525 .collect();
1526 let n = result.len();
1527 slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1528 &result,
1529 n,
1530 CoreDType::U8,
1531 &self.device,
1532 )?);
1533 }
1534
1535 Instruction::FusedMatMulAdd { a, b, c, dst } => {
1536 let at = get_slot(&slots, *a)?;
1537 let bt = get_slot(&slots, *b)?;
1538 let ct = get_slot(&slots, *c)?;
1539 slots[*dst] = Some(at.matmul(bt)?.add(ct)?);
1540 }
1541
1542 Instruction::FusedAddRelu { lhs, rhs, dst } => {
1543 let a = get_slot(&slots, *lhs)?;
1544 let b = get_slot(&slots, *rhs)?;
1545 slots[*dst] = Some(a.add(b)?.relu()?);
1546 }
1547
1548 Instruction::FusedSubRelu { lhs, rhs, dst } => {
1549 let a = get_slot(&slots, *lhs)?;
1550 let b = get_slot(&slots, *rhs)?;
1551 slots[*dst] = Some(a.sub(b)?.relu()?);
1552 }
1553
1554 Instruction::FusedMatMulRelu { lhs, rhs, dst } => {
1555 let a = get_slot(&slots, *lhs)?;
1556 let b = get_slot(&slots, *rhs)?;
1557 slots[*dst] = Some(a.matmul(b)?.relu()?);
1558 }
1559
1560 Instruction::Copy { src, dst } => {
1561 let t = get_slot(&slots, *src)?;
1562 slots[*dst] = Some(t.clone());
1563 }
1564
1565 Instruction::Range {
1566 inputs: input_slots,
1567 output_type,
1568 dst,
1569 } => {
1570 let (start, end) = if input_slots.len() >= 2 {
1571 let s = get_slot(&slots, input_slots[0])?.to_scalar_f64()?;
1572 let e = get_slot(&slots, input_slots[1])?.to_scalar_f64()?;
1573 (s as i64, e as i64)
1574 } else if input_slots.len() == 1 {
1575 (
1576 0i64,
1577 get_slot(&slots, input_slots[0])?.to_scalar_f64()? as i64,
1578 )
1579 } else {
1580 match output_type {
1581 IrType::Tensor { shape, .. } => {
1582 if let Some(Dim::Fixed(n)) = shape.first() {
1583 (0, *n)
1584 } else {
1585 (0, 1)
1586 }
1587 }
1588 _ => (0, 1),
1589 }
1590 };
1591 let data: Vec<f64> = (start..end).map(|i| i as f64).collect();
1592 let len = data.len();
1593 slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1594 &data,
1595 len,
1596 CoreDType::I64,
1597 &self.device,
1598 )?);
1599 }
1600
1601 Instruction::Free { slot } => {
1602 slots[*slot] = None;
1603 }
1604 }
1605 }
1606
1607 let mut outputs = HashMap::new();
1609 for (name, &slot) in &cg.output_slots {
1610 if let Some(tensor) = &slots[slot] {
1611 outputs.insert(name.clone(), tensor.clone());
1612 }
1613 }
1614
1615 Ok(JitResult { outputs })
1616 }
1617
1618 pub fn program(&self) -> &IrProgram {
1620 &self.program
1621 }
1622
1623 pub fn config(&self) -> &RuntimeConfig {
1625 &self.config
1626 }
1627
1628 pub fn params(&self) -> &HashMap<(String, String), Tensor<B>> {
1630 &self.params
1631 }
1632
1633 pub fn graph_params(&self, graph_name: &str) -> Vec<Tensor<B>> {
1635 self.params
1636 .iter()
1637 .filter(|((g, _), _)| g == graph_name)
1638 .map(|(_, t)| t.clone())
1639 .collect()
1640 }
1641
1642 pub fn update_params(&mut self, graph_name: &str, new_params: &[Tensor<B>]) {
1644 let param_names: Vec<String> = self
1645 .params
1646 .keys()
1647 .filter(|(g, _)| g == graph_name)
1648 .map(|(_, n)| n.clone())
1649 .collect();
1650
1651 for (name, tensor) in param_names.into_iter().zip(new_params.iter()) {
1652 self.params
1653 .insert((graph_name.to_string(), name), tensor.clone());
1654 }
1655 }
1656
1657 pub fn recompile(&mut self, graph_name: &str) -> Result<()> {
1659 let graph = self
1660 .program
1661 .get_graph(graph_name)
1662 .ok_or_else(|| shrew_core::Error::msg(format!("Graph '{}' not found", graph_name)))?;
1663 let cg = compile_graph(graph, &self.program, &self.config)?;
1664 self.compiled.insert(graph_name.to_string(), cg);
1665 Ok(())
1666 }
1667
1668 pub fn dump(&self, graph_name: &str) -> Option<String> {
1670 let cg = self.compiled.get(graph_name)?;
1671 let mut out = format!("=== JIT Compiled: {} ===\n", cg.graph_name);
1672 out.push_str(&format!("{}\n\n", cg.stats));
1673 for (i, instr) in cg.instructions.iter().enumerate() {
1674 out.push_str(&format!(" [{:>3}] {:?}\n", i, instr));
1675 }
1676 out.push_str(&format!("\nOutputs: {:?}\n", cg.output_slots));
1677 Some(out)
1678 }
1679}
1680
1681fn get_slot<B: Backend>(slots: &[Option<Tensor<B>>], idx: usize) -> Result<&Tensor<B>> {
1687 slots.get(idx).and_then(|s| s.as_ref()).ok_or_else(|| {
1688 shrew_core::Error::msg(format!(
1689 "Buffer slot {} is empty (value was freed or never produced)",
1690 idx
1691 ))
1692 })
1693}
1694
1695fn resolve_neg_dim(dim: i64, rank: usize) -> usize {
1697 if dim < 0 {
1698 (rank as i64 + dim) as usize
1699 } else {
1700 dim as usize
1701 }
1702}
1703
1704fn resolve_shape_vec(
1706 dims: &[Dim],
1707 config: &RuntimeConfig,
1708 program: &IrProgram,
1709) -> Result<Vec<usize>> {
1710 dims.iter()
1711 .map(|d| resolve_dim(d, config, program))
1712 .collect()
1713}
1714
1715fn resolve_dim(dim: &Dim, config: &RuntimeConfig, program: &IrProgram) -> Result<usize> {
1717 match dim {
1718 Dim::Fixed(n) => Ok(*n as usize),
1719 Dim::Symbolic(name) => {
1720 if let Some(&val) = config.dims.get(name) {
1721 return Ok(val);
1722 }
1723 if let Some(shrew_ir::graph::ConfigValue::Int(n)) = program.config.get(name) {
1724 return Ok(*n as usize);
1725 }
1726 Err(shrew_core::Error::msg(format!(
1727 "Unresolved symbolic dimension: '{}'",
1728 name
1729 )))
1730 }
1731 Dim::Dynamic => Err(shrew_core::Error::msg(
1732 "Cannot resolve dynamic dimension at compile time",
1733 )),
1734 }
1735}
1736
1737fn init_param<B: Backend>(
1739 ty: &IrType,
1740 init: &shrew_ir::graph::InitStrategy,
1741 frozen: bool,
1742 config: &RuntimeConfig,
1743 program: &IrProgram,
1744 device: &B::Device,
1745) -> Result<Tensor<B>> {
1746 let (shape, dtype) = resolve_type(ty, config, program)?;
1747
1748 let tensor = match init {
1749 shrew_ir::graph::InitStrategy::Zeros => Tensor::<B>::zeros(shape, dtype, device)?,
1750 shrew_ir::graph::InitStrategy::Ones => Tensor::<B>::ones(shape, dtype, device)?,
1751 shrew_ir::graph::InitStrategy::Normal { mean, std } => {
1752 Tensor::<B>::randn(shape, dtype, device)?.affine(*std, *mean)?
1753 }
1754 shrew_ir::graph::InitStrategy::Uniform { low, high } => {
1755 Tensor::<B>::rand(shape, dtype, device)?.affine(*high - *low, *low)?
1756 }
1757 shrew_ir::graph::InitStrategy::XavierUniform => {
1758 let (fan_in, fan_out) = compute_fans(&shape);
1759 let a = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
1760 Tensor::<B>::rand(shape, dtype, device)?.affine(2.0 * a, -a)?
1761 }
1762 shrew_ir::graph::InitStrategy::XavierNormal => {
1763 let (fan_in, fan_out) = compute_fans(&shape);
1764 let std = (2.0_f64 / (fan_in + fan_out) as f64).sqrt();
1765 Tensor::<B>::randn(shape, dtype, device)?.affine(std, 0.0)?
1766 }
1767 shrew_ir::graph::InitStrategy::KaimingUniform => {
1768 let (fan_in, _) = compute_fans(&shape);
1769 let bound = (3.0_f64 / fan_in as f64).sqrt();
1770 Tensor::<B>::rand(shape, dtype, device)?.affine(2.0 * bound, -bound)?
1771 }
1772 shrew_ir::graph::InitStrategy::KaimingNormal => {
1773 let (fan_in, _) = compute_fans(&shape);
1774 let std = (2.0_f64 / fan_in as f64).sqrt();
1775 Tensor::<B>::randn(shape, dtype, device)?.affine(std, 0.0)?
1776 }
1777 shrew_ir::graph::InitStrategy::Custom(_) => Tensor::<B>::randn(shape, dtype, device)?,
1778 };
1779
1780 if frozen {
1781 Ok(tensor)
1782 } else {
1783 Ok(tensor.set_variable())
1784 }
1785}
1786
1787fn resolve_type(
1789 ty: &IrType,
1790 config: &RuntimeConfig,
1791 program: &IrProgram,
1792) -> Result<(shrew_core::Shape, CoreDType)> {
1793 match ty {
1794 IrType::Tensor { shape, dtype } => {
1795 let dims: Vec<usize> = shape
1796 .iter()
1797 .map(|d| resolve_dim(d, config, program))
1798 .collect::<Result<Vec<_>>>()?;
1799 let core_dtype = ir_dtype_to_core(*dtype)?;
1800 Ok((shrew_core::Shape::new(dims), core_dtype))
1801 }
1802 IrType::Scalar(dtype) => {
1803 let core_dtype = ir_dtype_to_core(*dtype)?;
1804 Ok((shrew_core::Shape::new(vec![1]), core_dtype))
1805 }
1806 IrType::Int => Ok((shrew_core::Shape::new(vec![1]), CoreDType::I64)),
1807 _ => Ok((shrew_core::Shape::new(vec![1]), config.default_dtype)),
1808 }
1809}
1810
1811fn materialize_constant<B: Backend>(
1813 val: &ConstantValue,
1814 ty: &IrType,
1815 default_dtype: CoreDType,
1816 device: &B::Device,
1817) -> Result<Tensor<B>> {
1818 match val {
1819 ConstantValue::Int(n) => {
1820 Tensor::<B>::from_f64_slice(&[*n as f64], 1, CoreDType::I64, device)
1821 }
1822 ConstantValue::Float(f) => {
1823 let dtype = match ty {
1824 IrType::Tensor { dtype, .. } => ir_dtype_to_core(*dtype)?,
1825 IrType::Scalar(dtype) => ir_dtype_to_core(*dtype)?,
1826 _ => default_dtype,
1827 };
1828 Tensor::<B>::from_f64_slice(&[*f], 1, dtype, device)
1829 }
1830 ConstantValue::Bool(b) => {
1831 Tensor::<B>::from_f64_slice(&[if *b { 1.0 } else { 0.0 }], 1, CoreDType::U8, device)
1832 }
1833 ConstantValue::Str(_) => Tensor::<B>::zeros(1, default_dtype, device),
1834 ConstantValue::Null => Tensor::<B>::zeros(1, default_dtype, device),
1835 }
1836}
1837
1838fn execute_body_op<B: Backend>(
1840 op: &OpKind,
1841 input: &Tensor<B>,
1842 _device: &B::Device,
1843) -> Result<Tensor<B>> {
1844 match op {
1845 OpKind::TransformerBlock { n_heads } => {
1846 let dims = input.dims();
1847 let d_model = dims[dims.len() - 1];
1848 let d_ff = d_model * 4;
1849 let block = shrew_nn::TransformerBlock::<B>::new(
1850 d_model,
1851 *n_heads as usize,
1852 d_ff,
1853 true,
1854 input.dtype(),
1855 input.device(),
1856 )?;
1857 block.forward(input)
1858 }
1859 OpKind::MultiHeadAttention { n_heads } => {
1860 let d_model = *input
1861 .dims()
1862 .last()
1863 .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
1864 let mha = shrew_nn::MultiHeadAttention::<B>::new(
1865 d_model,
1866 *n_heads as usize,
1867 input.dtype(),
1868 input.device(),
1869 )?;
1870 mha.forward(input)
1871 }
1872 _ => Err(shrew_core::Error::msg(format!(
1873 "Unsupported op in Repeat body: {:?}",
1874 op
1875 ))),
1876 }
1877}
1878
1879fn compute_fans(shape: &shrew_core::Shape) -> (usize, usize) {
1881 let dims = shape.dims();
1882 match dims.len() {
1883 0 => (1, 1),
1884 1 => (dims[0], dims[0]),
1885 2 => (dims[1], dims[0]),
1886 _ => {
1887 let receptive: usize = dims[2..].iter().product();
1888 let fan_in = dims[1] * receptive;
1889 let fan_out = dims[0] * receptive;
1890 (fan_in, fan_out)
1891 }
1892 }
1893}
1894
1895pub fn load_jit<B: Backend>(
1910 source: &str,
1911 device: B::Device,
1912 config: RuntimeConfig,
1913) -> Result<JitExecutor<B>> {
1914 let ast =
1915 shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
1916 let mut ir = shrew_ir::lower(&ast)
1917 .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
1918
1919 if let Err(errors) = shrew_ir::validate(&ir) {
1920 let msg = errors
1921 .iter()
1922 .map(|e| e.to_string())
1923 .collect::<Vec<_>>()
1924 .join("\n");
1925 return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
1926 }
1927
1928 shrew_ir::infer_shapes(&mut ir);
1929 shrew_ir::optimize(&mut ir);
1930
1931 JitExecutor::<B>::compile(ir, device, config)
1932}