1use shrew_core::backend::Backend;
14use shrew_core::error::Result;
15use shrew_core::tensor::Tensor;
16
17use crate::module::Module;
18
19pub struct ReLU;
21
22impl<B: Backend> Module<B> for ReLU {
23 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
24 x.relu()
25 }
26 fn parameters(&self) -> Vec<Tensor<B>> {
27 vec![]
28 }
29}
30
31pub struct GeLU;
34
35impl<B: Backend> Module<B> for GeLU {
36 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
37 x.gelu()
38 }
39 fn parameters(&self) -> Vec<Tensor<B>> {
40 vec![]
41 }
42}
43
44pub struct SiLU;
47
48impl<B: Backend> Module<B> for SiLU {
49 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
50 x.silu()
51 }
52 fn parameters(&self) -> Vec<Tensor<B>> {
53 vec![]
54 }
55}
56
57pub struct Sigmoid;
59
60impl<B: Backend> Module<B> for Sigmoid {
61 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
62 x.sigmoid()
63 }
64 fn parameters(&self) -> Vec<Tensor<B>> {
65 vec![]
66 }
67}
68
69pub struct Tanh;
71
72impl<B: Backend> Module<B> for Tanh {
73 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
74 x.tanh()
75 }
76 fn parameters(&self) -> Vec<Tensor<B>> {
77 vec![]
78 }
79}
80
81pub struct LeakyReLU {
86 negative_slope: f64,
87}
88
89impl LeakyReLU {
90 pub fn new() -> Self {
92 LeakyReLU {
93 negative_slope: 0.01,
94 }
95 }
96
97 pub fn with_slope(negative_slope: f64) -> Self {
99 LeakyReLU { negative_slope }
100 }
101}
102
103impl Default for LeakyReLU {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl<B: Backend> Module<B> for LeakyReLU {
110 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
111 let zeros = Tensor::<B>::zeros_like(x)?;
113 let mask = x.ge(&zeros)?; let scaled = x.affine(self.negative_slope, 0.0)?; Tensor::<B>::where_cond(&mask, x, &scaled)
116 }
117 fn parameters(&self) -> Vec<Tensor<B>> {
118 vec![]
119 }
120}
121
122pub struct ELU {
126 alpha: f64,
127}
128
129impl ELU {
130 pub fn new() -> Self {
132 ELU { alpha: 1.0 }
133 }
134
135 pub fn with_alpha(alpha: f64) -> Self {
137 ELU { alpha }
138 }
139}
140
141impl Default for ELU {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl<B: Backend> Module<B> for ELU {
148 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
149 let zeros = Tensor::<B>::zeros_like(x)?;
151 let mask = x.gt(&zeros)?;
152 let exp_x = x.exp()?;
153 let ones = Tensor::<B>::ones_like(x)?;
154 let exp_minus_1 = exp_x.sub(&ones)?;
155 let neg_part = exp_minus_1.affine(self.alpha, 0.0)?;
156 Tensor::<B>::where_cond(&mask, x, &neg_part)
157 }
158 fn parameters(&self) -> Vec<Tensor<B>> {
159 vec![]
160 }
161}
162
163pub struct Mish;
168
169impl<B: Backend> Module<B> for Mish {
170 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
171 let exp_x = x.exp()?;
173 let ones = Tensor::<B>::ones_like(x)?;
174 let one_plus_exp = ones.add(&exp_x)?;
175 let softplus = one_plus_exp.log()?;
176 let tanh_sp = softplus.tanh()?;
178 x.mul(&tanh_sp)
179 }
180 fn parameters(&self) -> Vec<Tensor<B>> {
181 vec![]
182 }
183}