1use super::*;
7use crate::prelude::VeScalar;
8use crate::scalar::Opt;
9use crate::vector_engine::layer::{FpToFxp, FxpToFp};
10
11impl LogicBinaryOpI32 {
16 pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
18 match self {
19 LogicBinaryOpI32::BitAnd => |a, b| a & b,
20 LogicBinaryOpI32::BitOr => |a, b| a | b,
21 LogicBinaryOpI32::BitXor => |a, b| a ^ b,
22 LogicBinaryOpI32::LeftShift => |a, b| a << (b as u32),
23 LogicBinaryOpI32::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
24 LogicBinaryOpI32::ArithRightShift => |a, b| a >> (b as u32),
25 }
26 }
27}
28
29impl LogicBinaryOpF32 {
30 pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
32 match self {
33 LogicBinaryOpF32::BitAnd => |a, b| f32::from_bits(a.to_bits() & b.to_bits()),
34 LogicBinaryOpF32::BitOr => |a, b| f32::from_bits(a.to_bits() | b.to_bits()),
35 LogicBinaryOpF32::BitXor => |a, b| f32::from_bits(a.to_bits() ^ b.to_bits()),
36 }
37 }
38}
39
40impl LogicOpI {
41 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
43 let op = self.op.op_fn();
44 self.arg_mode.apply_opt(op)
45 }
46}
47
48impl LogicOpF {
49 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
51 let op = self.op.op_fn();
52 self.arg_mode.apply_opt(op)
53 }
54}
55
56impl FxpBinaryOp {
61 pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
63 match self {
64 FxpBinaryOp::AddFxp => |a, b| a.wrapping_add(b),
65 FxpBinaryOp::AddFxpSat => |a, b| a.saturating_add(b),
66 FxpBinaryOp::SubFxp => |a, b| a.wrapping_sub(b),
67 FxpBinaryOp::SubFxpSat => |a, b| a.saturating_sub(b),
68 FxpBinaryOp::LeftShift => |a, b| a << (b as u32),
69 FxpBinaryOp::LeftShiftSat => |a, b| a.saturating_mul(1 << (b as u32)),
70 FxpBinaryOp::MulFxp => |a, b| {
71 if a == i32::MIN && b == i32::MIN {
76 i32::MAX
77 } else {
78 let product = i64::from(a) * i64::from(b);
79 (((product >> 30) + 1) >> 1) as i32
80 }
81 },
82 FxpBinaryOp::MulInt => |a, b| a.wrapping_mul(b),
83 FxpBinaryOp::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
84 FxpBinaryOp::ArithRightShift => |a, b| a >> (b as u32),
85 FxpBinaryOp::ArithRightShiftRound => todo!("ArithRightShiftRound not implemented"),
86 }
87 }
88}
89
90impl FxpOp {
91 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
93 let op = self.op.op_fn();
94 self.arg_mode.apply_opt(op)
95 }
96}
97
98impl FpUnaryOp {
103 pub fn op_fn(&self) -> fn(f32) -> f32 {
105 match self {
106 FpUnaryOp::Exp => |x| x.exp(),
107 FpUnaryOp::NegExp => |x| (-x).exp(),
108 FpUnaryOp::Sqrt => |x| x.sqrt(),
109 FpUnaryOp::Tanh => |x| x.tanh(),
110 FpUnaryOp::Sigmoid => |x| 1.0 / (1.0 + (-x).exp()),
111 FpUnaryOp::Erf => |_x| todo!("Erf not implemented"),
112 FpUnaryOp::Log => |x| x.ln(),
113 FpUnaryOp::Sin => |x| x.sin(),
114 FpUnaryOp::Cos => |x| x.cos(),
115 }
116 }
117}
118
119impl FpBinaryOp {
120 pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
122 match self {
123 FpBinaryOp::AddF => |a, b| a + b,
124 FpBinaryOp::SubF => |a, b| a - b,
125 FpBinaryOp::MulF(_) => |a, b| a * b,
126 FpBinaryOp::MaskMulF(_) => |_a, _b| todo!("MaskMulF not implemented"),
127 FpBinaryOp::DivF => |a, b| a / b,
128 }
129 }
130}
131
132impl FpTernaryOp {
133 pub fn op_fn(&self) -> fn(f32, f32, f32) -> f32 {
135 match self {
136 FpTernaryOp::FmaF => |a, b, c| a.mul_add(b, c),
137 FpTernaryOp::MaskFmaF => |_a, _b, _c| todo!("MaskFmaF not implemented"),
138 }
139 }
140}
141
142impl FpOp {
143 pub fn unary_op_opt(&self) -> Box<dyn Fn(Opt<f32>) -> Opt<f32>> {
146 match self {
147 FpOp::UnaryOp { op, mode } => mode.apply_opt(op.op_fn()),
148 _ => panic!("unary_op_opt called on non-unary FpOp"),
149 }
150 }
151
152 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
155 match self {
156 FpOp::BinaryOp { op, mode } => mode.apply_opt(op.op_fn()),
157 _ => panic!("binary_op_opt called on non-binary FpOp"),
158 }
159 }
160
161 pub fn ternary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32>> {
164 match self {
165 FpOp::TernaryOp { op, mode } => mode.apply_opt(op.op_fn()),
166 _ => panic!("ternary_op_opt called on non-ternary FpOp"),
167 }
168 }
169}
170
171impl ClipBinaryOpI32 {
176 pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
178 match self {
179 ClipBinaryOpI32::AddFxp => |a, b| a.wrapping_add(b),
180 ClipBinaryOpI32::AddFxpSat => |a, b| a.saturating_add(b),
181 ClipBinaryOpI32::Min => |a, b| a.min(b),
182 ClipBinaryOpI32::Max => |a, b| a.max(b),
183 ClipBinaryOpI32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
184 ClipBinaryOpI32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
185 }
186 }
187}
188
189impl ClipBinaryOpF32 {
190 pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
192 match self {
193 ClipBinaryOpF32::Add => |a, b| a + b,
194 ClipBinaryOpF32::Min => |a, b| a.min(b),
195 ClipBinaryOpF32::Max => |a, b| a.max(b),
196 ClipBinaryOpF32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
197 ClipBinaryOpF32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
198 }
199 }
200}
201
202impl ClipOpI {
203 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
205 let op = self.op.op_fn();
206 self.mode.apply_opt(op)
207 }
208}
209
210impl ClipOpF {
211 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
213 let op = self.op.op_fn();
214 self.mode.apply_opt(op)
215 }
216}
217
218impl FxpToFp {
223 pub fn op_fn(&self) -> impl Fn(i32) -> f32 {
225 let int_width = self.int_width();
226 move |x| crate::float::fixedpoint_to_float(x, int_width)
227 }
228}
229
230impl FpToFxp {
231 pub fn op_fn(&self) -> impl Fn(f32) -> i32 {
233 let int_width = self.int_width();
234 move |x| crate::float::float_to_fixedpoint(x, int_width)
235 }
236}
237
238pub trait HasConversionOp<D: VeScalar, D2: VeScalar>: Clone + Copy {
240 fn conversion_op_fn(&self) -> impl Fn(D) -> D2;
242}
243
244impl HasConversionOp<i32, f32> for FxpToFp {
245 fn conversion_op_fn(&self) -> impl Fn(i32) -> f32 {
246 self.op_fn()
247 }
248}
249
250impl HasConversionOp<f32, i32> for FpToFxp {
251 fn conversion_op_fn(&self) -> impl Fn(f32) -> i32 {
252 self.op_fn()
253 }
254}
255
256fn lift_reduce_fn<D: Copy>(reduce_fn: impl Fn(D, D) -> D + 'static) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D> {
263 move |a: Opt<D>, b: Opt<D>| match (a, b) {
264 (Opt::Uninit, _) => b,
265 (_, Opt::Uninit) => a,
266 (Opt::Init(x), Opt::Init(y)) => Opt::Init(reduce_fn(x, y)),
267 }
268}
269
270impl IntraSliceReduceOpI32 {
271 pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
273 match self {
274 IntraSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
275 IntraSliceReduceOpI32::Max => |a, b| a.max(b),
276 IntraSliceReduceOpI32::Min => |a, b| a.min(b),
277 }
278 }
279
280 pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
282 Box::new(lift_reduce_fn(self.reduce_fn()))
283 }
284
285 pub fn identity(&self) -> i32 {
287 match self {
288 IntraSliceReduceOpI32::AddSat => 0,
289 IntraSliceReduceOpI32::Max => i32::MIN,
290 IntraSliceReduceOpI32::Min => i32::MAX,
291 }
292 }
293}
294
295impl IntraSliceReduceOpF32 {
296 pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
298 match self {
299 IntraSliceReduceOpF32::Add => |a, b| a + b,
300 IntraSliceReduceOpF32::Max => |a, b| a.max(b),
301 IntraSliceReduceOpF32::Min => |a, b| a.min(b),
302 }
303 }
304
305 pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
307 Box::new(lift_reduce_fn(self.reduce_fn()))
308 }
309
310 pub fn identity(&self) -> f32 {
312 match self {
313 IntraSliceReduceOpF32::Add => 0.0,
314 IntraSliceReduceOpF32::Max => f32::NEG_INFINITY,
315 IntraSliceReduceOpF32::Min => f32::INFINITY,
316 }
317 }
318}
319
320impl InterSliceReduceOpI32 {
325 pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
327 match self {
328 InterSliceReduceOpI32::Add => |a, b| a.wrapping_add(b),
329 InterSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
330 InterSliceReduceOpI32::Max => |a, b| a.max(b),
331 InterSliceReduceOpI32::Min => |a, b| a.min(b),
332 }
333 }
334
335 pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
337 Box::new(lift_reduce_fn(self.reduce_fn()))
338 }
339}
340
341impl InterSliceReduceOpF32 {
342 pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
344 match self {
345 InterSliceReduceOpF32::Add => |a, b| a + b,
346 InterSliceReduceOpF32::Max => |a, b| a.max(b),
347 InterSliceReduceOpF32::Min => |a, b| a.min(b),
348 InterSliceReduceOpF32::Mul => |a, b| a * b,
349 }
350 }
351
352 pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
354 Box::new(lift_reduce_fn(self.reduce_fn()))
355 }
356}
357
358impl FpDivBinaryOp {
363 pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
365 match self {
366 FpDivBinaryOp::DivF => |a, b| a / b,
367 }
368 }
369}
370
371impl FpDivOp {
372 pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
374 let op = self.op.op_fn();
375 self.mode.apply_opt(op)
376 }
377}
378
379pub trait HasUnaryOp<D>: Clone + Copy {
381 fn unary_op_fn(self, mode: Option<UnaryArgMode>) -> impl Fn(Opt<D>) -> Opt<D>;
384}
385
386pub trait HasBinaryOp<D>: Clone + Copy {
388 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D>;
391}
392
393pub trait HasTernaryOp<D>: Clone + Copy {
395 fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>;
397}
398
399impl HasBinaryOp<i32> for LogicBinaryOpI32 {
404 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
405 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
406 }
407}
408
409impl HasBinaryOp<f32> for LogicBinaryOpF32 {
410 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
411 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
412 }
413}
414
415impl HasBinaryOp<i32> for FxpBinaryOp {
416 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
417 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
418 }
419}
420
421impl HasUnaryOp<f32> for FpUnaryOp {
422 fn unary_op_fn(self, mode: Option<UnaryArgMode>) -> impl Fn(Opt<f32>) -> Opt<f32> {
423 mode.unwrap_or(UnaryArgMode::Mode0).apply_opt(self.op_fn())
424 }
425}
426
427impl HasBinaryOp<f32> for FpBinaryOp {
428 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
429 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
430 }
431}
432
433impl HasTernaryOp<f32> for FpTernaryOp {
434 fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32> {
435 mode.unwrap_or(TernaryArgMode::Mode012).apply_opt(self.op_fn())
436 }
437}
438
439impl HasBinaryOp<f32> for FpDivBinaryOp {
440 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
441 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
442 }
443}
444
445impl HasBinaryOp<f32> for FpDivOp {
446 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
447 match mode {
448 Some(mode) => mode.apply_opt(self.op.op_fn()),
449 None => self.binary_op_opt(),
450 }
451 }
452}
453
454impl HasBinaryOp<i32> for ClipBinaryOpI32 {
455 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
456 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
457 }
458}
459
460impl HasBinaryOp<f32> for ClipBinaryOpF32 {
461 fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
462 mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
463 }
464}