furiosa_visa_std/vector_engine/op/
semantics.rs

1//! Semantic implementations for VE operations.
2//!
3//! This module provides the actual operation logic (apply functions, operation functions)
4//! separated from type definitions in `op.rs`.
5
6use super::*;
7use crate::prelude::VeScalar;
8use crate::scalar::Opt;
9use crate::vector_engine::layer::{FpToFxp, FxpToFp};
10
11// ============================================================================
12// Operation functions - Logic
13// ============================================================================
14
15impl LogicBinaryOpI32 {
16    /// Returns the raw binary operation function.
17    pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
18        match self {
19            LogicBinaryOpI32::BitAnd => |a, b| a & b,
20            LogicBinaryOpI32::BitOr => |a, b| a | b,
21            LogicBinaryOpI32::BitXor => |a, b| a ^ b,
22            LogicBinaryOpI32::LeftShift => |a, b| a << (b as u32),
23            LogicBinaryOpI32::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
24            LogicBinaryOpI32::ArithRightShift => |a, b| a >> (b as u32),
25        }
26    }
27}
28
29impl LogicBinaryOpF32 {
30    /// Returns the raw binary operation function.
31    pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
32        match self {
33            LogicBinaryOpF32::BitAnd => |a, b| f32::from_bits(a.to_bits() & b.to_bits()),
34            LogicBinaryOpF32::BitOr => |a, b| f32::from_bits(a.to_bits() | b.to_bits()),
35            LogicBinaryOpF32::BitXor => |a, b| f32::from_bits(a.to_bits() ^ b.to_bits()),
36        }
37    }
38}
39
40impl LogicOpI {
41    /// Returns the binary operation with arg mode applied (Opt version).
42    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
43        let op = self.op.op_fn();
44        self.arg_mode.apply_opt(op)
45    }
46}
47
48impl LogicOpF {
49    /// Returns the binary operation with arg mode applied (Opt version).
50    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
51        let op = self.op.op_fn();
52        self.arg_mode.apply_opt(op)
53    }
54}
55
56// ============================================================================
57// Operation functions - Fxp
58// ============================================================================
59
60impl FxpBinaryOp {
61    /// Returns the raw binary operation function.
62    pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
63        match self {
64            FxpBinaryOp::AddFxp => |a, b| a.wrapping_add(b),
65            FxpBinaryOp::AddFxpSat => |a, b| a.saturating_add(b),
66            FxpBinaryOp::SubFxp => |a, b| a.wrapping_sub(b),
67            FxpBinaryOp::SubFxpSat => |a, b| a.saturating_sub(b),
68            FxpBinaryOp::LeftShift => |a, b| a << (b as u32),
69            FxpBinaryOp::LeftShiftSat => |a, b| a.saturating_mul(1 << (b as u32)),
70            FxpBinaryOp::MulFxp => |a, b| {
71                // Q31 fixed-point multiply with rounding, matching npu-ir BinOp::MulFxp.
72                // Operands are interpreted as Q31 (2^31 ≈ 1.0), so the raw product is
73                // shifted right by 31 with a round-to-nearest step. The sole overflow
74                // case is MIN × MIN, which saturates to MAX.
75                if a == i32::MIN && b == i32::MIN {
76                    i32::MAX
77                } else {
78                    let product = i64::from(a) * i64::from(b);
79                    (((product >> 30) + 1) >> 1) as i32
80                }
81            },
82            FxpBinaryOp::MulInt => |a, b| a.wrapping_mul(b),
83            FxpBinaryOp::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
84            FxpBinaryOp::ArithRightShift => |a, b| a >> (b as u32),
85            FxpBinaryOp::ArithRightShiftRound => todo!("ArithRightShiftRound not implemented"),
86        }
87    }
88}
89
90impl FxpOp {
91    /// Returns the binary operation with arg mode applied (Opt version).
92    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
93        let op = self.op.op_fn();
94        self.arg_mode.apply_opt(op)
95    }
96}
97
98// ============================================================================
99// Operation functions - Fp
100// ============================================================================
101
102impl FpUnaryOp {
103    /// Returns the raw unary operation function.
104    pub fn op_fn(&self) -> fn(f32) -> f32 {
105        match self {
106            FpUnaryOp::Exp => |x| x.exp(),
107            FpUnaryOp::NegExp => |x| (-x).exp(),
108            FpUnaryOp::Sqrt => |x| x.sqrt(),
109            FpUnaryOp::Tanh => |x| x.tanh(),
110            FpUnaryOp::Sigmoid => |x| 1.0 / (1.0 + (-x).exp()),
111            FpUnaryOp::Erf => |_x| todo!("Erf not implemented"),
112            FpUnaryOp::Log => |x| x.ln(),
113            FpUnaryOp::Sin => |x| x.sin(),
114            FpUnaryOp::Cos => |x| x.cos(),
115        }
116    }
117}
118
119impl FpBinaryOp {
120    /// Returns the raw binary operation function.
121    pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
122        match self {
123            FpBinaryOp::AddF => |a, b| a + b,
124            FpBinaryOp::SubF => |a, b| a - b,
125            FpBinaryOp::MulF(_) => |a, b| a * b,
126            FpBinaryOp::MaskMulF(_) => |_a, _b| todo!("MaskMulF not implemented"),
127            FpBinaryOp::DivF => |a, b| a / b,
128        }
129    }
130}
131
132impl FpTernaryOp {
133    /// Returns the raw ternary operation function.
134    pub fn op_fn(&self) -> fn(f32, f32, f32) -> f32 {
135        match self {
136            FpTernaryOp::FmaF => |a, b, c| a.mul_add(b, c),
137            FpTernaryOp::MaskFmaF => |_a, _b, _c| todo!("MaskFmaF not implemented"),
138        }
139    }
140}
141
142impl FpOp {
143    /// Returns the unary operation with arg mode applied (Opt version).
144    /// Panics if not a unary operation.
145    pub fn unary_op_opt(&self) -> Box<dyn Fn(Opt<f32>) -> Opt<f32>> {
146        match self {
147            FpOp::UnaryOp { op, mode } => mode.apply_opt(op.op_fn()),
148            _ => panic!("unary_op_opt called on non-unary FpOp"),
149        }
150    }
151
152    /// Returns the binary operation with arg mode applied (Opt version).
153    /// Panics if not a binary operation.
154    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
155        match self {
156            FpOp::BinaryOp { op, mode } => mode.apply_opt(op.op_fn()),
157            _ => panic!("binary_op_opt called on non-binary FpOp"),
158        }
159    }
160
161    /// Returns the ternary operation with arg mode applied (Opt version).
162    /// Panics if not a ternary operation.
163    pub fn ternary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32>> {
164        match self {
165            FpOp::TernaryOp { op, mode } => mode.apply_opt(op.op_fn()),
166            _ => panic!("ternary_op_opt called on non-ternary FpOp"),
167        }
168    }
169}
170
171// ============================================================================
172// Operation functions - Clip
173// ============================================================================
174
175impl ClipBinaryOpI32 {
176    /// Returns the raw binary operation function.
177    pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
178        match self {
179            ClipBinaryOpI32::AddFxp => |a, b| a.wrapping_add(b),
180            ClipBinaryOpI32::AddFxpSat => |a, b| a.saturating_add(b),
181            ClipBinaryOpI32::Min => |a, b| a.min(b),
182            ClipBinaryOpI32::Max => |a, b| a.max(b),
183            ClipBinaryOpI32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
184            ClipBinaryOpI32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
185        }
186    }
187}
188
189impl ClipBinaryOpF32 {
190    /// Returns the raw binary operation function.
191    pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
192        match self {
193            ClipBinaryOpF32::Add => |a, b| a + b,
194            ClipBinaryOpF32::Min => |a, b| a.min(b),
195            ClipBinaryOpF32::Max => |a, b| a.max(b),
196            ClipBinaryOpF32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
197            ClipBinaryOpF32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
198        }
199    }
200}
201
202impl ClipOpI {
203    /// Returns the binary operation with arg mode applied (Opt version).
204    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
205        let op = self.op.op_fn();
206        self.mode.apply_opt(op)
207    }
208}
209
210impl ClipOpF {
211    /// Returns the binary operation with arg mode applied (Opt version).
212    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
213        let op = self.op.op_fn();
214        self.mode.apply_opt(op)
215    }
216}
217
218// ============================================================================
219// Operation functions - FxpToFp / FpToFxp conversions
220// ============================================================================
221
222impl FxpToFp {
223    /// Returns the conversion function.
224    pub fn op_fn(&self) -> impl Fn(i32) -> f32 {
225        let int_width = self.int_width();
226        move |x| crate::float::fixedpoint_to_float(x, int_width)
227    }
228}
229
230impl FpToFxp {
231    /// Returns the conversion function.
232    pub fn op_fn(&self) -> impl Fn(f32) -> i32 {
233        let int_width = self.int_width();
234        move |x| crate::float::float_to_fixedpoint(x, int_width)
235    }
236}
237
238/// Trait for ops that provide conversion operation.
239pub trait HasConversionOp<D: VeScalar, D2: VeScalar>: Clone + Copy {
240    /// Returns the conversion function.
241    fn conversion_op_fn(&self) -> impl Fn(D) -> D2;
242}
243
244impl HasConversionOp<i32, f32> for FxpToFp {
245    fn conversion_op_fn(&self) -> impl Fn(i32) -> f32 {
246        self.op_fn()
247    }
248}
249
250impl HasConversionOp<f32, i32> for FpToFxp {
251    fn conversion_op_fn(&self) -> impl Fn(f32) -> i32 {
252        self.op_fn()
253    }
254}
255
256// ============================================================================
257// Operation functions - Intra-Slice Reduce
258// ============================================================================
259
260/// Lifts a binary reduction function on D to operate on Opt<D>, treating Uninit as the identity element.
261/// TODO: this should be replaced by valid count generator, and Opt<D> should be removed.
262fn lift_reduce_fn<D: Copy>(reduce_fn: impl Fn(D, D) -> D + 'static) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D> {
263    move |a: Opt<D>, b: Opt<D>| match (a, b) {
264        (Opt::Uninit, _) => b,
265        (_, Opt::Uninit) => a,
266        (Opt::Init(x), Opt::Init(y)) => Opt::Init(reduce_fn(x, y)),
267    }
268}
269
270impl IntraSliceReduceOpI32 {
271    /// Returns the raw binary reduction function.
272    pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
273        match self {
274            IntraSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
275            IntraSliceReduceOpI32::Max => |a, b| a.max(b),
276            IntraSliceReduceOpI32::Min => |a, b| a.min(b),
277        }
278    }
279
280    /// Returns a reduction function lifted to [`Opt`], treating `Uninit` as the identity.
281    pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
282        Box::new(lift_reduce_fn(self.reduce_fn()))
283    }
284
285    /// Returns the identity element for reduction.
286    pub fn identity(&self) -> i32 {
287        match self {
288            IntraSliceReduceOpI32::AddSat => 0,
289            IntraSliceReduceOpI32::Max => i32::MIN,
290            IntraSliceReduceOpI32::Min => i32::MAX,
291        }
292    }
293}
294
295impl IntraSliceReduceOpF32 {
296    /// Returns the raw binary reduction function.
297    pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
298        match self {
299            IntraSliceReduceOpF32::Add => |a, b| a + b,
300            IntraSliceReduceOpF32::Max => |a, b| a.max(b),
301            IntraSliceReduceOpF32::Min => |a, b| a.min(b),
302        }
303    }
304
305    /// Returns a reduction function lifted to [`Opt`], treating `Uninit` as the identity.
306    pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
307        Box::new(lift_reduce_fn(self.reduce_fn()))
308    }
309
310    /// Returns the identity element for reduction.
311    pub fn identity(&self) -> f32 {
312        match self {
313            IntraSliceReduceOpF32::Add => 0.0,
314            IntraSliceReduceOpF32::Max => f32::NEG_INFINITY,
315            IntraSliceReduceOpF32::Min => f32::INFINITY,
316        }
317    }
318}
319
320// ============================================================================
321// Operation functions - Inter-Slice Reduce (VRU)
322// ============================================================================
323
324impl InterSliceReduceOpI32 {
325    /// Returns the raw binary reduction function.
326    pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
327        match self {
328            InterSliceReduceOpI32::Add => |a, b| a.wrapping_add(b),
329            InterSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
330            InterSliceReduceOpI32::Max => |a, b| a.max(b),
331            InterSliceReduceOpI32::Min => |a, b| a.min(b),
332        }
333    }
334
335    /// Returns a reduction function lifted to [`Opt`], treating `Uninit` as the identity.
336    pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
337        Box::new(lift_reduce_fn(self.reduce_fn()))
338    }
339}
340
341impl InterSliceReduceOpF32 {
342    /// Returns the raw binary reduction function.
343    pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
344        match self {
345            InterSliceReduceOpF32::Add => |a, b| a + b,
346            InterSliceReduceOpF32::Max => |a, b| a.max(b),
347            InterSliceReduceOpF32::Min => |a, b| a.min(b),
348            InterSliceReduceOpF32::Mul => |a, b| a * b,
349        }
350    }
351
352    /// Returns a reduction function lifted to [`Opt`], treating `Uninit` as the identity.
353    pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
354        Box::new(lift_reduce_fn(self.reduce_fn()))
355    }
356}
357
358// ============================================================================
359// Operation functions - FpDiv
360// ============================================================================
361
362impl FpDivBinaryOp {
363    /// Returns the raw binary operation function.
364    pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
365        match self {
366            FpDivBinaryOp::DivF => |a, b| a / b,
367        }
368    }
369}
370
371impl FpDivOp {
372    /// Returns the binary operation with arg mode applied (Opt version).
373    pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
374        let op = self.op.op_fn();
375        self.mode.apply_opt(op)
376    }
377}
378
379/// Trait for ops that provide unary operation function.
380pub trait HasUnaryOp<D>: Clone + Copy {
381    /// Returns a function that applies this unary operation with the given mode.
382    /// If mode is None, uses the default mode (Mode0).
383    fn unary_op_fn(self, mode: Option<UnaryArgMode>) -> impl Fn(Opt<D>) -> Opt<D>;
384}
385
386/// Trait for ops that provide binary operation function.
387pub trait HasBinaryOp<D>: Clone + Copy {
388    /// Returns a function that applies this binary operation with the given mode.
389    /// If mode is None, uses the default mode (Mode01).
390    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D>;
391}
392
393/// Trait for ops that provide ternary operation function.
394pub trait HasTernaryOp<D>: Clone + Copy {
395    /// Returns a function that applies this ternary operation.
396    fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>;
397}
398
399// ============================================================================
400// Op implementations
401// ============================================================================
402
403impl HasBinaryOp<i32> for LogicBinaryOpI32 {
404    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
405        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
406    }
407}
408
409impl HasBinaryOp<f32> for LogicBinaryOpF32 {
410    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
411        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
412    }
413}
414
415impl HasBinaryOp<i32> for FxpBinaryOp {
416    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
417        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
418    }
419}
420
421impl HasUnaryOp<f32> for FpUnaryOp {
422    fn unary_op_fn(self, mode: Option<UnaryArgMode>) -> impl Fn(Opt<f32>) -> Opt<f32> {
423        mode.unwrap_or(UnaryArgMode::Mode0).apply_opt(self.op_fn())
424    }
425}
426
427impl HasBinaryOp<f32> for FpBinaryOp {
428    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
429        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
430    }
431}
432
433impl HasTernaryOp<f32> for FpTernaryOp {
434    fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32> {
435        mode.unwrap_or(TernaryArgMode::Mode012).apply_opt(self.op_fn())
436    }
437}
438
439impl HasBinaryOp<f32> for FpDivBinaryOp {
440    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
441        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
442    }
443}
444
445impl HasBinaryOp<f32> for FpDivOp {
446    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
447        match mode {
448            Some(mode) => mode.apply_opt(self.op.op_fn()),
449            None => self.binary_op_opt(),
450        }
451    }
452}
453
454impl HasBinaryOp<i32> for ClipBinaryOpI32 {
455    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
456        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
457    }
458}
459
460impl HasBinaryOp<f32> for ClipBinaryOpF32 {
461    fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
462        mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
463    }
464}