furiosa_visa_std/vector_engine/op/
arg_mode.rs

1//! Argument mode types for VE operations.
2//!
3//! Defines how operands are mapped to operation arguments.
4
5use std::fmt::{self, Display, Formatter};
6
7use crate::scalar::Opt;
8use crate::vector_engine::scalar::VeScalar;
9use furiosa_mapping_macro::primitive;
10
11// ============================================================================
12// ArgMode types
13// ============================================================================
14
15/// Arg mode: what operand to use as each argument of the operator.
16#[derive(Debug, Clone, Copy)]
17pub enum ArgMode {
18    /// Unary argument mode.
19    Unary(UnaryArgMode),
20    /// Binary argument mode.
21    Binary(BinaryArgMode),
22    /// Ternary argument mode.
23    Ternary(TernaryArgMode),
24}
25
26/// Unary arg mode.
27/// Mode0: op(mainstream), Mode1: op(operand0)
28#[derive(Debug, Clone, Copy)]
29pub enum UnaryArgMode {
30    /// Use mainstream as the argument: op(mainstream).
31    Mode0,
32    /// Use operand0 as the argument: op(operand0).
33    Mode1,
34}
35
36impl Display for UnaryArgMode {
37    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::Mode0 => write!(f, "UnaryArgMode::Mode0"),
40            Self::Mode1 => write!(f, "UnaryArgMode::Mode1"),
41        }
42    }
43}
44
45impl UnaryArgMode {
46    /// Applies arg mode to a unary operation (Opt version).
47    pub fn apply_opt<D: VeScalar>(&self, op: impl Fn(D) -> D + 'static) -> Box<dyn Fn(Opt<D>) -> Opt<D>> {
48        Box::new(move |x| match x {
49            Opt::Init(x) => Opt::Init(op(x)),
50            Opt::Uninit => Opt::Uninit,
51        })
52    }
53}
54
55/// Binary arg mode.
56/// ModeXY: op(argX, argY) where 0=mainstream, 1=operand0
57#[primitive(op::BinaryArgMode)]
58#[derive(Debug, Clone, Copy)]
59pub enum BinaryArgMode {
60    /// op(mainstream, mainstream).
61    Mode00,
62    /// op(mainstream, operand0).
63    Mode01,
64    /// op(operand0, mainstream).
65    Mode10,
66    /// op(operand0, operand0).
67    Mode11,
68}
69
70impl Display for BinaryArgMode {
71    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
72        match self {
73            Self::Mode00 => write!(f, "BinaryArgMode::Mode00"),
74            Self::Mode01 => write!(f, "BinaryArgMode::Mode01"),
75            Self::Mode10 => write!(f, "BinaryArgMode::Mode10"),
76            Self::Mode11 => write!(f, "BinaryArgMode::Mode11"),
77        }
78    }
79}
80
81impl BinaryArgMode {
82    /// Applies arg mode to a binary operation (Opt version).
83    pub fn apply_opt<D: VeScalar>(&self, op: impl Fn(D, D) -> D + 'static) -> Box<dyn Fn(Opt<D>, Opt<D>) -> Opt<D>> {
84        match self {
85            BinaryArgMode::Mode00 => Box::new(move |a, _b| match a {
86                Opt::Init(a) => Opt::Init(op(a, a)),
87                Opt::Uninit => Opt::Uninit,
88            }),
89            BinaryArgMode::Mode01 => Box::new(move |a, b| match (a, b) {
90                (Opt::Init(a), Opt::Init(b)) => Opt::Init(op(a, b)),
91                _ => Opt::Uninit,
92            }),
93            BinaryArgMode::Mode10 => Box::new(move |a, b| match (a, b) {
94                (Opt::Init(a), Opt::Init(b)) => Opt::Init(op(b, a)),
95                _ => Opt::Uninit,
96            }),
97            BinaryArgMode::Mode11 => Box::new(move |_a, b| match b {
98                Opt::Init(b) => Opt::Init(op(b, b)),
99                Opt::Uninit => Opt::Uninit,
100            }),
101        }
102    }
103}
104
105/// Ternary arg mode.
106/// ModeXYZ: op(argX, argY, argZ) where 0=mainstream, 1=operand0, 2=operand1
107#[derive(Debug, Clone, Copy)]
108pub enum TernaryArgMode {
109    /// op(mainstream, operand0, operand1).
110    Mode012,
111    /// op(mainstream, mainstream, operand1).
112    Mode002,
113    /// op(operand0, mainstream, operand1).
114    Mode102,
115    /// op(operand0, operand0, operand1).
116    Mode112,
117    /// op(mainstream, operand1, mainstream).
118    Mode020,
119    /// op(mainstream, operand1, operand0).
120    Mode021,
121    /// op(operand0, operand1, mainstream).
122    Mode120,
123}
124
125impl Display for TernaryArgMode {
126    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127        match self {
128            Self::Mode012 => write!(f, "TernaryArgMode::Mode012"),
129            Self::Mode002 => write!(f, "TernaryArgMode::Mode002"),
130            Self::Mode102 => write!(f, "TernaryArgMode::Mode102"),
131            Self::Mode112 => write!(f, "TernaryArgMode::Mode112"),
132            Self::Mode020 => write!(f, "TernaryArgMode::Mode020"),
133            Self::Mode021 => write!(f, "TernaryArgMode::Mode021"),
134            Self::Mode120 => write!(f, "TernaryArgMode::Mode120"),
135        }
136    }
137}
138
139impl TernaryArgMode {
140    /// Applies arg mode to a ternary operation (Opt version).
141    pub fn apply_opt<D: VeScalar>(
142        &self,
143        op: impl Fn(D, D, D) -> D + 'static,
144    ) -> Box<dyn Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>> {
145        match self {
146            TernaryArgMode::Mode012 => Box::new(move |m, o0, o1| match (m, o0, o1) {
147                (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o0, o1)),
148                _ => Opt::Uninit,
149            }),
150            TernaryArgMode::Mode002 => Box::new(move |m, _o0, o1| match (m, o1) {
151                (Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, m, o1)),
152                _ => Opt::Uninit,
153            }),
154            TernaryArgMode::Mode102 => Box::new(move |m, o0, o1| match (m, o0, o1) {
155                (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, m, o1)),
156                _ => Opt::Uninit,
157            }),
158            TernaryArgMode::Mode112 => Box::new(move |_m, o0, o1| match (o0, o1) {
159                (Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o0, o1)),
160                _ => Opt::Uninit,
161            }),
162            TernaryArgMode::Mode020 => Box::new(move |m, _o0, o1| match (m, o1) {
163                (Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, o1, m)),
164                _ => Opt::Uninit,
165            }),
166            TernaryArgMode::Mode021 => Box::new(move |m, o0, o1| match (m, o0, o1) {
167                (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o1, o0)),
168                _ => Opt::Uninit,
169            }),
170            TernaryArgMode::Mode120 => Box::new(move |m, o0, o1| match (m, o0, o1) {
171                (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o1, m)),
172                _ => Opt::Uninit,
173            }),
174        }
175    }
176}