1use std::fmt::{self, Display, Formatter};
6
7use crate::scalar::Opt;
8use crate::vector_engine::scalar::VeScalar;
9use furiosa_mapping_macro::primitive;
10
11#[derive(Debug, Clone, Copy)]
17pub enum ArgMode {
18 Unary(UnaryArgMode),
20 Binary(BinaryArgMode),
22 Ternary(TernaryArgMode),
24}
25
26#[derive(Debug, Clone, Copy)]
29pub enum UnaryArgMode {
30 Mode0,
32 Mode1,
34}
35
36impl Display for UnaryArgMode {
37 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::Mode0 => write!(f, "UnaryArgMode::Mode0"),
40 Self::Mode1 => write!(f, "UnaryArgMode::Mode1"),
41 }
42 }
43}
44
45impl UnaryArgMode {
46 pub fn apply_opt<D: VeScalar>(&self, op: impl Fn(D) -> D + 'static) -> Box<dyn Fn(Opt<D>) -> Opt<D>> {
48 Box::new(move |x| match x {
49 Opt::Init(x) => Opt::Init(op(x)),
50 Opt::Uninit => Opt::Uninit,
51 })
52 }
53}
54
55#[primitive(op::BinaryArgMode)]
58#[derive(Debug, Clone, Copy)]
59pub enum BinaryArgMode {
60 Mode00,
62 Mode01,
64 Mode10,
66 Mode11,
68}
69
70impl Display for BinaryArgMode {
71 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
72 match self {
73 Self::Mode00 => write!(f, "BinaryArgMode::Mode00"),
74 Self::Mode01 => write!(f, "BinaryArgMode::Mode01"),
75 Self::Mode10 => write!(f, "BinaryArgMode::Mode10"),
76 Self::Mode11 => write!(f, "BinaryArgMode::Mode11"),
77 }
78 }
79}
80
81impl BinaryArgMode {
82 pub fn apply_opt<D: VeScalar>(&self, op: impl Fn(D, D) -> D + 'static) -> Box<dyn Fn(Opt<D>, Opt<D>) -> Opt<D>> {
84 match self {
85 BinaryArgMode::Mode00 => Box::new(move |a, _b| match a {
86 Opt::Init(a) => Opt::Init(op(a, a)),
87 Opt::Uninit => Opt::Uninit,
88 }),
89 BinaryArgMode::Mode01 => Box::new(move |a, b| match (a, b) {
90 (Opt::Init(a), Opt::Init(b)) => Opt::Init(op(a, b)),
91 _ => Opt::Uninit,
92 }),
93 BinaryArgMode::Mode10 => Box::new(move |a, b| match (a, b) {
94 (Opt::Init(a), Opt::Init(b)) => Opt::Init(op(b, a)),
95 _ => Opt::Uninit,
96 }),
97 BinaryArgMode::Mode11 => Box::new(move |_a, b| match b {
98 Opt::Init(b) => Opt::Init(op(b, b)),
99 Opt::Uninit => Opt::Uninit,
100 }),
101 }
102 }
103}
104
105#[derive(Debug, Clone, Copy)]
108pub enum TernaryArgMode {
109 Mode012,
111 Mode002,
113 Mode102,
115 Mode112,
117 Mode020,
119 Mode021,
121 Mode120,
123}
124
125impl Display for TernaryArgMode {
126 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127 match self {
128 Self::Mode012 => write!(f, "TernaryArgMode::Mode012"),
129 Self::Mode002 => write!(f, "TernaryArgMode::Mode002"),
130 Self::Mode102 => write!(f, "TernaryArgMode::Mode102"),
131 Self::Mode112 => write!(f, "TernaryArgMode::Mode112"),
132 Self::Mode020 => write!(f, "TernaryArgMode::Mode020"),
133 Self::Mode021 => write!(f, "TernaryArgMode::Mode021"),
134 Self::Mode120 => write!(f, "TernaryArgMode::Mode120"),
135 }
136 }
137}
138
139impl TernaryArgMode {
140 pub fn apply_opt<D: VeScalar>(
142 &self,
143 op: impl Fn(D, D, D) -> D + 'static,
144 ) -> Box<dyn Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>> {
145 match self {
146 TernaryArgMode::Mode012 => Box::new(move |m, o0, o1| match (m, o0, o1) {
147 (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o0, o1)),
148 _ => Opt::Uninit,
149 }),
150 TernaryArgMode::Mode002 => Box::new(move |m, _o0, o1| match (m, o1) {
151 (Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, m, o1)),
152 _ => Opt::Uninit,
153 }),
154 TernaryArgMode::Mode102 => Box::new(move |m, o0, o1| match (m, o0, o1) {
155 (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, m, o1)),
156 _ => Opt::Uninit,
157 }),
158 TernaryArgMode::Mode112 => Box::new(move |_m, o0, o1| match (o0, o1) {
159 (Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o0, o1)),
160 _ => Opt::Uninit,
161 }),
162 TernaryArgMode::Mode020 => Box::new(move |m, _o0, o1| match (m, o1) {
163 (Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, o1, m)),
164 _ => Opt::Uninit,
165 }),
166 TernaryArgMode::Mode021 => Box::new(move |m, o0, o1| match (m, o0, o1) {
167 (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o1, o0)),
168 _ => Opt::Uninit,
169 }),
170 TernaryArgMode::Mode120 => Box::new(move |m, o0, o1| match (m, o0, o1) {
171 (Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o1, m)),
172 _ => Opt::Uninit,
173 }),
174 }
175 }
176}