furiosa_visa_std/vector_engine/
branch.rs

1//! Branch Unit configuration for Vector Engine.
2
3use std::fmt::{self, Display, Formatter};
4
5use furiosa_mapping::{Atom, Ident, M};
6use furiosa_mapping_macro::primitive;
7use smart_default::SmartDefault;
8
9use crate::scalar::Opt;
10use crate::tensor::Tensor;
11
12use super::scalar::VeScalar;
13
14/// Branch mode configuration for Vector Engine.
15#[primitive(ve::BranchMode)]
16#[derive(Debug, Clone, SmartDefault)]
17pub enum BranchMode {
18    /// No branching - all elements processed unconditionally with ExecutionId = 0.
19    #[default]
20    Unconditional,
21    /// Toggle group id (0/1) based on axis index.
22    AxisToggle {
23        /// Axis identifier to toggle on (e.g., Ident::I).
24        /// The group ID will be determined by (axis_index % 2).
25        axis: Ident,
26    },
27    /// Set branch id using valid count generator.
28    ValidCount,
29    /// Set each branch id bit using comparison operations.
30    Comparison([InputCmp; 4]),
31    /// Load execution IDs from VRF (previously stored by a Comparison pass).
32    /// Maps to npu-ir `GenBranch::WithLog`. Enables cross-TuExec branch reuse.
33    Vrf,
34}
35
36impl Display for BranchMode {
37    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::Unconditional => write!(f, "BranchMode::Unconditional"),
40            Self::AxisToggle { axis } => write!(f, "BranchMode::AxisToggle {{ axis: {axis} }}"),
41            Self::ValidCount => write!(f, "BranchMode::ValidCount"),
42            Self::Comparison(input_cmps) => {
43                write!(f, "BranchMode::Comparison(")?;
44                for (i, cmp) in input_cmps.iter().enumerate() {
45                    if i > 0 {
46                        write!(f, ", ")?;
47                    }
48                    write!(f, "{cmp}")?;
49                }
50                write!(f, ")")
51            }
52            Self::Vrf => write!(f, "BranchMode::Vrf"),
53        }
54    }
55}
56
57/// comparison operations for Vector Engine Branch Unit.
58#[derive(Debug, Clone)]
59pub enum InputCmp {
60    /// i32 comparison
61    I32(InputCmpI32),
62    /// f32 comparison
63    F32(InputCmpF32),
64}
65
66impl Display for InputCmp {
67    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68        match self {
69            Self::I32(input_cmp_i32) => write!(f, "{input_cmp_i32}"),
70            Self::F32(input_cmp_f32) => write!(f, "{input_cmp_f32}"),
71        }
72    }
73}
74
75/// i32 comparison operations
76#[derive(Debug, Clone)]
77pub enum InputCmpI32 {
78    /// set bit if equal to boundary
79    Equal {
80        /// i32 value to compare with.
81        boundary: i32,
82    },
83    /// set bit if less than boundary
84    Less {
85        /// i32 value to compare with.
86        boundary: i32,
87    },
88    /// set bit if greater than boundary
89    Greater {
90        /// i32 value to compare with.
91        boundary: i32,
92    },
93    /// set bit if less than boundary (unsigned)
94    LessUnsigned {
95        /// i32 value to compare with.
96        boundary: i32,
97    },
98    /// set bit if greater than boundary (unsigned)
99    GreaterUnsigned {
100        /// i32 value to compare with.
101        boundary: i32,
102    },
103    /// always true
104    True,
105    /// always false
106    False,
107}
108
109impl Display for InputCmpI32 {
110    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111        match self {
112            Self::Equal { boundary } => write!(f, "={boundary}"),
113            Self::Less { boundary } => write!(f, "<{boundary}"),
114            Self::Greater { boundary } => write!(f, ">{boundary}"),
115            Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
116            Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
117            Self::True => write!(f, "true"),
118            Self::False => write!(f, "false"),
119        }
120    }
121}
122
123/// f32 comparison operations
124#[derive(Debug, Clone)]
125pub enum InputCmpF32 {
126    /// set bit if equal to boundary
127    Equal {
128        /// f32 value to compare with.
129        boundary: f32,
130    },
131    /// set bit if less than boundary
132    Less {
133        /// f32 value to compare with.
134        boundary: f32,
135    },
136    /// set bit if greater than boundary
137    Greater {
138        /// f32 value to compare with.
139        boundary: f32,
140    },
141    /// set bit if less than boundary (unsigned, compares bit representation)
142    LessUnsigned {
143        /// f32 value to compare with.
144        boundary: f32,
145    },
146    /// set bit if greater than boundary (unsigned, compares bit representation)
147    GreaterUnsigned {
148        /// f32 value to compare with.
149        boundary: f32,
150    },
151    /// always true
152    True,
153    /// always false
154    False,
155}
156
157impl Display for InputCmpF32 {
158    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
159        match self {
160            Self::Equal { boundary } => write!(f, "={boundary}"),
161            Self::Less { boundary } => write!(f, "<{boundary}"),
162            Self::Greater { boundary } => write!(f, ">{boundary}"),
163            Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
164            Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
165            Self::True => write!(f, "true"),
166            Self::False => write!(f, "false"),
167        }
168    }
169}
170
171impl InputCmpI32 {
172    /// Check if i32 value matches this comparison
173    pub fn matches(&self, x: i32) -> bool {
174        match self {
175            InputCmpI32::Equal { boundary } => x == *boundary,
176            InputCmpI32::Less { boundary } => x < *boundary,
177            InputCmpI32::Greater { boundary } => x > *boundary,
178            InputCmpI32::LessUnsigned { boundary } => (x as u32) < (*boundary as u32),
179            InputCmpI32::GreaterUnsigned { boundary } => (x as u32) > (*boundary as u32),
180            InputCmpI32::True => true,
181            InputCmpI32::False => false,
182        }
183    }
184}
185
186impl InputCmpF32 {
187    /// Check if f32 value matches this comparison
188    pub fn matches(&self, x: f32) -> bool {
189        match self {
190            InputCmpF32::Equal { boundary } => x == *boundary,
191            InputCmpF32::Less { boundary } => x < *boundary,
192            InputCmpF32::Greater { boundary } => x > *boundary,
193            InputCmpF32::LessUnsigned { boundary } => {
194                let x_bits = x.to_bits();
195                let boundary_bits = boundary.to_bits();
196                x_bits < boundary_bits
197            }
198            InputCmpF32::GreaterUnsigned { boundary } => {
199                let x_bits = x.to_bits();
200                let boundary_bits = boundary.to_bits();
201                x_bits > boundary_bits
202            }
203            InputCmpF32::True => true,
204            InputCmpF32::False => false,
205        }
206    }
207}
208
209impl InputCmp {
210    /// Generic matches method that dispatches to type-specific implementation
211    pub fn matches<D: VeScalar>(&self, x: D) -> bool {
212        use std::any::TypeId;
213        match self {
214            InputCmp::I32(cmp) => {
215                if TypeId::of::<D>() == TypeId::of::<i32>() {
216                    unsafe {
217                        let x_i32 = std::mem::transmute_copy::<D, i32>(&x);
218                        cmp.matches(x_i32)
219                    }
220                } else {
221                    panic!("Type mismatch: InputCmp::I32 used with f32 data")
222                }
223            }
224            InputCmp::F32(cmp) => {
225                if TypeId::of::<D>() == TypeId::of::<f32>() {
226                    unsafe {
227                        let x_f32 = std::mem::transmute_copy::<D, f32>(&x);
228                        cmp.matches(x_f32)
229                    }
230                } else {
231                    panic!("Type mismatch: InputCmp::F32 used with i32 data")
232                }
233            }
234        }
235    }
236}
237
238/// GroupId: msb 1 bit of branch id.
239#[derive(Debug, Clone, PartialEq, Eq)]
240pub enum GroupId {
241    /// Group 0
242    Zero,
243    /// Group 1
244    One,
245}
246
247impl GroupId {
248    /// Returns the bit value of the GroupId.
249    pub fn bit_value(&self) -> u8 {
250        match self {
251            GroupId::Zero => 0,
252            GroupId::One => 1,
253        }
254    }
255}
256
257/// Branch ID configuration for Vector Engine operations.
258///
259/// Controls which elements are processed based on their execution ID (set by branch unit).
260/// The execution ID's MSB (bit 3) represents the group ID (0 or 1).
261///
262/// - `ValidGroup { id }`: Only elements whose group ID matches are processed.
263///   Used for conditional execution based on branch conditions.
264/// - `ValidAlways`: All elements are processed regardless of their branch ID.
265///   This is the default for operations that don't need branching.
266#[primitive(ve::ValidBranchIds)]
267#[derive(Debug, Clone, Default)]
268pub enum ValidBranchIds {
269    /// Valid only for a specific group (filtered by MSB of execution_id).
270    ValidGroup {
271        /// The group ID to filter by.
272        id: GroupId,
273    },
274    /// Always valid regardless of branch ID.
275    #[default]
276    ValidAlways,
277}
278
279impl ValidBranchIds {
280    /// Check if this branch config matches the given execution ID.
281    /// Only Init values can match - Uninit never matches any config.
282    pub fn matches(&self, exec_id: Opt<u8>) -> bool {
283        match (self, exec_id) {
284            (_, Opt::Uninit) => false,
285            (ValidBranchIds::ValidAlways, Opt::Init(_)) => true,
286            (ValidBranchIds::ValidGroup { id }, Opt::Init(eid_val)) => ((eid_val >> 3) & 1) == id.bit_value(),
287        }
288    }
289}
290
291impl From<GroupId> for ValidBranchIds {
292    fn from(id: GroupId) -> Self {
293        ValidBranchIds::ValidGroup { id }
294    }
295}
296
297/// Applies branch unit to generate ExecutionId for each element.
298pub fn apply_branch_config<D: VeScalar, Mapping: M>(
299    data: &Tensor<D, Mapping>,
300    config: &BranchMode,
301) -> Tensor<u8, Mapping> {
302    match config {
303        BranchMode::Unconditional => data.map(|_| Opt::Init(0u8)),
304        BranchMode::AxisToggle { axis } => Tensor::from_fn(|axes, idx| {
305            let axis_pos = axes.iter().position(|term| {
306                if let Atom::Symbol { symbol, .. } = &term.inner {
307                    symbol == axis
308                } else {
309                    false
310                }
311            });
312
313            if let Some(pos) = axis_pos {
314                let axis_idx = idx[pos];
315                let group_id = (axis_idx % 2) as u8;
316                let exec_id = group_id << 3;
317                Opt::Init(exec_id)
318            } else {
319                Opt::Init(0u8)
320            }
321        }),
322        BranchMode::ValidCount => todo!(),
323        BranchMode::Vrf => todo!("BranchMode::Vrf: load execution IDs from VRF (GenBranch::WithLog)"),
324        BranchMode::Comparison(cmps) => data.map(|x| match x {
325            Opt::Init(x) => {
326                let mut exec_id: u8 = 0;
327                for (bit_pos, cmp) in cmps.iter().enumerate() {
328                    let bit = if cmp.matches(*x) { 0x1 } else { 0x0 };
329                    exec_id |= bit << bit_pos;
330                }
331
332                Opt::Init(exec_id)
333            }
334            Opt::Uninit => Opt::Uninit,
335        }),
336    }
337}