shrew_nn/
activation.rs

1// Activation modules — Wrappers around tensor activation functions
2//
3// These are thin wrappers that turn tensor-level activations (like tensor.relu())
4// into Module implementations. This lets you compose activations in Sequential.
5//
6// Example:
7//   let model = Sequential::new(vec![
8//       Box::new(linear1),
9//       Box::new(ReLU),
10//       Box::new(linear2),
11//   ]);
12
13use shrew_core::backend::Backend;
14use shrew_core::error::Result;
15use shrew_core::tensor::Tensor;
16
17use crate::module::Module;
18
19/// ReLU activation: max(0, x)
20pub struct ReLU;
21
22impl<B: Backend> Module<B> for ReLU {
23    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
24        x.relu()
25    }
26    fn parameters(&self) -> Vec<Tensor<B>> {
27        vec![]
28    }
29}
30
31/// GELU activation (Gaussian Error Linear Unit)
32/// Used in Transformers (BERT, GPT, etc.)
33pub struct GeLU;
34
35impl<B: Backend> Module<B> for GeLU {
36    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
37        x.gelu()
38    }
39    fn parameters(&self) -> Vec<Tensor<B>> {
40        vec![]
41    }
42}
43
44/// SiLU / Swish activation: x * σ(x)
45/// Used in modern architectures (EfficientNet, LLaMA, etc.)
46pub struct SiLU;
47
48impl<B: Backend> Module<B> for SiLU {
49    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
50        x.silu()
51    }
52    fn parameters(&self) -> Vec<Tensor<B>> {
53        vec![]
54    }
55}
56
57/// Sigmoid activation: 1 / (1 + e^(-x))
58pub struct Sigmoid;
59
60impl<B: Backend> Module<B> for Sigmoid {
61    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
62        x.sigmoid()
63    }
64    fn parameters(&self) -> Vec<Tensor<B>> {
65        vec![]
66    }
67}
68
69/// Tanh activation
70pub struct Tanh;
71
72impl<B: Backend> Module<B> for Tanh {
73    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
74        x.tanh()
75    }
76    fn parameters(&self) -> Vec<Tensor<B>> {
77        vec![]
78    }
79}
80
81/// LeakyReLU activation: max(negative_slope * x, x)
82///
83/// Allows a small gradient when the unit is not active (x < 0).
84/// Default negative_slope = 0.01.
85pub struct LeakyReLU {
86    negative_slope: f64,
87}
88
89impl LeakyReLU {
90    /// Create with default negative_slope = 0.01.
91    pub fn new() -> Self {
92        LeakyReLU {
93            negative_slope: 0.01,
94        }
95    }
96
97    /// Create with custom negative_slope.
98    pub fn with_slope(negative_slope: f64) -> Self {
99        LeakyReLU { negative_slope }
100    }
101}
102
103impl Default for LeakyReLU {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl<B: Backend> Module<B> for LeakyReLU {
110    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
111        // LeakyReLU(x) = x if x >= 0, negative_slope * x otherwise
112        let zeros = Tensor::<B>::zeros_like(x)?;
113        let mask = x.ge(&zeros)?; // mask = (x >= 0)
114        let scaled = x.affine(self.negative_slope, 0.0)?; // negative_slope * x
115        Tensor::<B>::where_cond(&mask, x, &scaled)
116    }
117    fn parameters(&self) -> Vec<Tensor<B>> {
118        vec![]
119    }
120}
121
122/// ELU activation: x if x > 0, alpha * (exp(x) - 1) otherwise
123///
124/// Smoother than ReLU for negative values. Default alpha = 1.0.
125pub struct ELU {
126    alpha: f64,
127}
128
129impl ELU {
130    /// Create with default alpha = 1.0.
131    pub fn new() -> Self {
132        ELU { alpha: 1.0 }
133    }
134
135    /// Create with custom alpha.
136    pub fn with_alpha(alpha: f64) -> Self {
137        ELU { alpha }
138    }
139}
140
141impl Default for ELU {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl<B: Backend> Module<B> for ELU {
148    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
149        // ELU(x) = x if x > 0, alpha * (exp(x) - 1) otherwise
150        let zeros = Tensor::<B>::zeros_like(x)?;
151        let mask = x.gt(&zeros)?;
152        let exp_x = x.exp()?;
153        let ones = Tensor::<B>::ones_like(x)?;
154        let exp_minus_1 = exp_x.sub(&ones)?;
155        let neg_part = exp_minus_1.affine(self.alpha, 0.0)?;
156        Tensor::<B>::where_cond(&mask, x, &neg_part)
157    }
158    fn parameters(&self) -> Vec<Tensor<B>> {
159        vec![]
160    }
161}
162
163/// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
164///
165/// A self-regularizing non-monotonic activation function.
166/// Used in YOLOv4 and other modern architectures.
167pub struct Mish;
168
169impl<B: Backend> Module<B> for Mish {
170    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
171        // softplus(x) = ln(1 + exp(x))
172        let exp_x = x.exp()?;
173        let ones = Tensor::<B>::ones_like(x)?;
174        let one_plus_exp = ones.add(&exp_x)?;
175        let softplus = one_plus_exp.log()?;
176        // mish(x) = x * tanh(softplus(x))
177        let tanh_sp = softplus.tanh()?;
178        x.mul(&tanh_sp)
179    }
180    fn parameters(&self) -> Vec<Tensor<B>> {
181        vec![]
182    }
183}