furiosa_visa_std/vector_engine/
operand.rs

1//! Operand types for Vector Engine operations.
2//!
3//! This module provides types for specifying operands in VE binary and ternary operations:
4//! - [`VeRhs`]: RHS operand (constant or VRF data) with type safety
5//! - [`StashOperand`]: Stash operand with branch validity (requires matching D type)
6//! - [`TernaryOperand`]: Operand for ternary operations
7//! - [`VeOperand`]: Unified operand type with automatic conversion
8//! - [`IntoOperands`]: Trait for converting operands to ArrayVec
9//! - [`Stash`]: Type-inferred stash marker (compile-time type checked)
10
11use std::marker::PhantomData;
12
13use furiosa_mapping::{M, Pair};
14use furiosa_mapping_macro::primitive;
15
16use crate::{
17    array_vec::ArrayVec,
18    prelude::{GroupId, ValidBranchIds, VrfTensor},
19    tensor::Tensor,
20    vector_engine::{MAX_BRANCHES, scalar::VeScalar},
21};
22
23// ============================================================================
24// VeRhs - Constant or VRF operand (type-safe)
25// ============================================================================
26
27/// RHS operand for Vector Engine operations.
28///
29/// Generic over:
30/// - `D`: Data type (i32 or f32) - ensures type safety with tensor operations
31/// - `TargetMapping`: Target tensor shape for VRF transpose
32#[primitive(op::VeRhs)]
33#[derive(Debug, Clone)]
34pub enum VeRhs<D: VeScalar, TargetMapping: M> {
35    /// Constant value.
36    Const {
37        /// The constant value.
38        v: D,
39    },
40    /// VRF data that has been transposed to match the target tensor shape.
41    Vrf {
42        /// The transposed VRF tensor.
43        data: Tensor<D, TargetMapping>,
44    },
45    /// Read from stash (previously written value).
46    Stash,
47}
48
49impl<D: VeScalar, TargetMapping: M> VeRhs<D, TargetMapping> {
50    /// Creates a constant operand.
51    #[primitive(op::VeRhs::constant)]
52    pub fn constant(v: D) -> Self {
53        VeRhs::Const { v }
54    }
55
56    /// Creates a VeRhs from a VrfTensor, automatically transposing to match the target tensor shape.
57    #[primitive(op::VeRhs::vrf)]
58    pub fn vrf<Chip: M, Cluster: M, Slice: M, Element: M>(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
59        let transposed = vrf.inner.transpose::<TargetMapping>(true);
60        VeRhs::Vrf { data: transposed }
61    }
62}
63
64impl<TargetMapping: M> From<i32> for VeRhs<i32, TargetMapping> {
65    fn from(v: i32) -> Self {
66        VeRhs::Const { v }
67    }
68}
69
70impl<TargetMapping: M> From<f32> for VeRhs<f32, TargetMapping> {
71    fn from(v: f32) -> Self {
72        VeRhs::Const { v }
73    }
74}
75
76impl<D: VeScalar, TargetMapping: M> From<Stash> for VeRhs<D, TargetMapping> {
77    fn from(_: Stash) -> Self {
78        VeRhs::Stash
79    }
80}
81
82impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, TargetMapping: M>
83    From<&VrfTensor<D, Chip, Cluster, Slice, Element>> for VeRhs<D, TargetMapping>
84{
85    fn from(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
86        VeRhs::vrf(vrf)
87    }
88}
89
90// ============================================================================
91// StashOperand - Stash read with branch validity (type-safe)
92// ============================================================================
93
94/// Stash operand for Vector Engine operations.
95#[derive(Debug, Clone)]
96pub struct StashOperand<D: VeScalar> {
97    pub(crate) valid_branch_ids: ValidBranchIds,
98    _phantom: PhantomData<D>,
99}
100
101impl<D: VeScalar> StashOperand<D> {
102    pub(crate) fn always() -> Self {
103        Self {
104            valid_branch_ids: ValidBranchIds::ValidAlways,
105            _phantom: PhantomData,
106        }
107    }
108
109    #[expect(dead_code)]
110    pub(crate) fn group(id: GroupId) -> Self {
111        Self {
112            valid_branch_ids: ValidBranchIds::ValidGroup { id },
113            _phantom: PhantomData,
114        }
115    }
116}
117
118// ============================================================================
119// VeBranchOperand - Operand with branch validity
120// ============================================================================
121
122/// Operand with branch validity for multi-operand cases.
123///
124/// Combines a VeRhs (constant, VRF, or stash) with branch validity.
125#[primitive(op::VeBranchOperand)]
126#[derive(Debug, Clone)]
127pub struct VeBranchOperand<D: VeScalar, TargetMapping: M> {
128    /// The operand value.
129    pub operand: VeRhs<D, TargetMapping>,
130    /// Valid branch IDs for this operand.
131    pub valid_branch_ids: ValidBranchIds,
132}
133
134impl<D: VeScalar, TargetMapping: M> VeBranchOperand<D, TargetMapping> {
135    /// Creates an always-valid operand.
136    #[primitive(op::VeBranchOperand::always)]
137    pub fn always(operand: VeRhs<D, TargetMapping>) -> Self {
138        Self {
139            operand,
140            valid_branch_ids: ValidBranchIds::ValidAlways,
141        }
142    }
143
144    /// Creates a group-specific operand.
145    pub fn group(operand: VeRhs<D, TargetMapping>, id: GroupId) -> Self {
146        Self {
147            operand,
148            valid_branch_ids: ValidBranchIds::ValidGroup { id },
149        }
150    }
151
152    /// Creates an always-valid stash operand.
153    pub fn stash_always() -> Self {
154        Self {
155            operand: VeRhs::Stash,
156            valid_branch_ids: ValidBranchIds::ValidAlways,
157        }
158    }
159
160    /// Creates a group-specific stash operand.
161    pub fn stash_group(id: GroupId) -> Self {
162        Self {
163            operand: VeRhs::Stash,
164            valid_branch_ids: ValidBranchIds::ValidGroup { id },
165        }
166    }
167
168    /// Returns true if this operand uses stash.
169    pub fn is_stash(&self) -> bool {
170        matches!(self.operand, VeRhs::Stash)
171    }
172
173    /// Returns the valid branch IDs for this operand.
174    pub fn valid_branch_ids(&self) -> &ValidBranchIds {
175        &self.valid_branch_ids
176    }
177}
178
179// ============================================================================
180// TernaryOperand - For ternary operations (f32 only)
181// ============================================================================
182
183/// User-facing operand for ternary operations.
184/// Generic over `Mapping` to match the target tensor's mapping from creation time.
185/// Ternary operations are only supported for f32 tensors.
186#[derive(Debug, Clone)]
187pub struct TernaryOperand<Mapping: M> {
188    /// First operand as VeRhs.
189    pub operand0: VeRhs<f32, Mapping>,
190    /// Second operand as f32.
191    pub operand1: f32,
192    /// Valid branch IDs.
193    pub valid_branch_ids: ValidBranchIds,
194}
195
196impl<Mapping: M> TernaryOperand<Mapping> {
197    /// Creates a TernaryOperand always valid.
198    pub fn always(operand0: VeRhs<f32, Mapping>, operand1: f32) -> Self {
199        Self {
200            operand0,
201            operand1,
202            valid_branch_ids: ValidBranchIds::ValidAlways,
203        }
204    }
205
206    /// Creates a TernaryOperand valid for a specific group.
207    pub fn group(operand0: VeRhs<f32, Mapping>, operand1: f32, id: GroupId) -> Self {
208        Self {
209            operand0,
210            operand1,
211            valid_branch_ids: ValidBranchIds::ValidGroup { id },
212        }
213    }
214}
215
216// From implementations for TernaryOperand (enables blanket impl for IntoGroupTernaryOperand)
217
218/// `(Into<VeRhs<f32, Mapping>>, f32)` - VeRhs and constant become TernaryOperand.
219impl<R, Mapping: M> From<(R, f32)> for TernaryOperand<Mapping>
220where
221    R: Into<VeRhs<f32, Mapping>>,
222{
223    fn from((operand0, operand1): (R, f32)) -> Self {
224        TernaryOperand::always(operand0.into(), operand1)
225    }
226}
227
228impl<R, B, Mapping: M> From<((R, f32), B)> for TernaryOperand<Mapping>
229where
230    R: Into<VeRhs<f32, Mapping>>,
231    B: Into<ValidBranchIds>,
232{
233    fn from(((operand0, operand1), branch): ((R, f32), B)) -> Self {
234        TernaryOperand {
235            operand0: operand0.into(),
236            operand1,
237            valid_branch_ids: branch.into(),
238        }
239    }
240}
241
242// ============================================================================
243// IntoTernaryOperands trait (for ternary operations, f32 only)
244// ============================================================================
245
246/// Trait for converting various operand types into an ArrayVec of TernaryOperand.
247///
248/// # Supported operand types
249///
250/// - `(f32, f32)` - two constant values (operand0, operand1)
251/// - `(VeRhs<f32, Mapping>, f32)` - VeRhs and constant
252/// - `TernaryOperand<Mapping>` - single ternary operand
253/// - `[TernaryOperand<Mapping>; N]` - array of ternary operands for multi-branch operations
254///
255/// # Example
256/// ```ignore
257/// // Simple usage with tuple (operand0, operand1)
258/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (2.0f32, 3.0f32))
259///
260/// // With VRF as operand0
261/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (&vrf, 3.0f32))
262///
263/// // With stash as operand0
264/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (Stash, 3.0f32))
265///
266/// // Explicit TernaryOperand for branch control
267/// tensor.vector_fp_ternary(
268///     FpTernaryOp::FmaF,
269///     TernaryOperand::always(VeRhs::constant(2.0f32), 3.0f32)
270/// )
271/// ```
272pub trait IntoTernaryOperands<TargetMapping: M> {
273    /// Converts into an ArrayVec of TernaryOperand.
274    fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES>;
275}
276
277// Blanket impl: Into<TernaryOperand> automatically provides IntoTernaryOperands
278impl<T, TargetMapping: M> IntoTernaryOperands<TargetMapping> for T
279where
280    T: Into<TernaryOperand<TargetMapping>>,
281{
282    fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
283        ArrayVec::new([self.into()])
284    }
285}
286
287/// Array of `TernaryOperand` for multi-branch operations.
288impl<TargetMapping: M, const N: usize> IntoTernaryOperands<TargetMapping> for [TernaryOperand<TargetMapping>; N] {
289    fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
290        // Validate: at most one ValidAlways operand is allowed
291        let always_count = self
292            .iter()
293            .filter(|op| matches!(op.valid_branch_ids, ValidBranchIds::ValidAlways))
294            .count();
295        assert!(
296            always_count <= 1,
297            "Multiple ValidAlways operands are not allowed (found {always_count})"
298        );
299        ArrayVec::new(self)
300    }
301}
302
303/// `ArrayVec<TernaryOperand, MAX_BRANCHES>` passes through.
304impl<TargetMapping: M> IntoTernaryOperands<TargetMapping> for ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
305    fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
306        self
307    }
308}
309
310// ============================================================================
311// From implementations for VeBranchOperand (enables .into() conversion)
312// ============================================================================
313//
314// These implementations allow ergonomic conversion to VeBranchOperand using `.into()`.
315//
316// # Usage for heterogeneous multi-branch operands
317//
318// When you need multiple operands of different types (e.g., constant + stash),
319// use `.into()` to convert each to `VeBranchOperand`, then pass as array:
320//
321// ```ignore
322// // Single operand (homogeneous) - direct usage
323// tensor.vector_fxp(op, 16384i32)
324// tensor.vector_fxp(op, Stash)
325// tensor.vector_fxp(op, &vrf)
326//
327// // Multiple operands of same type
328// tensor.vector_fxp(op, [
329//     VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
330//     VeBranchOperand::group(VeRhs::constant(200), GroupId::Group1),
331// ])
332//
333// // Multiple operands of different types (heterogeneous)
334// // Use .into() to convert each type
335// tensor.vector_fxp(op, [
336//     16384i32.into(),
337//     Stash.into(),
338// ])
339//
340// // With branch control
341// tensor.vector_fxp(op, [
342//     VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
343//     StashOperand::group(GroupId::Group1).into(),
344// ])
345// ```
346
347impl<R, D: VeScalar, Mapping: M> From<R> for VeBranchOperand<D, Mapping>
348where
349    R: Into<VeRhs<D, Mapping>>,
350{
351    fn from(rhs: R) -> Self {
352        VeBranchOperand::always(rhs.into())
353    }
354}
355
356impl<R, B, D: VeScalar, Mapping: M> From<(R, B)> for VeBranchOperand<D, Mapping>
357where
358    R: Into<VeRhs<D, Mapping>>,
359    B: Into<ValidBranchIds>,
360{
361    fn from((rhs, branch): (R, B)) -> Self {
362        VeBranchOperand {
363            operand: rhs.into(),
364            valid_branch_ids: branch.into(),
365        }
366    }
367}
368
369// ============================================================================
370// IntoOperands trait - Multiple operands conversion
371// ============================================================================
372
373/// Trait for converting various operand types into an ArrayVec.
374///
375/// Types implementing `Into<VeBranchOperand>` automatically get this via blanket impl.
376/// Array types `[VeBranchOperand; N]` and `ArrayVec` implement this directly.
377///
378/// # Supported operand types
379///
380/// **Single operand** (via `Into<VeBranchOperand>`, auto-wrapped in ArrayVec):
381/// - `i32`, `f32` - constant value
382/// - `Stash` - stash read marker
383/// - `StashOperand<D>` - stash read with branch validity
384/// - `VeBranchOperand<D, _>` - explicit operand (pass through)
385/// - `&VrfTensor<D, ...>` - VRF tensor reference
386///
387/// **Multiple operands** (direct implementations):
388/// - `[VeBranchOperand<D, _>; N]` - array of operands for multi-branch operations
389/// - `ArrayVec<VeBranchOperand<D, _>, MAX_BRANCHES>` - pass through
390///
391/// # Examples
392///
393/// ```ignore
394/// // Single operand - direct usage
395/// tensor.vector_fxp(op, 16384i32)
396///
397/// // Multiple homogeneous operands
398/// tensor.vector_fxp(op, [
399///     VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
400///     VeBranchOperand::group(VeRhs::constant(200), GroupId::Group1),
401/// ])
402///
403/// // Multiple heterogeneous operands - use .into()
404/// tensor.vector_fxp(op, [
405///     16384i32.into(),
406///     Stash.into(),
407/// ])
408/// ```
409pub trait IntoOperands<D: VeScalar, TargetMapping: M> {
410    /// Converts into an ArrayVec of operands.
411    fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES>;
412}
413
414// Blanket impl: Into<VeBranchOperand> automatically provides IntoOperands
415impl<T, D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping> for T
416where
417    T: Into<VeBranchOperand<D, TargetMapping>>,
418{
419    fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
420        ArrayVec::new([self.into()])
421    }
422}
423
424impl<D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping>
425    for ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES>
426{
427    fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
428        self
429    }
430}
431
432impl<D: VeScalar, TargetMapping: M, const N: usize> IntoOperands<D, TargetMapping>
433    for [VeBranchOperand<D, TargetMapping>; N]
434{
435    fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
436        // Validate: at most one ValidAlways operand is allowed
437        let always_count = self
438            .iter()
439            .filter(|op| matches!(op.valid_branch_ids(), ValidBranchIds::ValidAlways))
440            .count();
441        assert!(
442            always_count <= 1,
443            "Multiple ValidAlways operands are not allowed (found {always_count})"
444        );
445        ArrayVec::new(self)
446    }
447}
448
449// ============================================================================
450// Stash - Type-inferred marker for stash operands (compile-time type checked)
451// ============================================================================
452
453/// Type-inferred stash marker for compile-time type checking.
454///
455/// When used as an operand, the stash data type must match the operation's data type.
456///
457/// # Example
458/// ```ignore
459/// // f32 tensor with f32 stash -> OK
460/// tensor
461///     .vector_stash()
462///     .vector_fp_binary(FpBinaryOp::MulF(FpMulAlu::Mul0), 2.0f32)
463///     .vector_clip(ClipBinaryOpF32::Max, Stash)  // Compiles: D == StashD == f32
464///
465/// ```
466#[primitive(op::Stash)]
467#[derive(Debug, Clone, Copy)]
468pub struct Stash;
469
470// ============================================================================
471// VeOperand - Unified operand type with automatic conversion
472// ============================================================================
473
474/// Unified operand type for Vector Engine operations.
475///
476/// Supports automatic conversion from:
477/// - `D` (i32/f32) - constant value
478/// - `&VrfTensor<D, ...>` - VRF tensor reference
479/// - `Stash` - stash read (always valid)
480///
481/// Use with `impl Into<VeOperand<D, ...>>` for ergonomic API:
482/// ```ignore
483/// .vector_fxp(op, 16384i32)   // i32 auto-converted
484/// .vector_fxp(op, &vrf)       // VRF auto-converted
485/// .vector_fxp(op, Stash)      // Stash (always valid)
486/// ```
487#[derive(Debug)]
488pub enum VeOperand<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> {
489    /// Constant value (always valid).
490    Const(D),
491    /// VRF tensor reference.
492    Vrf(&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>),
493    /// Stash operand.
494    Stash(StashOperand<D>),
495}
496
497// From<i32> for VeOperand<i32, ...>
498impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<i32> for VeOperand<'_, i32, Chip, Cluster, Slice, VrfMapping> {
499    fn from(v: i32) -> Self {
500        VeOperand::Const(v)
501    }
502}
503
504// From<f32> for VeOperand<f32, ...>
505impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<f32> for VeOperand<'_, f32, Chip, Cluster, Slice, VrfMapping> {
506    fn from(v: f32) -> Self {
507        VeOperand::Const(v)
508    }
509}
510
511// From<&VrfTensor<D, ...>> for VeOperand<D, ...>
512impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M>
513    From<&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>> for VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping>
514{
515    fn from(vrf: &'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>) -> Self {
516        VeOperand::Vrf(vrf)
517    }
518}
519
520// From<StashOperand<D>> for VeOperand<D, ...>
521impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<StashOperand<D>>
522    for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
523{
524    fn from(stash: StashOperand<D>) -> Self {
525        VeOperand::Stash(stash)
526    }
527}
528
529// From<Stash> for VeOperand<D, ...> - enables using Stash marker directly
530impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<Stash>
531    for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
532{
533    fn from(_: Stash) -> Self {
534        VeOperand::Stash(StashOperand::always())
535    }
536}
537
538impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping> {
539    /// Converts VeOperand to an ArrayVec of VeBranchOperand with the target tensor mapping.
540    pub fn into_branch_operands<Time: M, Packet: M>(
541        self,
542    ) -> ArrayVec<
543        VeBranchOperand<D, furiosa_opt_macro::m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }]>,
544        MAX_BRANCHES,
545    > {
546        type TargetShape<Chip, Cluster, Slice, Time, Packet> =
547            furiosa_opt_macro::m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }];
548
549        match self {
550            VeOperand::Const(v) => ArrayVec::new([VeBranchOperand::always(VeRhs::Const { v })]),
551            VeOperand::Vrf(vrf) => {
552                let vrf_operand = VeRhs::<D, TargetShape<Chip, Cluster, Slice, Time, Packet>>::vrf(vrf);
553                ArrayVec::new([VeBranchOperand::always(vrf_operand)])
554            }
555            VeOperand::Stash(stash) => ArrayVec::new([VeBranchOperand {
556                operand: VeRhs::Stash,
557                valid_branch_ids: stash.valid_branch_ids,
558            }]),
559        }
560    }
561}
562
563// ============================================================================
564// IntoGroupOperand - Ergonomic operand conversion for VectorTensorPair
565// ============================================================================
566
567/// Type alias for group operand in VectorTensorPair operations.
568/// Generic over `Mapping` to match the target tensor's mapping from creation time.
569pub type GroupOperand<D, Mapping> = Option<VeBranchOperand<D, Mapping>>;
570
571/// Trait for converting various types into a group operand.
572///
573/// Uses `Into<VeBranchOperand>` blanket impl for automatic conversion from
574/// types that implement `From` for `VeBranchOperand` (i32, f32, Stash, etc.).
575///
576/// # Supported types
577/// - `()` - skip operation for this group
578/// - `i32`, `f32` - constant value (via `Into<VeBranchOperand>`)
579/// - `Stash` - stash read marker (via `Into<VeBranchOperand>`)
580/// - `StashOperand<D>` - stash with branch validity (via `Into<VeBranchOperand>`)
581/// - `VeBranchOperand<D, Mapping>` - explicit operand (via `Into<VeBranchOperand>`)
582/// - `Option<VeBranchOperand<D, Mapping>>` - pass through
583pub trait IntoGroupOperand<D: VeScalar, Mapping: M> {
584    /// Converts into a GroupOperand with the specified mapping.
585    fn into_group_operand(self) -> GroupOperand<D, Mapping>;
586}
587
588/// `()` represents skipping the operation for this group.
589impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for () {
590    fn into_group_operand(self) -> GroupOperand<D, Mapping> {
591        None
592    }
593}
594
595/// `Option<VeBranchOperand<D, Mapping>>` passes through.
596impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for Option<VeBranchOperand<D, Mapping>> {
597    fn into_group_operand(self) -> GroupOperand<D, Mapping> {
598        self
599    }
600}
601
602/// Blanket impl: any type that implements `Into<VeBranchOperand>` automatically
603/// implements `IntoGroupOperand` by wrapping in `Some`.
604impl<T, D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for T
605where
606    T: Into<VeBranchOperand<D, Mapping>>,
607{
608    fn into_group_operand(self) -> GroupOperand<D, Mapping> {
609        Some(self.into())
610    }
611}
612
613// ============================================================================
614// IntoGroupTernaryOperand - Ergonomic ternary operand conversion for VectorTensorPair
615// ============================================================================
616
617/// Type alias for group ternary operand in VectorTensorPair operations.
618pub type GroupTernaryOperand<Mapping> = Option<TernaryOperand<Mapping>>;
619
620/// Trait for converting various types into a group ternary operand.
621///
622/// Uses `Into<TernaryOperand>` blanket impl for automatic conversion from
623/// types that implement `From` for `TernaryOperand` ((f32, f32), (VeRhs, f32), etc.).
624///
625/// # Supported types
626/// - `()` - skip operation for this group
627/// - `(f32, f32)` - two constant values (via `Into<TernaryOperand>`)
628/// - `(VeRhs<f32, Mapping>, f32)` - VeRhs and constant (via `Into<TernaryOperand>`)
629/// - `TernaryOperand<Mapping>` - explicit ternary operand (via `Into<TernaryOperand>`)
630/// - `Option<TernaryOperand<Mapping>>` - pass through
631///
632/// # Example
633/// ```ignore
634/// // Apply ternary op only to group0
635/// pair.vector_fp_ternary(op, (2.0f32, 3.0f32), ())
636///
637/// // Apply to both groups with different operands
638/// pair.vector_fp_ternary(op, (2.0f32, 3.0f32), (4.0f32, 5.0f32))
639///
640/// // With stash as operand0 for group0
641/// pair.vector_fp_ternary(op, (Stash, 3.0f32), ())
642/// ```
643pub trait IntoGroupTernaryOperand<Mapping: M> {
644    /// Converts into a GroupTernaryOperand with the specified mapping.
645    fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping>;
646}
647
648/// `()` represents skipping the operation for this group.
649impl<Mapping: M> IntoGroupTernaryOperand<Mapping> for () {
650    fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
651        None
652    }
653}
654
655/// `Option<TernaryOperand<Mapping>>` passes through.
656impl<Mapping: M> IntoGroupTernaryOperand<Mapping> for Option<TernaryOperand<Mapping>> {
657    fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
658        self
659    }
660}
661
662/// Blanket impl: any type that implements `Into<TernaryOperand>` automatically
663/// implements `IntoGroupTernaryOperand` by wrapping in `Some`.
664impl<T, Mapping: M> IntoGroupTernaryOperand<Mapping> for T
665where
666    T: Into<TernaryOperand<Mapping>>,
667{
668    fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
669        Some(self.into())
670    }
671}