furiosa_visa_std/vector_engine/
branch.rs1use std::fmt::{self, Display, Formatter};
4
5use furiosa_mapping::{Atom, Ident, M};
6use furiosa_mapping_macro::primitive;
7use smart_default::SmartDefault;
8
9use crate::scalar::Opt;
10use crate::tensor::Tensor;
11
12use super::scalar::VeScalar;
13
14#[primitive(ve::BranchMode)]
16#[derive(Debug, Clone, SmartDefault)]
17pub enum BranchMode {
18 #[default]
20 Unconditional,
21 AxisToggle {
23 axis: Ident,
26 },
27 ValidCount,
29 Comparison([InputCmp; 4]),
31 Vrf,
34}
35
36impl Display for BranchMode {
37 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::Unconditional => write!(f, "BranchMode::Unconditional"),
40 Self::AxisToggle { axis } => write!(f, "BranchMode::AxisToggle {{ axis: {axis} }}"),
41 Self::ValidCount => write!(f, "BranchMode::ValidCount"),
42 Self::Comparison(input_cmps) => {
43 write!(f, "BranchMode::Comparison(")?;
44 for (i, cmp) in input_cmps.iter().enumerate() {
45 if i > 0 {
46 write!(f, ", ")?;
47 }
48 write!(f, "{cmp}")?;
49 }
50 write!(f, ")")
51 }
52 Self::Vrf => write!(f, "BranchMode::Vrf"),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub enum InputCmp {
60 I32(InputCmpI32),
62 F32(InputCmpF32),
64}
65
66impl Display for InputCmp {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 match self {
69 Self::I32(input_cmp_i32) => write!(f, "{input_cmp_i32}"),
70 Self::F32(input_cmp_f32) => write!(f, "{input_cmp_f32}"),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub enum InputCmpI32 {
78 Equal {
80 boundary: i32,
82 },
83 Less {
85 boundary: i32,
87 },
88 Greater {
90 boundary: i32,
92 },
93 LessUnsigned {
95 boundary: i32,
97 },
98 GreaterUnsigned {
100 boundary: i32,
102 },
103 True,
105 False,
107}
108
109impl Display for InputCmpI32 {
110 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111 match self {
112 Self::Equal { boundary } => write!(f, "={boundary}"),
113 Self::Less { boundary } => write!(f, "<{boundary}"),
114 Self::Greater { boundary } => write!(f, ">{boundary}"),
115 Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
116 Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
117 Self::True => write!(f, "true"),
118 Self::False => write!(f, "false"),
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub enum InputCmpF32 {
126 Equal {
128 boundary: f32,
130 },
131 Less {
133 boundary: f32,
135 },
136 Greater {
138 boundary: f32,
140 },
141 LessUnsigned {
143 boundary: f32,
145 },
146 GreaterUnsigned {
148 boundary: f32,
150 },
151 True,
153 False,
155}
156
157impl Display for InputCmpF32 {
158 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
159 match self {
160 Self::Equal { boundary } => write!(f, "={boundary}"),
161 Self::Less { boundary } => write!(f, "<{boundary}"),
162 Self::Greater { boundary } => write!(f, ">{boundary}"),
163 Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
164 Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
165 Self::True => write!(f, "true"),
166 Self::False => write!(f, "false"),
167 }
168 }
169}
170
171impl InputCmpI32 {
172 pub fn matches(&self, x: i32) -> bool {
174 match self {
175 InputCmpI32::Equal { boundary } => x == *boundary,
176 InputCmpI32::Less { boundary } => x < *boundary,
177 InputCmpI32::Greater { boundary } => x > *boundary,
178 InputCmpI32::LessUnsigned { boundary } => (x as u32) < (*boundary as u32),
179 InputCmpI32::GreaterUnsigned { boundary } => (x as u32) > (*boundary as u32),
180 InputCmpI32::True => true,
181 InputCmpI32::False => false,
182 }
183 }
184}
185
186impl InputCmpF32 {
187 pub fn matches(&self, x: f32) -> bool {
189 match self {
190 InputCmpF32::Equal { boundary } => x == *boundary,
191 InputCmpF32::Less { boundary } => x < *boundary,
192 InputCmpF32::Greater { boundary } => x > *boundary,
193 InputCmpF32::LessUnsigned { boundary } => {
194 let x_bits = x.to_bits();
195 let boundary_bits = boundary.to_bits();
196 x_bits < boundary_bits
197 }
198 InputCmpF32::GreaterUnsigned { boundary } => {
199 let x_bits = x.to_bits();
200 let boundary_bits = boundary.to_bits();
201 x_bits > boundary_bits
202 }
203 InputCmpF32::True => true,
204 InputCmpF32::False => false,
205 }
206 }
207}
208
209impl InputCmp {
210 pub fn matches<D: VeScalar>(&self, x: D) -> bool {
212 use std::any::TypeId;
213 match self {
214 InputCmp::I32(cmp) => {
215 if TypeId::of::<D>() == TypeId::of::<i32>() {
216 unsafe {
217 let x_i32 = std::mem::transmute_copy::<D, i32>(&x);
218 cmp.matches(x_i32)
219 }
220 } else {
221 panic!("Type mismatch: InputCmp::I32 used with f32 data")
222 }
223 }
224 InputCmp::F32(cmp) => {
225 if TypeId::of::<D>() == TypeId::of::<f32>() {
226 unsafe {
227 let x_f32 = std::mem::transmute_copy::<D, f32>(&x);
228 cmp.matches(x_f32)
229 }
230 } else {
231 panic!("Type mismatch: InputCmp::F32 used with i32 data")
232 }
233 }
234 }
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq)]
240pub enum GroupId {
241 Zero,
243 One,
245}
246
247impl GroupId {
248 pub fn bit_value(&self) -> u8 {
250 match self {
251 GroupId::Zero => 0,
252 GroupId::One => 1,
253 }
254 }
255}
256
257#[primitive(ve::ValidBranchIds)]
267#[derive(Debug, Clone, Default)]
268pub enum ValidBranchIds {
269 ValidGroup {
271 id: GroupId,
273 },
274 #[default]
276 ValidAlways,
277}
278
279impl ValidBranchIds {
280 pub fn matches(&self, exec_id: Opt<u8>) -> bool {
283 match (self, exec_id) {
284 (_, Opt::Uninit) => false,
285 (ValidBranchIds::ValidAlways, Opt::Init(_)) => true,
286 (ValidBranchIds::ValidGroup { id }, Opt::Init(eid_val)) => ((eid_val >> 3) & 1) == id.bit_value(),
287 }
288 }
289}
290
291impl From<GroupId> for ValidBranchIds {
292 fn from(id: GroupId) -> Self {
293 ValidBranchIds::ValidGroup { id }
294 }
295}
296
297pub fn apply_branch_config<D: VeScalar, Mapping: M>(
299 data: &Tensor<D, Mapping>,
300 config: &BranchMode,
301) -> Tensor<u8, Mapping> {
302 match config {
303 BranchMode::Unconditional => data.map(|_| Opt::Init(0u8)),
304 BranchMode::AxisToggle { axis } => Tensor::from_fn(|axes, idx| {
305 let axis_pos = axes.iter().position(|term| {
306 if let Atom::Symbol { symbol, .. } = &term.inner {
307 symbol == axis
308 } else {
309 false
310 }
311 });
312
313 if let Some(pos) = axis_pos {
314 let axis_idx = idx[pos];
315 let group_id = (axis_idx % 2) as u8;
316 let exec_id = group_id << 3;
317 Opt::Init(exec_id)
318 } else {
319 Opt::Init(0u8)
320 }
321 }),
322 BranchMode::ValidCount => todo!(),
323 BranchMode::Vrf => todo!("BranchMode::Vrf: load execution IDs from VRF (GenBranch::WithLog)"),
324 BranchMode::Comparison(cmps) => data.map(|x| match x {
325 Opt::Init(x) => {
326 let mut exec_id: u8 = 0;
327 for (bit_pos, cmp) in cmps.iter().enumerate() {
328 let bit = if cmp.matches(*x) { 0x1 } else { 0x0 };
329 exec_id |= bit << bit_pos;
330 }
331
332 Opt::Init(exec_id)
333 }
334 Opt::Uninit => Opt::Uninit,
335 }),
336 }
337}