furiosa_visa_std/vector_engine/operand.rs
1//! Operand types for Vector Engine operations.
2//!
3//! This module provides types for specifying operands in VE binary and ternary operations:
4//! - [`VeRhs`]: RHS operand (constant or VRF data) with type safety
5//! - [`StashOperand`]: Stash operand with branch validity (requires matching D type)
6//! - [`TernaryOperand`]: Operand for ternary operations
7//! - [`VeOperand`]: Unified operand type with automatic conversion
8//! - [`IntoOperands`]: Trait for converting operands to ArrayVec
9//! - [`Stash`]: Type-inferred stash marker (compile-time type checked)
10
11use std::marker::PhantomData;
12
13use furiosa_mapping::{M, Pair};
14use furiosa_mapping_macro::primitive;
15
16use crate::{
17 array_vec::ArrayVec,
18 prelude::{GroupId, ValidBranchIds, VrfTensor},
19 tensor::Tensor,
20 vector_engine::{MAX_BRANCHES, scalar::VeScalar},
21};
22
23// ============================================================================
24// VeRhs - Constant or VRF operand (type-safe)
25// ============================================================================
26
27/// RHS operand for Vector Engine operations.
28///
29/// Generic over:
30/// - `D`: Data type (i32 or f32) - ensures type safety with tensor operations
31/// - `TargetMapping`: Target tensor shape for VRF transpose
32#[primitive(op::VeRhs)]
33#[derive(Debug, Clone)]
34pub enum VeRhs<D: VeScalar, TargetMapping: M> {
35 /// Constant value.
36 Const {
37 /// The constant value.
38 v: D,
39 },
40 /// VRF data that has been transposed to match the target tensor shape.
41 Vrf {
42 /// The transposed VRF tensor.
43 data: Tensor<D, TargetMapping>,
44 },
45 /// Read from stash (previously written value).
46 Stash,
47}
48
49impl<D: VeScalar, TargetMapping: M> VeRhs<D, TargetMapping> {
50 /// Creates a constant operand.
51 #[primitive(op::VeRhs::constant)]
52 pub fn constant(v: D) -> Self {
53 VeRhs::Const { v }
54 }
55
56 /// Creates a VeRhs from a VrfTensor, automatically transposing to match the target tensor shape.
57 #[primitive(op::VeRhs::vrf)]
58 pub fn vrf<Chip: M, Cluster: M, Slice: M, Element: M>(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
59 let transposed = vrf.inner.transpose::<TargetMapping>(true);
60 VeRhs::Vrf { data: transposed }
61 }
62}
63
64impl<TargetMapping: M> From<i32> for VeRhs<i32, TargetMapping> {
65 fn from(v: i32) -> Self {
66 VeRhs::Const { v }
67 }
68}
69
70impl<TargetMapping: M> From<f32> for VeRhs<f32, TargetMapping> {
71 fn from(v: f32) -> Self {
72 VeRhs::Const { v }
73 }
74}
75
76impl<D: VeScalar, TargetMapping: M> From<Stash> for VeRhs<D, TargetMapping> {
77 fn from(_: Stash) -> Self {
78 VeRhs::Stash
79 }
80}
81
82impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, TargetMapping: M>
83 From<&VrfTensor<D, Chip, Cluster, Slice, Element>> for VeRhs<D, TargetMapping>
84{
85 fn from(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
86 VeRhs::vrf(vrf)
87 }
88}
89
90// ============================================================================
91// StashOperand - Stash read with branch validity (type-safe)
92// ============================================================================
93
94/// Stash operand for Vector Engine operations.
95#[derive(Debug, Clone)]
96pub struct StashOperand<D: VeScalar> {
97 pub(crate) valid_branch_ids: ValidBranchIds,
98 _phantom: PhantomData<D>,
99}
100
101impl<D: VeScalar> StashOperand<D> {
102 pub(crate) fn always() -> Self {
103 Self {
104 valid_branch_ids: ValidBranchIds::ValidAlways,
105 _phantom: PhantomData,
106 }
107 }
108
109 #[expect(dead_code)]
110 pub(crate) fn group(id: GroupId) -> Self {
111 Self {
112 valid_branch_ids: ValidBranchIds::ValidGroup { id },
113 _phantom: PhantomData,
114 }
115 }
116}
117
118// ============================================================================
119// VeBranchOperand - Operand with branch validity
120// ============================================================================
121
122/// Operand with branch validity for multi-operand cases.
123///
124/// Combines a VeRhs (constant, VRF, or stash) with branch validity.
125#[primitive(op::VeBranchOperand)]
126#[derive(Debug, Clone)]
127pub struct VeBranchOperand<D: VeScalar, TargetMapping: M> {
128 /// The operand value.
129 pub operand: VeRhs<D, TargetMapping>,
130 /// Valid branch IDs for this operand.
131 pub valid_branch_ids: ValidBranchIds,
132}
133
134impl<D: VeScalar, TargetMapping: M> VeBranchOperand<D, TargetMapping> {
135 /// Creates an always-valid operand.
136 #[primitive(op::VeBranchOperand::always)]
137 pub fn always(operand: VeRhs<D, TargetMapping>) -> Self {
138 Self {
139 operand,
140 valid_branch_ids: ValidBranchIds::ValidAlways,
141 }
142 }
143
144 /// Creates a group-specific operand.
145 pub fn group(operand: VeRhs<D, TargetMapping>, id: GroupId) -> Self {
146 Self {
147 operand,
148 valid_branch_ids: ValidBranchIds::ValidGroup { id },
149 }
150 }
151
152 /// Creates an always-valid stash operand.
153 pub fn stash_always() -> Self {
154 Self {
155 operand: VeRhs::Stash,
156 valid_branch_ids: ValidBranchIds::ValidAlways,
157 }
158 }
159
160 /// Creates a group-specific stash operand.
161 pub fn stash_group(id: GroupId) -> Self {
162 Self {
163 operand: VeRhs::Stash,
164 valid_branch_ids: ValidBranchIds::ValidGroup { id },
165 }
166 }
167
168 /// Returns true if this operand uses stash.
169 pub fn is_stash(&self) -> bool {
170 matches!(self.operand, VeRhs::Stash)
171 }
172
173 /// Returns the valid branch IDs for this operand.
174 pub fn valid_branch_ids(&self) -> &ValidBranchIds {
175 &self.valid_branch_ids
176 }
177}
178
179// ============================================================================
180// TernaryOperand - For ternary operations (f32 only)
181// ============================================================================
182
183/// User-facing operand for ternary operations.
184/// Generic over `Mapping` to match the target tensor's mapping from creation time.
185/// Ternary operations are only supported for f32 tensors.
186#[derive(Debug, Clone)]
187pub struct TernaryOperand<Mapping: M> {
188 /// First operand as VeRhs.
189 pub operand0: VeRhs<f32, Mapping>,
190 /// Second operand as f32.
191 pub operand1: f32,
192 /// Valid branch IDs.
193 pub valid_branch_ids: ValidBranchIds,
194}
195
196impl<Mapping: M> TernaryOperand<Mapping> {
197 /// Creates a TernaryOperand always valid.
198 pub fn always(operand0: VeRhs<f32, Mapping>, operand1: f32) -> Self {
199 Self {
200 operand0,
201 operand1,
202 valid_branch_ids: ValidBranchIds::ValidAlways,
203 }
204 }
205
206 /// Creates a TernaryOperand valid for a specific group.
207 pub fn group(operand0: VeRhs<f32, Mapping>, operand1: f32, id: GroupId) -> Self {
208 Self {
209 operand0,
210 operand1,
211 valid_branch_ids: ValidBranchIds::ValidGroup { id },
212 }
213 }
214}
215
216// From implementations for TernaryOperand (enables blanket impl for IntoGroupTernaryOperand)
217
218/// `(Into<VeRhs<f32, Mapping>>, f32)` - VeRhs and constant become TernaryOperand.
219impl<R, Mapping: M> From<(R, f32)> for TernaryOperand<Mapping>
220where
221 R: Into<VeRhs<f32, Mapping>>,
222{
223 fn from((operand0, operand1): (R, f32)) -> Self {
224 TernaryOperand::always(operand0.into(), operand1)
225 }
226}
227
228impl<R, B, Mapping: M> From<((R, f32), B)> for TernaryOperand<Mapping>
229where
230 R: Into<VeRhs<f32, Mapping>>,
231 B: Into<ValidBranchIds>,
232{
233 fn from(((operand0, operand1), branch): ((R, f32), B)) -> Self {
234 TernaryOperand {
235 operand0: operand0.into(),
236 operand1,
237 valid_branch_ids: branch.into(),
238 }
239 }
240}
241
242// ============================================================================
243// IntoTernaryOperands trait (for ternary operations, f32 only)
244// ============================================================================
245
246/// Trait for converting various operand types into an ArrayVec of TernaryOperand.
247///
248/// # Supported operand types
249///
250/// - `(f32, f32)` - two constant values (operand0, operand1)
251/// - `(VeRhs<f32, Mapping>, f32)` - VeRhs and constant
252/// - `TernaryOperand<Mapping>` - single ternary operand
253/// - `[TernaryOperand<Mapping>; N]` - array of ternary operands for multi-branch operations
254///
255/// # Example
256/// ```ignore
257/// // Simple usage with tuple (operand0, operand1)
258/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (2.0f32, 3.0f32))
259///
260/// // With VRF as operand0
261/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (&vrf, 3.0f32))
262///
263/// // With stash as operand0
264/// tensor.vector_fp_ternary(FpTernaryOp::FmaF, (Stash, 3.0f32))
265///
266/// // Explicit TernaryOperand for branch control
267/// tensor.vector_fp_ternary(
268/// FpTernaryOp::FmaF,
269/// TernaryOperand::always(VeRhs::constant(2.0f32), 3.0f32)
270/// )
271/// ```
272pub trait IntoTernaryOperands<TargetMapping: M> {
273 /// Converts into an ArrayVec of TernaryOperand.
274 fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES>;
275}
276
277// Blanket impl: Into<TernaryOperand> automatically provides IntoTernaryOperands
278impl<T, TargetMapping: M> IntoTernaryOperands<TargetMapping> for T
279where
280 T: Into<TernaryOperand<TargetMapping>>,
281{
282 fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
283 ArrayVec::new([self.into()])
284 }
285}
286
287/// Array of `TernaryOperand` for multi-branch operations.
288impl<TargetMapping: M, const N: usize> IntoTernaryOperands<TargetMapping> for [TernaryOperand<TargetMapping>; N] {
289 fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
290 // Validate: at most one ValidAlways operand is allowed
291 let always_count = self
292 .iter()
293 .filter(|op| matches!(op.valid_branch_ids, ValidBranchIds::ValidAlways))
294 .count();
295 assert!(
296 always_count <= 1,
297 "Multiple ValidAlways operands are not allowed (found {always_count})"
298 );
299 ArrayVec::new(self)
300 }
301}
302
303/// `ArrayVec<TernaryOperand, MAX_BRANCHES>` passes through.
304impl<TargetMapping: M> IntoTernaryOperands<TargetMapping> for ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
305 fn into_ternary_operands(self) -> ArrayVec<TernaryOperand<TargetMapping>, MAX_BRANCHES> {
306 self
307 }
308}
309
310// ============================================================================
311// From implementations for VeBranchOperand (enables .into() conversion)
312// ============================================================================
313//
314// These implementations allow ergonomic conversion to VeBranchOperand using `.into()`.
315//
316// # Usage for heterogeneous multi-branch operands
317//
318// When you need multiple operands of different types (e.g., constant + stash),
319// use `.into()` to convert each to `VeBranchOperand`, then pass as array:
320//
321// ```ignore
322// // Single operand (homogeneous) - direct usage
323// tensor.vector_fxp(op, 16384i32)
324// tensor.vector_fxp(op, Stash)
325// tensor.vector_fxp(op, &vrf)
326//
327// // Multiple operands of same type
328// tensor.vector_fxp(op, [
329// VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
330// VeBranchOperand::group(VeRhs::constant(200), GroupId::Group1),
331// ])
332//
333// // Multiple operands of different types (heterogeneous)
334// // Use .into() to convert each type
335// tensor.vector_fxp(op, [
336// 16384i32.into(),
337// Stash.into(),
338// ])
339//
340// // With branch control
341// tensor.vector_fxp(op, [
342// VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
343// StashOperand::group(GroupId::Group1).into(),
344// ])
345// ```
346
347impl<R, D: VeScalar, Mapping: M> From<R> for VeBranchOperand<D, Mapping>
348where
349 R: Into<VeRhs<D, Mapping>>,
350{
351 fn from(rhs: R) -> Self {
352 VeBranchOperand::always(rhs.into())
353 }
354}
355
356impl<R, B, D: VeScalar, Mapping: M> From<(R, B)> for VeBranchOperand<D, Mapping>
357where
358 R: Into<VeRhs<D, Mapping>>,
359 B: Into<ValidBranchIds>,
360{
361 fn from((rhs, branch): (R, B)) -> Self {
362 VeBranchOperand {
363 operand: rhs.into(),
364 valid_branch_ids: branch.into(),
365 }
366 }
367}
368
369// ============================================================================
370// IntoOperands trait - Multiple operands conversion
371// ============================================================================
372
373/// Trait for converting various operand types into an ArrayVec.
374///
375/// Types implementing `Into<VeBranchOperand>` automatically get this via blanket impl.
376/// Array types `[VeBranchOperand; N]` and `ArrayVec` implement this directly.
377///
378/// # Supported operand types
379///
380/// **Single operand** (via `Into<VeBranchOperand>`, auto-wrapped in ArrayVec):
381/// - `i32`, `f32` - constant value
382/// - `Stash` - stash read marker
383/// - `StashOperand<D>` - stash read with branch validity
384/// - `VeBranchOperand<D, _>` - explicit operand (pass through)
385/// - `&VrfTensor<D, ...>` - VRF tensor reference
386///
387/// **Multiple operands** (direct implementations):
388/// - `[VeBranchOperand<D, _>; N]` - array of operands for multi-branch operations
389/// - `ArrayVec<VeBranchOperand<D, _>, MAX_BRANCHES>` - pass through
390///
391/// # Examples
392///
393/// ```ignore
394/// // Single operand - direct usage
395/// tensor.vector_fxp(op, 16384i32)
396///
397/// // Multiple homogeneous operands
398/// tensor.vector_fxp(op, [
399/// VeBranchOperand::group(VeRhs::constant(100), GroupId::Group0),
400/// VeBranchOperand::group(VeRhs::constant(200), GroupId::Group1),
401/// ])
402///
403/// // Multiple heterogeneous operands - use .into()
404/// tensor.vector_fxp(op, [
405/// 16384i32.into(),
406/// Stash.into(),
407/// ])
408/// ```
409pub trait IntoOperands<D: VeScalar, TargetMapping: M> {
410 /// Converts into an ArrayVec of operands.
411 fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES>;
412}
413
414// Blanket impl: Into<VeBranchOperand> automatically provides IntoOperands
415impl<T, D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping> for T
416where
417 T: Into<VeBranchOperand<D, TargetMapping>>,
418{
419 fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
420 ArrayVec::new([self.into()])
421 }
422}
423
424impl<D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping>
425 for ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES>
426{
427 fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
428 self
429 }
430}
431
432impl<D: VeScalar, TargetMapping: M, const N: usize> IntoOperands<D, TargetMapping>
433 for [VeBranchOperand<D, TargetMapping>; N]
434{
435 fn into_operands(self) -> ArrayVec<VeBranchOperand<D, TargetMapping>, MAX_BRANCHES> {
436 // Validate: at most one ValidAlways operand is allowed
437 let always_count = self
438 .iter()
439 .filter(|op| matches!(op.valid_branch_ids(), ValidBranchIds::ValidAlways))
440 .count();
441 assert!(
442 always_count <= 1,
443 "Multiple ValidAlways operands are not allowed (found {always_count})"
444 );
445 ArrayVec::new(self)
446 }
447}
448
449// ============================================================================
450// Stash - Type-inferred marker for stash operands (compile-time type checked)
451// ============================================================================
452
453/// Type-inferred stash marker for compile-time type checking.
454///
455/// When used as an operand, the stash data type must match the operation's data type.
456///
457/// # Example
458/// ```ignore
459/// // f32 tensor with f32 stash -> OK
460/// tensor
461/// .vector_stash()
462/// .vector_fp_binary(FpBinaryOp::MulF(FpMulAlu::Mul0), 2.0f32)
463/// .vector_clip(ClipBinaryOpF32::Max, Stash) // Compiles: D == StashD == f32
464///
465/// ```
466#[primitive(op::Stash)]
467#[derive(Debug, Clone, Copy)]
468pub struct Stash;
469
470// ============================================================================
471// VeOperand - Unified operand type with automatic conversion
472// ============================================================================
473
474/// Unified operand type for Vector Engine operations.
475///
476/// Supports automatic conversion from:
477/// - `D` (i32/f32) - constant value
478/// - `&VrfTensor<D, ...>` - VRF tensor reference
479/// - `Stash` - stash read (always valid)
480///
481/// Use with `impl Into<VeOperand<D, ...>>` for ergonomic API:
482/// ```ignore
483/// .vector_fxp(op, 16384i32) // i32 auto-converted
484/// .vector_fxp(op, &vrf) // VRF auto-converted
485/// .vector_fxp(op, Stash) // Stash (always valid)
486/// ```
487#[derive(Debug)]
488pub enum VeOperand<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> {
489 /// Constant value (always valid).
490 Const(D),
491 /// VRF tensor reference.
492 Vrf(&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>),
493 /// Stash operand.
494 Stash(StashOperand<D>),
495}
496
497// From<i32> for VeOperand<i32, ...>
498impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<i32> for VeOperand<'_, i32, Chip, Cluster, Slice, VrfMapping> {
499 fn from(v: i32) -> Self {
500 VeOperand::Const(v)
501 }
502}
503
504// From<f32> for VeOperand<f32, ...>
505impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<f32> for VeOperand<'_, f32, Chip, Cluster, Slice, VrfMapping> {
506 fn from(v: f32) -> Self {
507 VeOperand::Const(v)
508 }
509}
510
511// From<&VrfTensor<D, ...>> for VeOperand<D, ...>
512impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M>
513 From<&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>> for VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping>
514{
515 fn from(vrf: &'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>) -> Self {
516 VeOperand::Vrf(vrf)
517 }
518}
519
520// From<StashOperand<D>> for VeOperand<D, ...>
521impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<StashOperand<D>>
522 for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
523{
524 fn from(stash: StashOperand<D>) -> Self {
525 VeOperand::Stash(stash)
526 }
527}
528
529// From<Stash> for VeOperand<D, ...> - enables using Stash marker directly
530impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<Stash>
531 for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
532{
533 fn from(_: Stash) -> Self {
534 VeOperand::Stash(StashOperand::always())
535 }
536}
537
538impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping> {
539 /// Converts VeOperand to an ArrayVec of VeBranchOperand with the target tensor mapping.
540 pub fn into_branch_operands<Time: M, Packet: M>(
541 self,
542 ) -> ArrayVec<
543 VeBranchOperand<D, furiosa_opt_macro::m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }]>,
544 MAX_BRANCHES,
545 > {
546 type TargetShape<Chip, Cluster, Slice, Time, Packet> =
547 furiosa_opt_macro::m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }];
548
549 match self {
550 VeOperand::Const(v) => ArrayVec::new([VeBranchOperand::always(VeRhs::Const { v })]),
551 VeOperand::Vrf(vrf) => {
552 let vrf_operand = VeRhs::<D, TargetShape<Chip, Cluster, Slice, Time, Packet>>::vrf(vrf);
553 ArrayVec::new([VeBranchOperand::always(vrf_operand)])
554 }
555 VeOperand::Stash(stash) => ArrayVec::new([VeBranchOperand {
556 operand: VeRhs::Stash,
557 valid_branch_ids: stash.valid_branch_ids,
558 }]),
559 }
560 }
561}
562
563// ============================================================================
564// IntoGroupOperand - Ergonomic operand conversion for VectorTensorPair
565// ============================================================================
566
567/// Type alias for group operand in VectorTensorPair operations.
568/// Generic over `Mapping` to match the target tensor's mapping from creation time.
569pub type GroupOperand<D, Mapping> = Option<VeBranchOperand<D, Mapping>>;
570
571/// Trait for converting various types into a group operand.
572///
573/// Uses `Into<VeBranchOperand>` blanket impl for automatic conversion from
574/// types that implement `From` for `VeBranchOperand` (i32, f32, Stash, etc.).
575///
576/// # Supported types
577/// - `()` - skip operation for this group
578/// - `i32`, `f32` - constant value (via `Into<VeBranchOperand>`)
579/// - `Stash` - stash read marker (via `Into<VeBranchOperand>`)
580/// - `StashOperand<D>` - stash with branch validity (via `Into<VeBranchOperand>`)
581/// - `VeBranchOperand<D, Mapping>` - explicit operand (via `Into<VeBranchOperand>`)
582/// - `Option<VeBranchOperand<D, Mapping>>` - pass through
583pub trait IntoGroupOperand<D: VeScalar, Mapping: M> {
584 /// Converts into a GroupOperand with the specified mapping.
585 fn into_group_operand(self) -> GroupOperand<D, Mapping>;
586}
587
588/// `()` represents skipping the operation for this group.
589impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for () {
590 fn into_group_operand(self) -> GroupOperand<D, Mapping> {
591 None
592 }
593}
594
595/// `Option<VeBranchOperand<D, Mapping>>` passes through.
596impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for Option<VeBranchOperand<D, Mapping>> {
597 fn into_group_operand(self) -> GroupOperand<D, Mapping> {
598 self
599 }
600}
601
602/// Blanket impl: any type that implements `Into<VeBranchOperand>` automatically
603/// implements `IntoGroupOperand` by wrapping in `Some`.
604impl<T, D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for T
605where
606 T: Into<VeBranchOperand<D, Mapping>>,
607{
608 fn into_group_operand(self) -> GroupOperand<D, Mapping> {
609 Some(self.into())
610 }
611}
612
613// ============================================================================
614// IntoGroupTernaryOperand - Ergonomic ternary operand conversion for VectorTensorPair
615// ============================================================================
616
617/// Type alias for group ternary operand in VectorTensorPair operations.
618pub type GroupTernaryOperand<Mapping> = Option<TernaryOperand<Mapping>>;
619
620/// Trait for converting various types into a group ternary operand.
621///
622/// Uses `Into<TernaryOperand>` blanket impl for automatic conversion from
623/// types that implement `From` for `TernaryOperand` ((f32, f32), (VeRhs, f32), etc.).
624///
625/// # Supported types
626/// - `()` - skip operation for this group
627/// - `(f32, f32)` - two constant values (via `Into<TernaryOperand>`)
628/// - `(VeRhs<f32, Mapping>, f32)` - VeRhs and constant (via `Into<TernaryOperand>`)
629/// - `TernaryOperand<Mapping>` - explicit ternary operand (via `Into<TernaryOperand>`)
630/// - `Option<TernaryOperand<Mapping>>` - pass through
631///
632/// # Example
633/// ```ignore
634/// // Apply ternary op only to group0
635/// pair.vector_fp_ternary(op, (2.0f32, 3.0f32), ())
636///
637/// // Apply to both groups with different operands
638/// pair.vector_fp_ternary(op, (2.0f32, 3.0f32), (4.0f32, 5.0f32))
639///
640/// // With stash as operand0 for group0
641/// pair.vector_fp_ternary(op, (Stash, 3.0f32), ())
642/// ```
643pub trait IntoGroupTernaryOperand<Mapping: M> {
644 /// Converts into a GroupTernaryOperand with the specified mapping.
645 fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping>;
646}
647
648/// `()` represents skipping the operation for this group.
649impl<Mapping: M> IntoGroupTernaryOperand<Mapping> for () {
650 fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
651 None
652 }
653}
654
655/// `Option<TernaryOperand<Mapping>>` passes through.
656impl<Mapping: M> IntoGroupTernaryOperand<Mapping> for Option<TernaryOperand<Mapping>> {
657 fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
658 self
659 }
660}
661
662/// Blanket impl: any type that implements `Into<TernaryOperand>` automatically
663/// implements `IntoGroupTernaryOperand` by wrapping in `Some`.
664impl<T, Mapping: M> IntoGroupTernaryOperand<Mapping> for T
665where
666 T: Into<TernaryOperand<Mapping>>,
667{
668 fn into_group_ternary_operand(self) -> GroupTernaryOperand<Mapping> {
669 Some(self.into())
670 }
671}