furiosa_visa_std/
stream_tensor.rs

1//! TensorValue streams.
2
3use std::collections::{HashMap, HashSet};
4use std::marker::PhantomData;
5
6use abi_stable::std_types::{RBox, Tuple2};
7use furiosa_mapping::*;
8use furiosa_mapping_macro::primitive;
9use furiosa_opt_macro::m;
10
11/// Size of a single flit in bytes.
12///
13/// Data flows through the switching network in flit-sized units.
14/// Both the collect engine and cast engine normalize packets to exactly one flit.
15const FETCH_ALIGN_BYTES: usize = 8;
16const FLIT_BYTES: usize = 32;
17
18/// Transpose engine input packet size in bytes.
19const TRANSPOSE_INPUT_BYTES: usize = 32;
20
21/// Transpose engine output packet size in bytes.
22const TRANSPOSE_OUTPUT_BYTES: usize = 32;
23
24/// Number of elements fetched per transpose packet for non 4-bit types.
25const TRANSPOSE_ELEMENTS_PER_PACKET_NON_4BIT: usize = 8;
26
27/// Number of elements fetched per transpose packet for 4-bit types.
28const TRANSPOSE_ELEMENTS_PER_PACKET_4BIT: usize = 16;
29
30/// Maximum size for transpose `in_rows` in bytes.
31const TRANSPOSE_MAX_IN_ROWS_BYTES: usize = 8;
32
33/// Valid `in_cols` values for the transpose engine (non 4-bit types).
34const TRANSPOSE_VALID_IN_COLS: &[usize] = &[8, 16, 32];
35
36/// Valid `in_cols` values for the transpose engine (4-bit types).
37const TRANSPOSE_VALID_IN_COLS_4BIT: &[usize] = &[16, 32];
38
39/// Number of columns in the reducer's temporal accumulator.
40const TEMPORAL_ACCUMULATOR_COLS: usize = 32;
41
42/// Number of elements in `accumulate`'s output packet (i32/f32).
43const ACCUMULATE_OUT_PACKET_ELEMENTS: usize = 8;
44
45/// Valid output packet sizes for the commit engine in bytes.
46const COMMIT_OUT_PACKET_SIZES: [usize; 4] = [8, 16, 24, 32];
47
48/// Contraction mode for accumulation.
49#[primitive(AccumulationKind)]
50#[derive(Clone, Debug)]
51pub enum AccumulationKind {
52    /// Interleaved accumulation: Outputs data element-by-element across all Rows.
53    Interleaved,
54    /// Sequential accumulation: Outputs reduced data in each Row sequentially.
55    Sequential,
56}
57
58impl std::fmt::Display for AccumulationKind {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            AccumulationKind::Interleaved => write!(f, "Interleaved"),
62            AccumulationKind::Sequential => write!(f, "Sequential"),
63        }
64    }
65}
66
67use crate::cast::*;
68use crate::context::*;
69use crate::memory_tensor::*;
70use crate::scalar::*;
71use crate::tensor::*;
72use crate::vector_engine::scalar::VeScalar;
73use crate::vector_engine::tensor::VectorInitTensor;
74
75/// Marker trait for pipeline position of stream tensors.
76///
77/// Position does not contain Vector Engine position: VectorTensor has its own typestate.
78pub trait Position: std::fmt::Debug + 'static {}
79
80/// After beginning the pipeline.
81#[derive(Debug)]
82pub struct PositionBegin;
83
84/// After the fetch engine.
85#[derive(Debug)]
86pub struct PositionFetch;
87
88/// After the switch engine.
89#[derive(Debug)]
90pub struct PositionSwitch;
91
92/// After the switch engine's collect engine (32-byte packet normalized).
93#[derive(Debug)]
94pub struct PositionCollect;
95
96/// After the contraction engine.
97#[derive(Debug)]
98pub struct PositionContraction;
99
100/// After the vector engine (vector_final).
101#[derive(Debug)]
102pub struct PositionVectorFinal;
103
104/// After the cast engine.
105#[derive(Debug)]
106pub struct PositionCast;
107
108/// After the transpose engine.
109#[derive(Debug)]
110pub struct PositionTranspose;
111
112impl Position for PositionBegin {}
113impl Position for PositionFetch {}
114impl Position for PositionSwitch {}
115impl Position for PositionCollect {}
116impl Position for PositionContraction {}
117impl Position for PositionVectorFinal {}
118impl Position for PositionCast {}
119impl Position for PositionTranspose {}
120
121/// Tensor streamed through the pipeline.
122#[derive(Debug)]
123pub struct StreamTensor<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> {
124    pub(crate) ctx: &'l mut TuContext<{ T }>,
125    pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Time, Packet>>>>>,
126    _position: PhantomData<P>,
127}
128
129impl<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
130    StreamTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet>
131{
132    /// Mapping type alias.
133    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }];
134
135    /// Creates a new stream tensor.
136    pub fn new(ctx: &'l mut TuContext<{ T }>, inner: Tensor<D, Self::Mapping>) -> Self {
137        Self {
138            ctx,
139            inner,
140            _position: PhantomData,
141        }
142    }
143}
144
145/// Tensor streamed after the beginning.
146pub type BeginTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
147    StreamTensor<'l, { T }, PositionBegin, D, Chip, Cluster, Slice, Time, Packet>;
148
149/// Tensor streamed after the fetch engine.
150pub type FetchTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
151    StreamTensor<'l, { T }, PositionFetch, D, Chip, Cluster, Slice, Time, Packet>;
152
153/// Tensor streamed after the switch engine.
154pub type SwitchTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
155    StreamTensor<'l, { T }, PositionSwitch, D, Chip, Cluster, Slice, Time, Packet>;
156
157/// Tensor after collect engine: packet is exactly 32 bytes (one flit).
158pub type CollectTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
159    StreamTensor<'l, { T }, PositionCollect, D, Chip, Cluster, Slice, Time, Packet>;
160
161/// Pair of aligned tensors ready for contraction (Feed Buffer + TRF Sequencer output).
162#[derive(Debug)]
163pub struct AlignedPair<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M> {
164    ctx: &'l mut TuContext<{ T }>,
165    lhs: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
166    trf: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
167}
168
169/// Intermediate tensor after contraction (LAT reduce within Packet),
170/// before accumulation (accumulator reduce across Time).
171#[derive(Debug)]
172pub struct ContractionTensor<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M> {
173    ctx: &'l mut TuContext<{ T }>,
174    inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
175}
176
177/// Tensor streamed after the contraction engine.
178pub type AccumulationTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
179    StreamTensor<'l, { T }, PositionContraction, D, Chip, Cluster, Slice, Time, Packet>;
180
181/// Tensor after the vector engine (after `vector_final()`).
182pub type VectorFinalTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
183    StreamTensor<'l, { T }, PositionVectorFinal, D, Chip, Cluster, Slice, Time, Packet>;
184
185/// Tensor streamed after the cast engine.
186pub type CastTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
187    StreamTensor<'l, { T }, PositionCast, D, Chip, Cluster, Slice, Time, Packet>;
188
189/// Tensor streamed after the transpose engine.
190pub type TransposeTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
191    StreamTensor<'l, { T }, PositionTranspose, D, Chip, Cluster, Slice, Time, Packet>;
192
193// ANCHOR: fetch_impl
194impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
195    BeginTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
196{
197    /// Performs fetch operation to create a fetched tensor.
198    #[primitive(BeginTensor::fetch)]
199    pub fn fetch<D2: Scalar, Time2: M, Packet2: M>(self) -> FetchTensor<'l, T, D2, Chip, Cluster, Slice, Time2, Packet2>
200    where
201        D: FetchCast<D2>,
202    {
203        assert_eq!(Cluster::SIZE, 2, "Cluster size must be 2, got {}", Cluster::SIZE);
204        assert_eq!(Slice::SIZE, 256, "Slice size must be 256, got {}", Slice::SIZE);
205        let packet_bytes = D2::size_in_bytes_from_length(Packet2::SIZE);
206        assert_eq!(
207            packet_bytes % FETCH_ALIGN_BYTES,
208            0,
209            "Fetch output packet must be {FETCH_ALIGN_BYTES}-byte aligned, got {packet_bytes} bytes.",
210        );
211        FetchTensor::new(self.ctx, self.inner.map(|v| v.map(|v| v.cast())).transpose(true))
212    }
213}
214// ANCHOR_END: fetch_impl
215
216/// Creates a CastTensor after validating cast engine constraints.
217///
218/// The cast engine operates on a single 32-byte flit. The input packet must
219/// be exactly one flit (32 bytes). After casting, the output packet is padded
220/// back to 32 bytes. Time passes through unchanged.
221///
222/// Constraints checked:
223/// 1. Input packet must be exactly one flit (32 bytes).
224/// 2. Output packet must be exactly one flit (32 bytes).
225/// 3. The data terms must match (only padding differs).
226pub(crate) fn cast_stream<
227    'l,
228    const T: Tu,
229    D: VeScalar + Cast<D2>,
230    D2: Scalar,
231    Chip: M,
232    Cluster: M,
233    Slice: M,
234    Time: M,
235    InPacket: M,
236    OutPacket: M,
237>(
238    ctx: &'l mut TuContext<{ T }>,
239    inner: Tensor<D, m![{ Chip }, { Cluster }, { Slice }, { Time }, { InPacket }]>,
240) -> CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket> {
241    // Input packet must be exactly one flit.
242    assert_eq!(
243        D::size_in_bytes_from_length(InPacket::SIZE),
244        FLIT_BYTES,
245        "Cast input packet must be exactly {FLIT_BYTES} bytes (one flit): \
246         {} elements = {} bytes",
247        InPacket::SIZE,
248        D::size_in_bytes_from_length(InPacket::SIZE),
249    );
250
251    let out_flit_elements = D2::length_from_bytes(FLIT_BYTES);
252
253    // Cast elements and pad to 32 bytes.
254    let in_packet = InPacket::to_value().factorize();
255    let expected_packet = in_packet.pad(out_flit_elements).normalize();
256
257    // Output packet must be exactly one flit.
258    let out_packet = OutPacket::to_value().factorize();
259    assert_eq!(
260        D2::size_in_bytes_from_length(OutPacket::SIZE),
261        FLIT_BYTES,
262        "Cast output packet must be exactly {FLIT_BYTES} bytes (one flit). \
263         Expected: {expected_packet}, got: {out_packet}",
264    );
265    assert_eq!(
266        expected_packet, out_packet,
267        "Cast packet mismatch. Expected: {expected_packet}, got: {out_packet}",
268    );
269
270    CastTensor::new(ctx, inner.map(|v| v.map(|v| v.cast())).transpose(false))
271}
272
273/// Configuration for the `switch` operation.
274#[derive(Debug, Clone)]
275pub enum SwitchConfig {
276    /// Replicates data across slices along slice dimensions 0 and 1.
277    /// Slice: \[slice2 | slice1 | slice0\] to \[slice2 | tile | tile\]
278    /// Time:  \[time1 | time0\] to \[time1 | slice1 | time0 | slice0\]
279    Broadcast01 {
280        /// Slice dimension 1 size.
281        slice1: usize,
282        /// Slice dimension 0 size.
283        slice0: usize,
284        /// Time dimension 0 size.
285        time0: usize,
286    },
287    /// Replicates data across slices along slice dimension 1.
288    /// Slice: \[slice2 | slice1 | slice0\] to \[slice2 | tile | slice0\]
289    /// Time:  \[time0\] to \[time0 | slice1\]
290    Broadcast1 {
291        /// Slice dimension 1 size.
292        slice1: usize,
293        /// Slice dimension 0 size.
294        slice0: usize,
295    },
296    /// Swaps slice1 and slice0 dimensions in the slice dimension. Time is unchanged.
297    /// Slice: \[slice2 | slice1 | slice0\] to \[slice2 | slice0 | slice1\]
298    Transpose {
299        /// Slice dimension 1 size.
300        slice1: usize,
301        /// Slice dimension 0 size.
302        slice0: usize,
303    },
304    /// Swaps and transposes between the slice and time dimensions.
305    /// Slice: \[slice2 | slice1 | slice0\] to \[slice2 | time1 | slice0\]
306    /// Time:  \[time2 | time1 | time0\] to \[time2 | time0 | slice1\]
307    InterTranspose {
308        /// Slice dimension 1 size.
309        slice1: usize,
310        /// Slice dimension 0 size.
311        slice0: usize,
312        /// Time dimension 0 size.
313        time0: usize,
314    },
315    /// Routes data across slices using a custom snoop bitmap.
316    /// The bitmap is computed by the compiler from the input shape and topology
317    /// parameters.
318    CustomBroadcast {
319        /// Ring group size for the custom routing.
320        ring_size: usize,
321    },
322    /// Swaps slice1 and slice0 dimensions in the slice dimension and replicates
323    /// data acorss slices along slice dimension 0.  The behavior is equivalent
324    /// to applying `Transpose` and `Broadcast1` at once.
325    /// Slice: \[slice2 | slice1 | slice0\] to \[slice2 | tile | slice1\]
326    /// Time:  \[time0\] to \[time0 | slice0\]
327    TransposedBroadcast1 {
328        /// Slice dimension 1 size.
329        slice1: usize,
330        /// Slice dimension 0 size.
331        slice0: usize,
332    },
333}
334
335/// Gathers all symbols present in a mapping.
336fn extract_symbols(mapping: &Mapping) -> HashSet<Ident> {
337    let mut symbols = HashSet::new();
338    let mapping = mapping.factorize();
339    let mut stack = mapping.clone().into_inner();
340    while let Some(factor) = stack.pop() {
341        if let Factor::Term { inner, .. } = factor {
342            match &inner.inner {
343                Atom::Symbol { symbol, .. } => {
344                    symbols.insert(*symbol);
345                }
346                Atom::Composite(inner) => {
347                    stack.extend(RBox::into_inner(inner.clone()).into_inner());
348                }
349            }
350        }
351    }
352    symbols
353}
354
355/// Validates switch engine constraints:
356/// 1. Switch input and output slice sizes must match.
357///
358/// Delegates to [`SwitchConfig::verify`] for topology-specific checks.
359fn verify_switch<InSlice: M, InTime: M, OutSlice: M, OutTime: M>(config: &SwitchConfig) {
360    assert_eq!(
361        InSlice::SIZE,
362        OutSlice::SIZE,
363        "Switch input and output slice sizes must match, got {} and {}",
364        InSlice::SIZE,
365        OutSlice::SIZE
366    );
367    config.verify::<InSlice, InTime, OutSlice, OutTime>();
368}
369
370impl SwitchConfig {
371    /// Verifies that the switch configuration is compatible with the provided
372    /// input and output slice/time mappings.
373    pub fn verify<InSlice: M, InTime: M, OutSlice: M, OutTime: M>(&self) {
374        match self {
375            SwitchConfig::Broadcast01 { slice1, slice0, time0 } => {
376                assert!(
377                    [slice1, slice0, time0].iter().all(|&x| *x > 0),
378                    "All dimensions must be greater than 0"
379                );
380                assert_eq!(
381                    InSlice::SIZE % (slice1 * slice0),
382                    0,
383                    "InSlice::SIZE must be divisible by (slice1 * slice0)"
384                );
385                assert_eq!(InTime::SIZE % time0, 0, "InTime::SIZE must be divisible by time0");
386
387                // m![InSlice / (slice1 * slice0)] = m![OutSlice / (slice1 * slice0)]
388                assert_eq!(
389                    Mapping::Stride {
390                        inner: RBox::new(InSlice::to_value()),
391                        stride: *slice1 * *slice0,
392                    }
393                    .factorize(),
394                    Mapping::Stride {
395                        inner: RBox::new(OutSlice::to_value()),
396                        stride: *slice1 * *slice0,
397                    }
398                    .factorize(),
399                    "OutSlice must preserve slice2 from InSlice",
400                );
401
402                // Broadcast axes in OutSlice should be new.
403                if *slice1 * *slice0 > 1 {
404                    let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
405                        inner: RBox::new(OutSlice::to_value()),
406                        modulo: *slice1 * *slice0,
407                    });
408
409                    let mut symbols = HashSet::new();
410                    symbols.extend(extract_symbols(&InSlice::to_value()));
411                    symbols.extend(extract_symbols(&InTime::to_value()));
412                    assert_eq!(
413                        out_slice_broadcast_symbols.intersection(&symbols).count(),
414                        0,
415                        "OutSlice broadcast axes must be new axes"
416                    );
417                }
418
419                let mut expected_out_time = Mapping::Identity
420                    // time1 = m![InTime / time0]
421                    .pair(Mapping::Stride {
422                        inner: RBox::new(InTime::to_value()),
423                        stride: *time0,
424                    })
425                    // slice1 = m![InSlice / slice0 % slice1]
426                    .pair(Mapping::Modulo {
427                        inner: RBox::new(Mapping::Stride {
428                            inner: RBox::new(InSlice::to_value()),
429                            stride: *slice0,
430                        }),
431                        modulo: *slice1,
432                    });
433                if *time0 > 1 {
434                    // time0 = m![InTime % time0]
435                    expected_out_time = expected_out_time.pair(Mapping::Modulo {
436                        inner: RBox::new(InTime::to_value()),
437                        modulo: *time0,
438                    })
439                }
440                if *slice0 > 1 {
441                    // slice0 = m![InSlice % slice0]
442                    expected_out_time = expected_out_time.pair(Mapping::Modulo {
443                        inner: RBox::new(InSlice::to_value()),
444                        modulo: *slice0,
445                    })
446                }
447
448                // OutTime must match [time1, slice1, time0, slice0]
449                assert_eq!(
450                    expected_out_time.factorize(),
451                    OutTime::to_value().factorize(),
452                    "OutTime does not match expected layout: [time1, slice1, time0, slice0]"
453                );
454            }
455
456            SwitchConfig::Broadcast1 { slice1, slice0 } => {
457                assert!(
458                    [slice1, slice0].iter().all(|&x| *x > 0),
459                    "All dimensions must be greater than 0"
460                );
461                assert_eq!(
462                    InSlice::SIZE % (slice1 * slice0),
463                    0,
464                    "InSlice::SIZE must be divisible by (slice1 * slice0)"
465                );
466
467                // m![InSlice / (slice1 * slice0)] = m![OutSlice / (tile * slice0)]
468                // Since tile has the same size as slice1, this is: m![InSlice / (slice1 * slice0)] = m![OutSlice / (slice1 * slice0)]
469                assert_eq!(
470                    Mapping::Stride {
471                        inner: RBox::new(InSlice::to_value()),
472                        stride: *slice1 * *slice0,
473                    }
474                    .factorize(),
475                    Mapping::Stride {
476                        inner: RBox::new(OutSlice::to_value()),
477                        stride: *slice1 * *slice0,
478                    }
479                    .factorize(),
480                    "OutSlice must preserve slice2 from InSlice",
481                );
482
483                // Broadcast axes in OutSlice should be new.
484                if *slice1 > 1 {
485                    let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
486                        inner: RBox::new(Mapping::Stride {
487                            inner: RBox::new(OutSlice::to_value()),
488                            stride: *slice0,
489                        }),
490                        modulo: *slice1,
491                    });
492
493                    let mut symbols = HashSet::new();
494                    symbols.extend(extract_symbols(&InSlice::to_value()));
495                    symbols.extend(extract_symbols(&InTime::to_value()));
496                    assert_eq!(
497                        out_slice_broadcast_symbols.intersection(&symbols).count(),
498                        0,
499                        "OutSlice broadcast axes must be new axes"
500                    );
501                }
502
503                // slice0 is preserved at the innermost part of OutSlice
504                if *slice0 > 1 {
505                    assert_eq!(
506                        Mapping::Modulo {
507                            inner: RBox::new(OutSlice::to_value()),
508                            modulo: *slice0,
509                        }
510                        .factorize(),
511                        Mapping::Modulo {
512                            inner: RBox::new(InSlice::to_value()),
513                            modulo: *slice0,
514                        }
515                        .factorize(),
516                        "OutSlice must preserve slice0 from InSlice",
517                    );
518                }
519
520                // OutTime must match [time0, slice1]
521                let mut expected_out_time = Mapping::Identity.pair(InTime::to_value());
522                if *slice1 > 1 {
523                    // slice1 = m![InSlice / slice0 % slice1]
524                    expected_out_time = expected_out_time.pair(Mapping::Modulo {
525                        inner: RBox::new(Mapping::Stride {
526                            inner: RBox::new(InSlice::to_value()),
527                            stride: *slice0,
528                        }),
529                        modulo: *slice1,
530                    });
531                }
532
533                assert_eq!(
534                    expected_out_time.factorize(),
535                    OutTime::to_value().factorize(),
536                    "OutTime does not match expected layout: [time0, slice1]"
537                );
538            }
539
540            SwitchConfig::Transpose { slice1, slice0 } => {
541                assert!(
542                    [slice1, slice0].iter().all(|&x| *x > 0),
543                    "All dimensions must be greater than 0"
544                );
545
546                // Time can be padded, but padding must not change the size.
547                assert_eq!(
548                    InTime::SIZE,
549                    OutTime::SIZE,
550                    "Input and output time dimensions must have the same size"
551                );
552                assert_eq!(
553                    InTime::to_value().factorize().remove_padding(),
554                    OutTime::to_value().factorize().remove_padding(),
555                    "Input and output time dimensions must match (excluding padding)"
556                );
557
558                // OutSlice must match [slice2, slice0, slice1] (slice1 and slice0 are swapped)
559                assert_eq!(
560                    Mapping::Stride {
561                        inner: RBox::new(InSlice::to_value()),
562                        stride: *slice1 * *slice0,
563                    }
564                    .pair(Mapping::Modulo {
565                        inner: RBox::new(InSlice::to_value()),
566                        modulo: *slice0,
567                    })
568                    .pair(Mapping::Modulo {
569                        inner: RBox::new(Mapping::Stride {
570                            inner: RBox::new(InSlice::to_value()),
571                            stride: *slice0,
572                        }),
573                        modulo: *slice1,
574                    })
575                    .factorize(),
576                    OutSlice::to_value().factorize(),
577                    "OutSlice does not match expected layout: [slice2, slice0, slice1]"
578                );
579            }
580
581            SwitchConfig::InterTranspose { slice1, slice0, time0 } => {
582                assert!(
583                    [slice1, slice0, time0].iter().all(|&x| *x > 0),
584                    "All dimensions must be greater than 0"
585                );
586
587                assert_eq!(
588                    InSlice::SIZE % (slice1 * slice0),
589                    0,
590                    "InSlice::SIZE must be divisible by (slice1 * slice0)"
591                );
592                assert_eq!(
593                    InTime::SIZE % (slice1 * time0),
594                    0,
595                    "InTime::SIZE must be divisible by (slice1 * time0)"
596                );
597
598                let slice2 = InSlice::SIZE / (slice1 * slice0);
599                let time2 = InTime::SIZE / (slice1 * time0);
600
601                assert_eq!(slice2 * slice1 * slice0, 256, "All dimensions should multiply to 256");
602                assert_eq!(
603                    time2 * slice1 * time0,
604                    InTime::SIZE,
605                    "time2 * slice1 * time0 must equal InTime::SIZE"
606                );
607
608                // InSlice  = [slice2, slice1,  slice0]
609                // OutSlice = [slice2, time1, slice0]
610                // slice2 and slice0 are preserved
611                assert_eq!(
612                    Mapping::Stride {
613                        inner: RBox::new(OutSlice::to_value()),
614                        stride: *slice1 * *slice0,
615                    }
616                    .factorize(),
617                    Mapping::Stride {
618                        inner: RBox::new(InSlice::to_value()),
619                        stride: *slice1 * *slice0,
620                    }
621                    .factorize(),
622                    "OutSlice must preserve slice2 from InSlice",
623                );
624                if *slice0 > 1 {
625                    assert_eq!(
626                        Mapping::Modulo {
627                            inner: RBox::new(OutSlice::to_value()),
628                            modulo: *slice0,
629                        }
630                        .factorize(),
631                        Mapping::Modulo {
632                            inner: RBox::new(InSlice::to_value()),
633                            modulo: *slice0,
634                        }
635                        .factorize(),
636                        "OutSlice must preserve slice0 from InSlice",
637                    );
638                }
639
640                // InTime   = [time2,  time1, time0]
641                // OutSlice = [slice2, time1, slice0]
642                // time1 comes from InTime
643                if *slice1 > 1 {
644                    assert_eq!(
645                        Mapping::Modulo {
646                            inner: RBox::new(Mapping::Stride {
647                                inner: RBox::new(OutSlice::to_value()),
648                                stride: *slice0,
649                            }),
650                            modulo: *slice1,
651                        }
652                        .factorize(),
653                        Mapping::Modulo {
654                            inner: RBox::new(Mapping::Stride {
655                                inner: RBox::new(InTime::to_value()),
656                                stride: *time0,
657                            }),
658                            modulo: *slice1,
659                        }
660                        .factorize(),
661                        "OutSlice time1 must come from InTime"
662                    );
663                }
664
665                // InTime  = [time2, time1, time0]
666                // OutTime = [time2, time0,  slice1]
667                // time2 is preserved
668                assert_eq!(
669                    Mapping::Stride {
670                        inner: RBox::new(OutTime::to_value()),
671                        stride: *slice1 * time0,
672                    }
673                    .factorize(),
674                    Mapping::Stride {
675                        inner: RBox::new(InTime::to_value()),
676                        stride: *time0 * slice1,
677                    }
678                    .factorize(),
679                    "OutTime must preserve 'time2' from InTime",
680                );
681                // time0 is moved from InTime to OutTime
682                if *time0 > 1 {
683                    assert_eq!(
684                        Mapping::Modulo {
685                            inner: RBox::new(Mapping::Stride {
686                                inner: RBox::new(OutTime::to_value()),
687                                stride: *slice1,
688                            }),
689                            modulo: *time0,
690                        }
691                        .factorize(),
692                        Mapping::Modulo {
693                            inner: RBox::new(InTime::to_value()),
694                            modulo: *time0,
695                        }
696                        .factorize(),
697                        "OutTime must preserve 'time0' from InTime"
698                    );
699                }
700                // InSlice = [slice2, slice1, slice0]
701                // OutTime = [time2,  time0,   slice1]
702                // slice1 in OutTime matches the one in InSlice
703                if *slice1 > 1 {
704                    assert_eq!(
705                        Mapping::Modulo {
706                            inner: RBox::new(OutTime::to_value()),
707                            modulo: *slice1,
708                        }
709                        .factorize(),
710                        Mapping::Modulo {
711                            inner: RBox::new(Mapping::Stride {
712                                inner: RBox::new(InSlice::to_value()),
713                                stride: *slice0,
714                            }),
715                            modulo: *slice1,
716                        }
717                        .factorize(),
718                        "OutTime must preserve 'slice1' from InSlice"
719                    );
720                }
721            }
722
723            // Custom topologies allow:
724            // 1. Arbitrary slice permutations in Slice.
725            // 2. Slice to Time broadcasts with slicing.
726
727            // Constraints checked:
728            // 1. Broadcast axes must use new unique axes (not in input `Slice` or `Time`).
729            // 2. Broadcast axes should not be padded.
730            // 3. Outer portion of output `Time` must match input `Time`.
731            // 4. Axes moving from `Slice` to `Time` appear at the innermost output `Time` positions.
732            // 5. Axes moving from `Slice` to `Time` must preserve their relative order from input `Slice`.
733            // 6. Ring size must match the outermost non-directcast boundary and be a power of 2.
734            SwitchConfig::CustomBroadcast { ring_size } => {
735                assert!(
736                    ring_size.is_power_of_two(),
737                    "Switch ring size must be a power of 2, got {ring_size}"
738                );
739
740                let slice = InSlice::to_value().factorize();
741                let out_slice = OutSlice::to_value().factorize();
742                let time = InTime::to_value().factorize();
743                let out_time = OutTime::to_value().factorize();
744
745                // Identify broadcast axes. Broadcast axes use new axes, not
746                // present in InSlice.
747                let furiosa_mapping::Division {
748                    dividend_residue,
749                    divisor_residue,
750                    division_terms,
751                    ..
752                } = out_slice.clone().divide_strict(slice.clone());
753                assert!(
754                    dividend_residue
755                        .clone()
756                        .divide_strict(slice.clone().mul(time.clone()))
757                        .division_terms()
758                        .is_empty(),
759                    "Switch broadcast axes must be new axes (not present in input Slice or Time)."
760                );
761
762                // Each broadcast axis must be used exactly once in OutSlice.
763                // `terms_with_stride` counts all terms (including duplicates),
764                // `idents` returns unique axes. A mismatch means an axis
765                // appears more than once.
766                assert_eq!(
767                    dividend_residue.terms_with_stride().len(),
768                    dividend_residue.idents().len(),
769                    "Switch broadcast axes must each be used exactly once in OutSlice"
770                );
771
772                // Broadcast axes must not have padding.
773                let factors = out_slice.factors();
774                for (i, factor) in factors.iter().enumerate() {
775                    if let Factor::Term { inner, .. } = factor
776                        && !division_terms.iter().any(|d| d.term.inner == inner.inner)
777                        && matches!(factors.get(i + 1), Some(Factor::Padding { .. }))
778                    {
779                        panic!("Switch broadcast axis {inner} in output Slice must not be padded.");
780                    }
781                }
782
783                // OutTime = [InTime (outer) | moved axes (inner)].
784                let Tuple2(outer, inner) =
785                    out_time.split_at(exact_div(OutTime::SIZE, InTime::SIZE).unwrap_or_else(|| {
786                        panic!(
787                            "Input Time size ({}) does not divide Output Time size ({})",
788                            InTime::SIZE,
789                            OutTime::SIZE
790                        )
791                    }));
792                assert_eq!(
793                    outer, time,
794                    "Switch axes moving from input slice to output time must be at the output time innermost positions. \
795                     Expected outer portion of output time to be {time}, got {outer}"
796                );
797
798                let broadcast_divisions = divisor_residue
799                    .clone()
800                    .divide_relaxed(inner.clone())
801                    .exact()
802                    .unwrap_or_else(|_| {
803                        panic!(
804                            "Switch broadcast axes in output time must come from input slice. \
805                            Input Slice: {slice}, inner part of output Time: {inner}"
806                        )
807                    })
808                    .division_terms;
809
810                // Axes moving from Slice to Time must preserve their relative
811                // order from input Slice.
812                for window in broadcast_divisions.windows(2) {
813                    assert!(
814                        window[0].divisor_stride > window[1].divisor_stride,
815                        "Switch axes moving from input Slice to output Time must preserve their relative order from input Slice. \
816                         {} (outer in input Slice) appears inner to {} in output Time",
817                        window[0].term,
818                        window[1].term
819                    );
820                }
821
822                // Find the span of the outermost non-directcast axis in the
823                // input Slice.
824                let mut max_non_dc = 0;
825                for d in &division_terms {
826                    if d.dividend_stride != d.divisor_stride {
827                        max_non_dc = max_non_dc.max(d.divisor_stride * d.resize);
828                    }
829                }
830                for t in &divisor_residue.terms_with_stride() {
831                    max_non_dc = max_non_dc.max(t.stride * t.resize);
832                }
833
834                // Find the directcast axis with the smallest stride that is at
835                // or above `max_non_dc`.
836                let mut expected_ring_size = InSlice::SIZE;
837                for d in &division_terms {
838                    if d.dividend_stride == d.divisor_stride && d.divisor_stride >= max_non_dc {
839                        expected_ring_size = expected_ring_size.min(d.divisor_stride);
840                    }
841                }
842                assert_eq!(
843                    *ring_size, expected_ring_size,
844                    "Switch ring size mismatch. Expected {expected_ring_size}, got {ring_size}"
845                );
846            }
847
848            SwitchConfig::TransposedBroadcast1 { slice1, slice0 } => {
849                assert!(
850                    [slice1, slice0].iter().all(|&x| *x > 0),
851                    "All dimensions must be greater than 0"
852                );
853
854                assert_eq!(
855                    InSlice::SIZE % (slice1 * slice0),
856                    0,
857                    "InSlice::SIZE must be divisible by (slice1 * slice0)"
858                );
859
860                // m![InSlice / (slice1 * slice0)] = m![OutSlice / (tile * slice1)]
861                // Since tile has the same size as slice0, this is: m![InSlice / (slice1 * slice0)] = m![OutSlice / (tile * slice1)]
862                assert_eq!(
863                    Mapping::Stride {
864                        inner: RBox::new(InSlice::to_value()),
865                        stride: *slice1 * *slice0,
866                    }
867                    .factorize(),
868                    Mapping::Stride {
869                        inner: RBox::new(OutSlice::to_value()),
870                        stride: *slice0 * *slice1,
871                    }
872                    .factorize(),
873                    "OutSlice must preserve slice2 from InSlice",
874                );
875
876                // Broadcast axes in OutSlice should be new.
877                if *slice0 > 1 {
878                    let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
879                        inner: RBox::new(Mapping::Stride {
880                            inner: RBox::new(OutSlice::to_value()),
881                            stride: *slice1,
882                        }),
883                        modulo: *slice0,
884                    });
885
886                    let mut symbols = HashSet::new();
887                    symbols.extend(extract_symbols(&InSlice::to_value()));
888                    symbols.extend(extract_symbols(&InTime::to_value()));
889                    assert_eq!(
890                        out_slice_broadcast_symbols.intersection(&symbols).count(),
891                        0,
892                        "OutSlice broadcast axes must be new axes"
893                    );
894                }
895
896                // slice1 is preserved at the innermost part of OutSlice
897                if *slice1 > 1 {
898                    assert_eq!(
899                        Mapping::Modulo {
900                            inner: RBox::new(OutSlice::to_value()),
901                            modulo: *slice1,
902                        }
903                        .factorize(),
904                        Mapping::Modulo {
905                            inner: RBox::new(Mapping::Stride {
906                                inner: RBox::new(InSlice::to_value()),
907                                stride: *slice0
908                            }),
909                            modulo: *slice1,
910                        }
911                        .factorize(),
912                        "OutSlice must preserve slice1 from InSlice",
913                    );
914                }
915
916                // OutTime must match [time0, slice0]
917                let mut expected_out_time = Mapping::Identity.pair(InTime::to_value());
918                if *slice0 > 1 {
919                    // slice0 = m![InSlice % slice0]
920                    expected_out_time = expected_out_time.pair(Mapping::Modulo {
921                        inner: RBox::new(InSlice::to_value()),
922                        modulo: *slice0,
923                    });
924                }
925
926                assert_eq!(
927                    expected_out_time.factorize(),
928                    OutTime::to_value().factorize(),
929                    "OutTime does not match expected layout: [time0, slice0]"
930                );
931            }
932        }
933    }
934}
935
936/// Validates collect engine constraints: normalizes packet to exactly one flit (32 bytes).
937///
938/// Pads the input packet to flit-aligned boundary, then splits:
939/// - Inner 32 bytes → Packet2 (one flit)
940/// - Outer flit portion → absorbed into Time2
941///
942/// For packets already ≤ 32 bytes, only padding is added.
943fn verify_collect<D: Scalar, Time: M, Packet: M, Time2: M, Packet2: M>() {
944    let in_packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
945    let aligned_bytes = align_up(in_packet_bytes, FLIT_BYTES);
946    let flit_elements = D::length_from_bytes(FLIT_BYTES);
947
948    // Output packet must be exactly one flit.
949    assert_eq!(
950        D::size_in_bytes_from_length(Packet2::SIZE),
951        FLIT_BYTES,
952        "Collect output packet must be exactly {FLIT_BYTES} bytes (one flit)."
953    );
954
955    // Pad input packet to flit-aligned boundary, then split at flit boundary.
956    let in_factorized = Packet::to_value().factorize();
957    let padded = in_factorized.pad(D::length_from_bytes(aligned_bytes));
958    let Tuple2(in_outer, in_flit) = padded.split_at(flit_elements);
959
960    // Output packet = inner flit.
961    let expected_packet = in_flit.normalize();
962    let out_packet = Packet2::to_value().factorize();
963    assert_eq!(
964        expected_packet, out_packet,
965        "Collect packet mismatch. Expected: {expected_packet}, got: {out_packet}"
966    );
967
968    // Time2 = Time × outer flit portion.
969    let expected_time = Time::to_value().factorize().mul(in_outer).normalize();
970    let out_time = Time2::to_value().factorize();
971    assert_eq!(
972        expected_time, out_time,
973        "Collect time mismatch. Expected: {expected_time}, got: {out_time}"
974    );
975}
976
977/// Validates hardware constraints for the transpose engine.
978///
979/// Constraints checked:
980/// 1. `Packet` and `OutPacket` must be 32 bytes
981/// 2. `in_rows` * sizeof(D) <= 8 bytes
982/// 3. `in_cols` must be 8, 16, or 32 (4-bit: 16 or 32 only)
983/// 4. `out_rows` <= `in_cols`
984fn verify_transpose<D: Scalar, Time: M, Packet: M, OutTime: M, OutPacket: M>() {
985    // Packet must be TRANSPOSE_INPUT_BYTES bytes.
986    let packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
987    assert_eq!(
988        packet_bytes, TRANSPOSE_INPUT_BYTES,
989        "Transpose input packet must be {TRANSPOSE_INPUT_BYTES} bytes, got {packet_bytes}"
990    );
991
992    // OutPacket must be TRANSPOSE_OUTPUT_BYTES bytes.
993    let out_packet_bytes = D::size_in_bytes_from_length(OutPacket::SIZE);
994    assert_eq!(
995        out_packet_bytes, TRANSPOSE_OUTPUT_BYTES,
996        "Transpose output packet must be {TRANSPOSE_OUTPUT_BYTES} bytes, got {out_packet_bytes}",
997    );
998
999    // OutPacket is `[in_rows # padding]`.
1000    // `in_rows` * sizeof(D) <= 8 bytes
1001    let in_rows = OutPacket::to_value().factorize().remove_padding();
1002    let in_rows_bytes = D::size_in_bytes_from_length(in_rows.size());
1003    assert!(
1004        in_rows_bytes <= TRANSPOSE_MAX_IN_ROWS_BYTES,
1005        "Transpose `in_rows` must be <= {TRANSPOSE_MAX_IN_ROWS_BYTES} bytes, got {in_rows_bytes}"
1006    );
1007
1008    // `Time = [..., in_rows, packets_per_col]`
1009    // Check that `in_rows` matches OutPacket `in_rows`.
1010    let time = Time::to_value().factorize();
1011    let division_terms = time
1012        .clone()
1013        .divide_relaxed(in_rows.clone())
1014        .exact()
1015        .unwrap_or_else(|_| panic!("Transpose `in_rows` ({in_rows}) must be present in the input Time ({time})"))
1016        .division_terms;
1017
1018    // `in_cols` = `packets_per_col` * `elements_per_packet` must be in {8, 16, 32} (4-bit: {16, 32})
1019    let Tuple2(time_outer, packets_per_col) = time.split_at(division_terms[0].dividend_stride);
1020    let Tuple2(time_outer, _) = time_outer.split_at(in_rows.size());
1021    let (elements_per_packet, valid_in_cols) = match D::BITS {
1022        4 => (TRANSPOSE_ELEMENTS_PER_PACKET_4BIT, TRANSPOSE_VALID_IN_COLS_4BIT),
1023        _ => (TRANSPOSE_ELEMENTS_PER_PACKET_NON_4BIT, TRANSPOSE_VALID_IN_COLS),
1024    };
1025    let in_cols = packets_per_col.size() * elements_per_packet;
1026    assert!(
1027        valid_in_cols.contains(&in_cols),
1028        "Transpose `in_cols` size ({in_cols}) must be one of {valid_in_cols:?} for {}-bit type",
1029        D::BITS,
1030    );
1031
1032    let elements_per_packet = Packet::to_value().factorize().pad(elements_per_packet);
1033    let out_time = OutTime::to_value().factorize();
1034
1035    // `OutTime = [time_outer, packets_per_col, elements_per_packet]`.
1036    // `elements_per_packet` may be sliced: `[in_cols x in_rows] → [out_rows x in_rows]`.
1037    // Make sure only padding is removed.
1038    let Tuple2(out_time_outer, out_time_elements_per_packet) = if packets_per_col.size() > 1 {
1039        let division_terms = out_time
1040            .clone()
1041            .divide_relaxed(packets_per_col.clone())
1042            .exact()
1043            .unwrap_or_else(|_| {
1044                panic!("Transpose `packets_per_col` ({packets_per_col}) not found in OutTime ({out_time})")
1045            })
1046            .division_terms;
1047        let Tuple2(outer, num_elems) = out_time.split_at(division_terms[0].dividend_stride);
1048        let Tuple2(outer, _) = outer.split_at(packets_per_col.size());
1049        Tuple2(outer, num_elems)
1050    } else {
1051        let out_rows = OutTime::SIZE / time_outer.size();
1052        out_time.split_at(out_rows)
1053    };
1054    let out_rows = packets_per_col
1055        .clone()
1056        .mul(out_time_elements_per_packet.clone())
1057        .normalize();
1058    let in_cols = packets_per_col.mul(elements_per_packet.clone()).normalize();
1059    assert!(
1060        out_rows.size() <= in_cols.size(),
1061        "Transpose `out_rows` ({}) must be <= `in_cols` ({})",
1062        out_rows.size(),
1063        in_cols.size(),
1064    );
1065    assert_eq!(
1066        out_time_elements_per_packet.remove_padding(),
1067        elements_per_packet.remove_padding(),
1068        "Transpose `out_rows` ({out_rows}) must match `in_cols` ({in_cols}) (excluding padding)",
1069    );
1070
1071    // Outer Time axes should match outer OutTime axes.
1072    assert_eq!(
1073        time_outer, out_time_outer,
1074        "Transpose time mismatch: expected outer OutTime to be {time_outer}, got ({out_time_outer})"
1075    );
1076}
1077
1078// ANCHOR: switch_impl
1079impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1080    FetchTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1081{
1082    /// Performs switching operation to create a switched tensor.
1083    ///
1084    /// Applies switching network routing only. The packet passes through
1085    /// unchanged — no padding, no reshaping. Use [`SwitchTensor::collect`]
1086    /// afterwards to normalize the packet to flit-sized chunks.
1087    #[primitive(FetchTensor::switch)]
1088    pub fn switch<Slice2: M, Time2: M>(
1089        self,
1090        config: SwitchConfig,
1091    ) -> SwitchTensor<'l, T, D, Chip, Cluster, Slice2, Time2, Packet> {
1092        verify_switch::<Slice, Time, Slice2, Time2>(&config);
1093        SwitchTensor::new(self.ctx, self.inner.transpose(true))
1094    }
1095
1096    /// Skips the switching network and goes directly to collect.
1097    ///
1098    /// Slice and Time are preserved from fetch; only the packet is normalized
1099    /// to flit-sized chunks.
1100    #[primitive(FetchTensor::collect)]
1101    pub fn collect<Time2: M, Packet2: M>(self) -> CollectTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1102        verify_collect::<D, Time, Packet, Time2, Packet2>();
1103        CollectTensor::new(self.ctx, self.inner.transpose(false))
1104    }
1105}
1106// ANCHOR_END: switch_impl
1107
1108// ANCHOR: collect_impl
1109impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1110    SwitchTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1111{
1112    /// Normalizes packet to exactly 32 bytes (one flit).
1113    ///
1114    /// Pads to flit-aligned boundary, then splits: inner 32 bytes become Packet2,
1115    /// outer flit portion is absorbed into Time2.
1116    /// For packets already ≤ 32 bytes, only padding is added.
1117    #[primitive(SwitchTensor::collect)]
1118    pub fn collect<Time2: M, Packet2: M>(self) -> CollectTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1119        verify_collect::<D, Time, Packet, Time2, Packet2>();
1120        CollectTensor::new(self.ctx, self.inner.transpose(false))
1121    }
1122}
1123// ANCHOR_END: collect_impl
1124
1125impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1126    CollectTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1127{
1128    /// Stores to the tensor register file.
1129    #[primitive(CollectTensor::to_trf)]
1130    pub fn to_trf<Row: M, Element: M>(self, address: TrfAddress) -> TrfTensor<D, Chip, Cluster, Slice, Row, Element> {
1131        assert!(
1132            [1, 2, 4, 8].contains(&Row::SIZE),
1133            "Row::SIZE must be 1, 2, 4, or 8, got {}",
1134            Row::SIZE
1135        );
1136
1137        // Trf data should fit in the register file.
1138        let capacity = address.capacity();
1139        let total_trf_bytes = D::size_in_bytes_from_length(Row::SIZE * Element::SIZE);
1140        assert!(
1141            total_trf_bytes <= capacity,
1142            "TRF data ({} bytes = {} rows x {} bytes) exceeds register file capacity ({} bytes for {})",
1143            total_trf_bytes,
1144            Row::SIZE,
1145            D::size_in_bytes_from_length(Element::SIZE),
1146            capacity,
1147            address,
1148        );
1149
1150        // |row| <= |time|
1151        assert!(
1152            Row::SIZE <= Time::SIZE,
1153            "Row::SIZE must be <= Time::SIZE, got {} > {}",
1154            Row::SIZE,
1155            Time::SIZE,
1156        );
1157
1158        // [time_outer] = [Row]
1159        let time = Time::to_value().factorize();
1160        let Tuple2(time_outer, time_inner) = time.split_at(
1161            exact_div(Time::SIZE, Row::SIZE)
1162                .unwrap_or_else(|| panic!("Row::SIZE ({}) does not divide Time::SIZE ({})", Row::SIZE, Time::SIZE)),
1163        );
1164        let row = Row::to_value().factorize();
1165        assert_eq!(
1166            time_outer, row,
1167            "`to_trf` row mismatch: time_outer != Row: {time_outer} != {row}",
1168        );
1169
1170        // [time_inner, Packet] = [Element]
1171        let expected_element = time_inner.mul(Packet::to_value().factorize()).normalize();
1172        let element = Element::to_value().factorize();
1173        assert_eq!(
1174            expected_element, element,
1175            "`to_trf` element mismatch: [time_inner, Packet] != Element: {expected_element} != {element}",
1176        );
1177
1178        TrfTensor::new(self.inner.transpose(false), address)
1179    }
1180
1181    /// Stores to the vector register file.
1182    #[primitive(CollectTensor::to_vrf)]
1183    pub fn to_vrf<Element2: M>(self, address: Address) -> VrfTensor<D, Chip, Cluster, Slice, Element2>
1184    where
1185        D: VeScalar,
1186    {
1187        VrfTensor::new(self.inner.transpose(false), address)
1188    }
1189
1190    // ANCHOR: collect_vector_init
1191    /// Initializes Vector Engine processing for this tensor.
1192    #[primitive(CollectTensor::vector_init)]
1193    pub fn vector_init(self) -> VectorInitTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1194    where
1195        D: VeScalar,
1196        // ANCHOR_END: collect_vector_init
1197    {
1198        VectorInitTensor::new(self.ctx, self.inner)
1199    }
1200
1201    /// Aligns LHS stream (via Feed Buffer) and RHS TRF (via TRF Sequencer) to computation shape.
1202    #[primitive(CollectTensor::align)]
1203    pub fn align<OutTime: M, OutPacket: M, Row: M, TrfElement: M>(
1204        self,
1205        trf_tensor: &TrfTensor<D, Chip, Cluster, Slice, Row, TrfElement>,
1206    ) -> AlignedPair<'l, { T }, D, Chip, Cluster, Slice, Row, OutTime, OutPacket>
1207    where
1208        Chip: M,
1209        Cluster: M,
1210        Slice: M,
1211    {
1212        assert!([1, 2, 4, 8].contains(&Row::SIZE), "Row::SIZE should be 1, 2, 4, or 8");
1213
1214        let out_packet_size = D::size_in_bytes_from_length(OutPacket::SIZE);
1215        assert_eq!(
1216            out_packet_size, 64,
1217            "OutPacket must be 64 bytes, got {out_packet_size} bytes"
1218        );
1219
1220        // Inner flit of OutPacket must match input Packet.
1221        let flit_elements = D::length_from_bytes(FLIT_BYTES);
1222        let Tuple2(out_packet_outer, out_packet_inner) = OutPacket::to_value().factorize().split_at(flit_elements);
1223        let out_packet_inner = out_packet_inner.normalize();
1224        let expected_packet = Packet::to_value().factorize();
1225        assert_eq!(
1226            out_packet_inner, expected_packet,
1227            "`align` packet mismatch: inner flit of OutPacket != Packet: {out_packet_inner} != {expected_packet}",
1228        );
1229
1230        // Time must equal OutTime * outer flit portion of OutPacket.
1231        // Padding is stripped for the collect_flits = 1 case.
1232        let expected_time = OutTime::to_value()
1233            .factorize()
1234            .mul(out_packet_outer.remove_padding())
1235            .normalize();
1236        let input_time = Time::to_value().factorize();
1237
1238        // Time broadcast axes are the innermost in `OutTime`.
1239        // These are present in TRF, but absent in the input data.
1240        let tiling_size = expected_time.size() / input_time.size();
1241        let align_div = expected_time
1242            .divide_relaxed(input_time.clone())
1243            .exact()
1244            .expect("`align`: Time does not divide OutTime");
1245        let tiling = align_div.dividend_residue;
1246        let division_terms = align_div.division_terms;
1247
1248        // Non-tiling axes must follow the same order in both mappings.
1249        assert!(
1250            division_terms
1251                .windows(2)
1252                .all(|w| w[0].divisor_stride > w[1].divisor_stride),
1253            "`align`: Time axes are reordered in OutTime"
1254        );
1255
1256        // Tiling axes are the innermost axes in `OutTime`.
1257        assert!(
1258            division_terms.iter().all(|d| d.dividend_stride >= tiling_size),
1259            "`align`: tiling axes must be innermost in OutTime"
1260        );
1261
1262        if tiling.size() > 1 {
1263            let trf_element = TrfElement::to_value().factorize();
1264            assert!(
1265                trf_element.clone().divide_relaxed(tiling.clone()).exact().is_ok(),
1266                "tiling axes must be present in TRF",
1267            );
1268        }
1269
1270        let lhs = self.inner.transpose(true);
1271        let trf = trf_tensor.inner.transpose(true);
1272
1273        AlignedPair {
1274            ctx: self.ctx,
1275            lhs,
1276            trf,
1277        }
1278    }
1279
1280    /// Performs transpose operation.
1281    #[primitive(CollectTensor::transpose)]
1282    pub fn transpose<OutTime: M, OutPacket: M>(
1283        self,
1284    ) -> TransposeTensor<'l, T, D, Chip, Cluster, Slice, OutTime, OutPacket> {
1285        verify_transpose::<D, Time, Packet, OutTime, OutPacket>();
1286        TransposeTensor::new(self.ctx, self.inner.transpose(false))
1287    }
1288}
1289
1290// ANCHOR: cast_impl
1291impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1292    for CollectTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1293{
1294    type CastOutput<D2: Scalar, OutPacket: M>
1295        = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1296    where
1297        D: Cast<D2>;
1298
1299    #[primitive(CollectTensor::cast)]
1300    fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1301    where
1302        D: Cast<D2>,
1303    {
1304        cast_stream(self.ctx, self.inner)
1305    }
1306}
1307// ANCHOR_END: cast_impl
1308
1309/// Validates contraction.
1310///
1311/// Checks:
1312/// 1. `Packet` should be 64 bytes.
1313/// 2. `OutPacket::SIZE` should be a power of two and at most
1314///    `TEMPORAL_ACCUMULATOR_COLS` and be obtainable from `Packet` by splitting
1315///    at a power-of-two sized boundary.
1316fn verify_contract<D: Scalar, Packet: M, OutPacket: M>() {
1317    let packet_size = D::size_in_bytes_from_length(Packet::SIZE);
1318    assert_eq!(packet_size, 64, "Packet must be 64 bytes, got {packet_size} bytes");
1319
1320    assert!(
1321        OutPacket::SIZE <= TEMPORAL_ACCUMULATOR_COLS,
1322        "OutPacket::SIZE must be at most {TEMPORAL_ACCUMULATOR_COLS}, got {}",
1323        OutPacket::SIZE
1324    );
1325
1326    assert!(
1327        OutPacket::SIZE.is_power_of_two(),
1328        "OutPacket::SIZE must be a power of two, got {}",
1329        OutPacket::SIZE
1330    );
1331
1332    let packet = Packet::to_value().factorize();
1333    let out_packet = OutPacket::to_value().factorize().remove_padding();
1334
1335    assert!(
1336        (0..=Packet::SIZE.trailing_zeros()).rev().any(|depth| {
1337            let split = 1 << depth;
1338            let outer = packet.clone().stride(split);
1339            outer.remove_padding() == out_packet.clone()
1340        }),
1341        "OutPacket {out_packet} is not a valid contraction of Packet {packet}",
1342    );
1343}
1344
1345/// Validates accumulation constraints on output shape and accumulator size.
1346///
1347/// Checks:
1348/// 1. `Packet::SIZE` must be at most 32 (`i32`/`f32`).
1349/// 2. `OutPacket::SIZE` must be 8 (`i32`/`f32`).
1350/// 3. Output factor composition matches retained (non-reduced) axes:
1351///    - Interleaved: `OutTime = [retained Time, retained Packet (may be sliced)]`, `OutPacket = [Row # 8]`.
1352///    - Sequential: `OutTime = [retained Time, Row, packet_outer]`, `OutPacket = [packet_inner # 8]`.
1353/// 4. Accumulator fits hardware limit (1024 elements):
1354///    - Interleaved: `inner_time * ReducedPacket <= 128` (since `align_up(Row, 8) = 8`).
1355///    - Sequential: `inner_time * Row * packet_outer <= 32` (since `align_up(ReducedPacket, 32) = 32`).
1356fn verify_accumulate<Row: M, Time: M, Packet: M, OutTime: M, OutPacket: M>(kind: AccumulationKind) {
1357    assert!(
1358        Packet::SIZE <= TEMPORAL_ACCUMULATOR_COLS,
1359        "accumulate: Packet::SIZE must be at most {TEMPORAL_ACCUMULATOR_COLS}, got {}",
1360        Packet::SIZE
1361    );
1362    assert_eq!(
1363        OutPacket::SIZE,
1364        ACCUMULATE_OUT_PACKET_ELEMENTS,
1365        "accumulate: OutPacket::SIZE must be {ACCUMULATE_OUT_PACKET_ELEMENTS}, got {}",
1366        OutPacket::SIZE
1367    );
1368
1369    let time = Time::to_value().factorize();
1370    let packet = Packet::to_value().factorize().remove_padding();
1371    let out_time = OutTime::to_value().factorize();
1372    let out_packet = OutPacket::to_value().factorize();
1373
1374    // Determine `outer_time` and `packet_outer_size` based on contraction kind.
1375    // `packet_outer_size` is 1 for Interleaved (no packet split into OutTime),
1376    // and the size of the outer packet portion for Sequential.
1377    let (outer_time, packet_outer_size) = match kind {
1378        AccumulationKind::Interleaved => {
1379            // `OutPacket = [Row # 8]`
1380            let expected_out_packet = Row::to_value().factorize().pad(ACCUMULATE_OUT_PACKET_ELEMENTS);
1381            assert_eq!(
1382                out_packet, expected_out_packet,
1383                "accumulate ({kind}): OutPacket mismatch. Expected: {expected_out_packet}, got: {out_packet}"
1384            );
1385
1386            // `OutTime = [Time', Packet (may be sliced)]`
1387            // Search for the `Packet / Time'` boundary.
1388            let outer_time = (1..=out_time.size().min(packet.size()).min(TEMPORAL_ACCUMULATOR_COLS))
1389                .filter(|&split| {
1390                    out_time.size() % split == 0
1391                        // If `packet` is not the identity mapping, require it
1392                        // to be present in `OutTime`.
1393                        && (split > 1 || packet.size() == 1)
1394                        // `Time'::SIZE` <= `Time::SIZE`.
1395                        && out_time.size() / split <= time.size()
1396                })
1397                .find_map(|split| {
1398                    let Tuple2(outer_time, sliced_packet) = out_time.split_at(split);
1399
1400                    // Slicing may only remove padding.
1401                    if sliced_packet != packet {
1402                        return None;
1403                    }
1404
1405                    Some(outer_time)
1406                })
1407                .unwrap_or_else(|| {
1408                    panic!(
1409                        "accumulate ({kind}): OutTime mismatch. \
1410                         Could not decompose OutTime {out_time} into [Time', Packet'] \
1411                         where Time' is {time} after temporal accumulation and Packet' is a truncation of {packet}"
1412                    )
1413                });
1414            (outer_time, 1)
1415        }
1416        AccumulationKind::Sequential => {
1417            // `OutTime   = [Time', Row, packet_outer]`
1418            // `OutPacket = [packet_inner # ACCUMULATE_OUT_PACKET_ELEMENTS]`
1419            //
1420            // `packet` is padded to the next multiple of ACCUMULATE_OUT_PACKET_ELEMENTS,
1421            // then split at ACCUMULATE_OUT_PACKET_ELEMENTS:
1422            // - `packet_inner` (ACCUMULATE_OUT_PACKET_ELEMENTS elements): becomes `OutPacket`
1423            // - `packet_outer`                                          : absorbed into `OutTime`
1424            let padded = packet
1425                .clone()
1426                .pad(align_up(packet.size(), ACCUMULATE_OUT_PACKET_ELEMENTS));
1427            let Tuple2(packet_outer, packet_inner) = padded.split_at(ACCUMULATE_OUT_PACKET_ELEMENTS);
1428            let packet_outer_size = packet_outer.size();
1429
1430            // `OutPacket = [packet_inner # ACCUMULATE_OUT_PACKET_ELEMENTS]`.
1431            assert_eq!(
1432                packet_inner, out_packet,
1433                "accumulate ({kind}): OutPacket mismatch. Expected: {packet_inner}, got: {out_packet}"
1434            );
1435
1436            // `OutTime` ends with `[Row, packet_outer]`
1437            let row_packet = Row::to_value().factorize().mul(packet_outer);
1438            let Tuple2(outer_time, inner_time) = out_time.split_at(row_packet.size());
1439            assert_eq!(
1440                inner_time, row_packet,
1441                "accumulate ({kind}): OutTime mismatch. Expected {row_packet}, got {inner_time}"
1442            );
1443
1444            (outer_time, packet_outer_size)
1445        }
1446    };
1447
1448    // The outer portion of `Time` should divide `Time`.
1449    // Some axes can be reduced in temporal accumulation.
1450    let division_terms = time
1451        .clone()
1452        .divide_relaxed(outer_time.clone())
1453        .exact()
1454        .unwrap_or_else(|_| {
1455            panic!("accumulate ({kind}): OutTime mismatch. Some axes present in Time are not present in Time': {time}, {outer_time}")
1456        })
1457        .division_terms;
1458    // Non-reduced axes must have their order preserved in `OutTime`.
1459    assert!(
1460        division_terms
1461            .windows(2)
1462            .all(|w| w[0].divisor_stride > w[1].divisor_stride),
1463        "accumulate ({kind}): OutTime axes must follow the same order as the Time axes"
1464    );
1465
1466    // Each retained axis in `outer_time` must preserve its padding from `time`.
1467    // We store `padding_size / stride` per term (always exact since padding
1468    // aligns to stride boundaries) and verify that the stride boundaries between
1469    // consecutive retained axes in `outer_time` match the gaps produced by
1470    // padding in `time`.
1471    let mut time_padding_per_stride: HashMap<usize, usize> = HashMap::new();
1472    let factors = time.factors();
1473    let mut stride = 1;
1474    // Build `cumulative_stride` : `padding_per_stride` table for each term in `time`.
1475    // `m![A # 8, B # 4]` with `axes![A = 4, B = 2]` -> { 1: 8, 8: 4 }
1476    for (i, factor) in factors.iter().enumerate() {
1477        match factor {
1478            Factor::Term { resize, .. } => {
1479                time_padding_per_stride.insert(
1480                    stride,
1481                    if let Some(Factor::Padding { size, .. }) = factors.get(i + 1) {
1482                        size / stride
1483                    } else {
1484                        *resize
1485                    },
1486                );
1487                stride *= resize;
1488            }
1489            Factor::Padding { size, .. } => {
1490                stride = *size;
1491            }
1492        }
1493    }
1494
1495    // Sort retained axes inner-to-outer by their position in `outer_time` (divisor).
1496    let mut sorted_divisions: Vec<&DivisionTerm> = division_terms.iter().collect();
1497    sorted_divisions.sort_by_key(|d| d.divisor_stride);
1498
1499    // The first divisor should have a stride of 1.
1500    // This catches unexpected padding that creeps into `outer_time` when
1501    // `split_at` absorbs padding from an adjacent axis (e.g., `M # 8` in
1502    // `row_packet` leaking a leading padding of 2 into `outer_time`).
1503    if let Some(first) = sorted_divisions.first() {
1504        assert_eq!(
1505            first.divisor_stride, 1,
1506            "accumulate ({kind}): Padding mismatch. \
1507             OutTime {outer_time} has unexpected leading padding not present in Time {time}"
1508        );
1509    }
1510
1511    // For each retained axis, its padding end must equal the start of the next retained axis.
1512    for (pos, d) in sorted_divisions.iter().enumerate() {
1513        let expected_end = d.divisor_stride
1514            * time_padding_per_stride
1515                .get(&d.dividend_stride)
1516                .copied()
1517                .unwrap_or(d.resize);
1518        let end = sorted_divisions
1519            .get(pos + 1)
1520            // The last term's padding ends at `outer_time.size()`.
1521            .map_or(outer_time.size(), |next| next.divisor_stride);
1522        assert_eq!(
1523            expected_end, end,
1524            "accumulate ({kind}): Padding mismatch. \
1525             Non-reduced axes in OutTime {outer_time} do not preserve padding from Time {time}"
1526        );
1527    }
1528
1529    // Calculate axis size inner to outermost reduce.
1530    let padding_end = |d: &DivisionTerm| {
1531        d.dividend_stride
1532            * time_padding_per_stride
1533                .get(&d.dividend_stride)
1534                .copied()
1535                .unwrap_or(d.resize)
1536    };
1537    let inner_time = if division_terms.is_empty() {
1538        // Case 1: All axes reduced.
1539        1
1540    } else if padding_end(&division_terms[0]) < time.size() {
1541        // Case 2: The outermost axis was reduced.
1542        outer_time.size()
1543    } else {
1544        // Case 3: The outermost retained factor reaches the top of `Time`.
1545        // Walk outer-to-inner looking for the first gap between adjacent division terms.
1546        division_terms
1547            .windows(2)
1548            .find(|w| padding_end(&w[1]) != w[0].dividend_stride)
1549            .map_or(1, |w| w[0].divisor_stride)
1550    };
1551
1552    // Check buffer limits
1553    match kind {
1554        AccumulationKind::Interleaved => {
1555            let buffer = inner_time * packet.size();
1556            assert!(
1557                buffer <= 1024 / ACCUMULATE_OUT_PACKET_ELEMENTS,
1558                "accumulate ({}): axes inner to reduce must be <= {} in size, got {}",
1559                kind,
1560                1024 / ACCUMULATE_OUT_PACKET_ELEMENTS,
1561                buffer
1562            );
1563        }
1564        AccumulationKind::Sequential => {
1565            let buffer = inner_time * Row::SIZE * packet_outer_size;
1566            assert!(
1567                buffer <= 1024 / TEMPORAL_ACCUMULATOR_COLS,
1568                "accumulate ({}): axes inner to reduce must be <= {} in size, got {}",
1569                kind,
1570                1024 / TEMPORAL_ACCUMULATOR_COLS,
1571                buffer
1572            );
1573        }
1574    }
1575}
1576
1577impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M>
1578    AlignedPair<'l, { T }, D, Chip, Cluster, Slice, Row, Time, Packet>
1579{
1580    /// Performs contraction (LAT reduce within Packet).
1581    /// Data type is widened during contraction: i4/i8 -> i32, f8/bf16 -> f32.
1582    #[primitive(AlignedPair::contract)]
1583    pub fn contract<OutPacket: M>(
1584        self,
1585    ) -> ContractionTensor<'l, { T }, <D as ContractionCast>::Output, Chip, Cluster, Slice, Row, Time, OutPacket>
1586    where
1587        D: ContractionCast + Cast<<D as ContractionCast>::Output>,
1588    {
1589        verify_contract::<D, Packet, OutPacket>();
1590        let lhs = self.lhs.map(|v| v.map(|v| v.cast()));
1591        let trf = self.trf.map(|v| v.map(|v| v.cast()));
1592        ContractionTensor {
1593            ctx: self.ctx,
1594            inner: lhs.zip_with(&trf, |a, b| a * b).reduce_add(),
1595        }
1596    }
1597}
1598
1599impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M>
1600    ContractionTensor<'l, { T }, D, Chip, Cluster, Slice, Row, Time, Packet>
1601{
1602    /// Performs accumulation (accumulator reduce across Time).
1603    #[primitive(ContractionTensor::accumulate)]
1604    pub fn accumulate<OutTime: M, OutPacket: M>(
1605        self,
1606        kind: AccumulationKind,
1607    ) -> AccumulationTensor<'l, { T }, D, Chip, Cluster, Slice, OutTime, OutPacket> {
1608        verify_accumulate::<Row, Time, Packet, OutTime, OutPacket>(kind);
1609        AccumulationTensor::new(self.ctx, self.inner.reduce_add())
1610    }
1611}
1612
1613impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1614    AccumulationTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1615{
1616    // ANCHOR: accumulation_vector_init
1617    /// Initializes Vector Engine processing from contraction output.
1618    #[primitive(AccumulationTensor::vector_init)]
1619    pub fn vector_init(self) -> VectorInitTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1620    where
1621        D: VeScalar,
1622        // ANCHOR_END: accumulation_vector_init
1623    {
1624        VectorInitTensor::new(self.ctx, self.inner)
1625    }
1626}
1627
1628impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1629    for AccumulationTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1630{
1631    type CastOutput<D2: Scalar, OutPacket: M>
1632        = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1633    where
1634        D: Cast<D2>;
1635
1636    #[primitive(AccumulationTensor::cast)]
1637    fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1638    where
1639        D: Cast<D2>,
1640    {
1641        cast_stream(self.ctx, self.inner)
1642    }
1643}
1644
1645impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1646    for VectorFinalTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1647{
1648    type CastOutput<D2: Scalar, OutPacket: M>
1649        = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1650    where
1651        D: Cast<D2>;
1652
1653    #[primitive(VectorFinalTensor::cast)]
1654    fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1655    where
1656        D: Cast<D2>,
1657    {
1658        cast_stream(self.ctx, self.inner)
1659    }
1660}
1661
1662impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1663    VectorFinalTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1664{
1665    /// Stores to the vector register file after VE pipeline.
1666    #[primitive(VectorFinalTensor::to_vrf)]
1667    pub fn to_vrf<Element: M>(self, address: Address) -> VrfTensor<D, Chip, Cluster, Slice, Element>
1668    where
1669        D: VeScalar,
1670    {
1671        VrfTensor::new(self.inner.transpose(false), address)
1672    }
1673
1674    /// Performs transpose operation (transitions to Transpose engine).
1675    #[primitive(VectorFinalTensor::transpose)]
1676    pub fn transpose<Time2: M, Packet2: M>(self) -> TransposeTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1677        TransposeTensor::new(self.ctx, self.inner.transpose(false))
1678    }
1679}
1680
1681/// Verifies commit engine constraints.
1682///
1683/// Constraints checked:
1684/// 1. Input packet must be exactly one flit (32 bytes).
1685/// 2. Output packet must be 8, 16, 24, or 32 bytes.
1686/// 3. Truncation may only remove elements from Packet.
1687fn verify_commit<D: Scalar, Time: M, Packet: M, Element: M>() {
1688    // Input packet must be exactly one flit.
1689    let packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
1690    assert_eq!(
1691        packet_bytes, FLIT_BYTES,
1692        "Commit input packet must be exactly {FLIT_BYTES} bytes (one flit), got {packet_bytes}",
1693    );
1694
1695    // Time can be transposed.
1696    let Tuple2(time, packet) = Element::to_value()
1697        .factorize()
1698        .split_at(exact_div(Element::SIZE, Time::SIZE).expect("Commit element size does not divide time size"));
1699    let input_time = Time::to_value().factorize().normalize();
1700    if input_time.clone().divide_relaxed(time.clone()).exact().is_err()
1701        || time.clone().divide_relaxed(input_time.clone()).exact().is_err()
1702    {
1703        panic!("Commit output Time ({time}) is not a valid transpose of the input Time ({input_time})");
1704    }
1705
1706    // Output packet must be 8, 16, 24, or 32 bytes.
1707    let out_packet_elements = Element::SIZE / Time::SIZE;
1708    let out_packet_bytes = D::size_in_bytes_from_length(out_packet_elements);
1709    assert!(
1710        COMMIT_OUT_PACKET_SIZES.contains(&out_packet_bytes),
1711        "Commit output packet must be one of {COMMIT_OUT_PACKET_SIZES:?} bytes, got {out_packet_bytes}",
1712    );
1713
1714    // The resulting packet can be a slice of Packet by `commit_in_size`.
1715    let expected_packet = Packet::to_value().factorize();
1716    assert!(
1717        packet.is_resize_of(&expected_packet),
1718        "Commit packet mismatch. Expected {expected_packet} or a truncation of it, got {packet}",
1719    );
1720}
1721
1722// ANCHOR: commit_impl
1723impl<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1724    StreamTensor<'l, { T }, P, D, Chip, Cluster, Slice, Time, Packet>
1725{
1726    /// Commits to the data memory.
1727    #[primitive(StreamTensor::commit)]
1728    pub fn commit<Element: M>(self, address: Address) -> DmTensor<D, Chip, Cluster, Slice, Element> {
1729        verify_commit::<D, Time, Packet, Element>();
1730        DmTensor::new(self.inner.transpose(false), address)
1731    }
1732
1733    /// Commits to mutable tensor view in the data memory.
1734    #[primitive(StreamTensor::commit_view)]
1735    pub fn commit_view<Element: M>(self, mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>) {
1736        verify_commit::<D, Time, Packet, Element>();
1737        dst.inner.write_transpose(self.inner.view(), false);
1738    }
1739}
1740// ANCHOR_END: commit_impl
1741
1742fn align_up(a: usize, b: usize) -> usize {
1743    assert_ne!(b, 0);
1744    a.div_ceil(b) * b
1745}
1746
1747fn exact_div(a: usize, b: usize) -> Option<usize> {
1748    if a.is_multiple_of(b) { Some(a / b) } else { None }
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753    use super::*;
1754
1755    mod transpose {
1756        use super::*;
1757        use crate::scalar::bf16;
1758
1759        mod valid {
1760            use super::*;
1761            axes![A = 4, B = 2, C = 8, D = 4, E = 8, F = 8, G = 2, X = 64, Y = 512];
1762
1763            #[test]
1764            fn basic() {
1765                verify_transpose::<i8, m![C, F], m![E # 32], m![C, E], m![F # 32]>();
1766            }
1767
1768            #[test]
1769            fn small() {
1770                // `elements_per_packet = B # 8` is sliced to `B`
1771                verify_transpose::<i8, m![A], m![B # 32], m![B], m![A # 32]>();
1772            }
1773
1774            #[test]
1775            fn small_no_slicing() {
1776                verify_transpose::<i8, m![A], m![B # 32], m![B # 8], m![A # 32]>();
1777            }
1778
1779            #[test]
1780            fn large_col() {
1781                verify_transpose::<i8, m![B, C, D], m![E # 32], m![B, D, E], m![C # 32]>();
1782            }
1783
1784            #[test]
1785            fn bf16() {
1786                verify_transpose::<bf16, m![C, D], m![E # 16], m![C, E], m![D # 16]>();
1787            }
1788        }
1789
1790        mod input_packet {
1791            use super::*;
1792            axes![C = 8, D = 8, E = 4, F = 16];
1793
1794            #[test]
1795            #[should_panic(expected = "Transpose input packet must be 32 bytes, got 16")]
1796            fn invalid() {
1797                verify_transpose::<i8, m![C, D], m![E # 16], m![C, E], m![D # 32]>();
1798            }
1799        }
1800
1801        mod output_packet {
1802            use super::*;
1803            axes![C = 8, D = 8, E = 8];
1804
1805            #[test]
1806            #[should_panic(expected = "Transpose output packet must be 32 bytes, got 16")]
1807            fn invalid() {
1808                verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![D # 16]>();
1809            }
1810        }
1811
1812        mod in_rows {
1813            use super::*;
1814            axes![A = 4, C = 8, D = 8, E = 8, F = 32];
1815
1816            #[test]
1817            #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 16")]
1818            fn invalid_i4() {
1819                verify_transpose::<i4, m![C, D], m![E # 64], m![C, E], m![F # 64]>();
1820            }
1821
1822            #[test]
1823            #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 32")]
1824            fn invalid_i8() {
1825                verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![A, D]>();
1826            }
1827
1828            #[test]
1829            #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 16")]
1830            fn invalid_bf16() {
1831                verify_transpose::<bf16, m![C, D], m![E # 16], m![C, E], m![D # 16]>();
1832            }
1833
1834            #[test]
1835            #[should_panic(expected = "must be present in the input Time")]
1836            fn invalid_in_rows_not_in_time() {
1837                verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![A # 32]>();
1838            }
1839        }
1840
1841        mod in_cols {
1842            use super::*;
1843            axes![C = 8, D = 8, E = 8, F = 16, G = 4];
1844
1845            #[test]
1846            #[should_panic(expected = "Transpose `in_cols` size (64) must be one of [16, 32] for 4-bit type")]
1847            fn invalid_i4() {
1848                verify_transpose::<i4, m![F, G], m![E # 64], m![G, E], m![F # 64]>();
1849            }
1850
1851            #[test]
1852            #[should_panic(expected = "Transpose `in_cols` size (64) must be one of [8, 16, 32] for 8-bit type")]
1853            fn invalid_i8() {
1854                verify_transpose::<i8, m![C, D], m![E # 32], m![D, E], m![C # 32]>();
1855            }
1856        }
1857
1858        mod out_time {
1859            use super::*;
1860            axes![A = 2, B = 2, C = 4, D = 2, E = 8, F = 8, G = 16];
1861
1862            #[test]
1863            #[should_panic(expected = "Transpose time mismatch")]
1864            fn invalid_outer_mismatch() {
1865                verify_transpose::<i8, m![B, C, D], m![E # 32], m![A, D, E], m![C # 32]>();
1866            }
1867
1868            #[test]
1869            #[should_panic(expected = "not found in OutTime")]
1870            fn invalid_missing_packets_per_col() {
1871                verify_transpose::<i8, m![C, D], m![E # 32], m![E], m![C # 32]>();
1872            }
1873
1874            #[test]
1875            #[should_panic(expected = "must match")]
1876            fn invalid_wrong_axis() {
1877                verify_transpose::<i8, m![C, D], m![E # 32], m![D, F], m![C # 32]>();
1878            }
1879
1880            #[test]
1881            #[should_panic(expected = "must match")]
1882            fn invalid_non_padding_resize() {
1883                // E = 4 discards non-padded elements
1884                verify_transpose::<i8, m![C], m![E # 32], m![E = 4], m![C # 32]>();
1885            }
1886
1887            #[test]
1888            #[should_panic(expected = "must be <=")]
1889            fn invalid_out_rows_exceeds_in_cols() {
1890                verify_transpose::<i8, m![A], m![B # 32], m![B # 16], m![A # 32]>();
1891            }
1892        }
1893    }
1894
1895    mod commit {
1896        use super::*;
1897
1898        mod valid {
1899            use super::*;
1900
1901            axes![M = 4, N = 8, A = 4, B = 3, C = 4];
1902
1903            #[test]
1904            fn full_truncation() {
1905                verify_commit::<i8, m![A, B, C], m![N # 32], m![A, B, C, N]>();
1906            }
1907
1908            #[test]
1909            fn partial_truncation() {
1910                verify_commit::<i8, m![M], m![N # 32], m![M, N # 16]>();
1911            }
1912
1913            #[test]
1914            fn no_truncation() {
1915                verify_commit::<i8, m![M], m![N # 32], m![M, N # 32]>();
1916            }
1917
1918            #[test]
1919            fn bf16() {
1920                verify_commit::<bf16, m![M], m![N # 16], m![M, N]>();
1921            }
1922
1923            #[test]
1924            fn f32() {
1925                verify_commit::<f32, m![M], m![N # 8], m![M, N]>();
1926            }
1927
1928            #[test]
1929            fn single_time_step() {
1930                verify_commit::<i8, m![1], m![N # 32], m![N # 8]>();
1931            }
1932
1933            #[test]
1934            fn non_padding_resize() {
1935                verify_commit::<bf16, m![1], m![N # 16], m![N = 4]>();
1936            }
1937
1938            #[test]
1939            fn time_transpose() {
1940                verify_commit::<bf16, m![A # 32, B], m![N # 16], m![B, A # 32, N = 4]>();
1941            }
1942        }
1943
1944        mod invalid {
1945            use super::*;
1946
1947            axes![M = 4, N = 8, X = 8, Y = 4, Z = 2];
1948
1949            #[test]
1950            #[should_panic(expected = "Commit input packet must be exactly 32 bytes (one flit), got 16")]
1951            fn input_packet_not_flit() {
1952                verify_commit::<i8, m![M], m![N # 16], m![M, N]>();
1953            }
1954
1955            #[test]
1956            #[should_panic(expected = "Commit output packet must be one of [8, 16, 24, 32] bytes, got 13")]
1957            fn out_packet_invalid_size() {
1958                verify_commit::<i8, m![M], m![N # 32], m![M, N # 13]>();
1959            }
1960
1961            #[test]
1962            #[should_panic(expected = "Commit output packet must be one of [8, 16, 24, 32] bytes, got 48")]
1963            fn extra_padding() {
1964                verify_commit::<i8, m![M], m![N # 32], m![M, N # 48]>();
1965            }
1966
1967            #[test]
1968            #[should_panic(expected = "Commit packet mismatch")]
1969            fn different_packet_axes() {
1970                verify_commit::<i8, m![M], m![N # 32], m![M, X]>();
1971            }
1972
1973            #[test]
1974            #[should_panic(expected = "not a valid transpose of the input Time")]
1975            fn different_time_axes() {
1976                verify_commit::<i8, m![M], m![N # 32], m![Y, N # 16]>();
1977            }
1978
1979            #[test]
1980            #[should_panic(expected = "not a valid transpose of the input Time")]
1981            fn time_transpose_padding_mismatch() {
1982                verify_commit::<bf16, m![M # 32, X], m![N # 16], m![X, M # 16, N = 4]>();
1983            }
1984
1985            #[test]
1986            #[should_panic(expected = "not a valid transpose of the input Time")]
1987            fn time_axis_dropped_with_padding() {
1988                verify_commit::<i8, m![M, Z], m![N # 32], m![M # 8, N]>();
1989            }
1990
1991            #[test]
1992            #[should_panic(expected = "not a valid transpose of the input Time")]
1993            fn time_axis_resized_with_padding() {
1994                verify_commit::<i8, m![M # 8], m![N # 32], m![M = 2 # 8, N]>();
1995            }
1996        }
1997    }
1998
1999    mod contract {
2000        use super::*;
2001
2002        axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
2003
2004        #[test]
2005        fn valid_full_reduction() {
2006            verify_contract::<i8, m![K], m![1]>();
2007        }
2008
2009        #[test]
2010        fn valid_partial_reduction() {
2011            // K % 4 reduced
2012            verify_contract::<i8, m![K], m![K / 4]>();
2013        }
2014
2015        #[test]
2016        #[should_panic(expected = "not a valid contraction")]
2017        fn invalid_retained_packet_size() {
2018            verify_contract::<i8, m![K], m![D]>();
2019        }
2020
2021        #[test]
2022        #[should_panic(expected = "OutPacket::SIZE must be at most 32, got 64")]
2023        fn invalid_no_reduction() {
2024            // Temporal accumulator only has 32 columns, cannot fit 64 packet
2025            verify_contract::<i8, m![K], m![K]>();
2026        }
2027
2028        #[test]
2029        fn valid_partial_reduction_multi_axis() {
2030            // `D / 2 % 4` is reduced, retained_packet is `[A, D / 8]`.
2031            verify_contract::<i8, m![A, D / 2], m![A, D / 8]>();
2032        }
2033
2034        #[test]
2035        fn valid_padded_packet_inner_reduction() {
2036            verify_contract::<i8, m![A # 16, C], m![A]>();
2037        }
2038
2039        #[test]
2040        fn valid_padded_packet_inner_reduction_with_padding() {
2041            verify_contract::<i8, m![A # 16, C], m![A # 16]>();
2042        }
2043
2044        #[test]
2045        fn valid_padded_packet_split() {
2046            verify_contract::<i8, m![B # 8, N], m![B]>();
2047        }
2048
2049        #[test]
2050        fn valid_no_spatial_reduction_bf16() {
2051            // Tree depth 0: all 32 bf16 elements pass through, no reduction.
2052            verify_contract::<bf16, m![D], m![D]>();
2053        }
2054
2055        #[test]
2056        #[should_panic(expected = "OutPacket::SIZE must be a power of two, got 3")]
2057        fn invalid_non_power_of_two_out_packet() {
2058            verify_contract::<i8, m![K], m![K = 3]>();
2059        }
2060
2061        #[test]
2062        #[should_panic(expected = "not a valid contraction")]
2063        fn invalid_partial_inner_packet() {
2064            verify_contract::<i8, m![K], m![K % 4]>();
2065        }
2066    }
2067
2068    mod accumulate {
2069        use super::*;
2070
2071        axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
2072
2073        mod out_packet_size {
2074            use super::*;
2075
2076            #[test]
2077            fn valid() {
2078                verify_accumulate::<m![1], m![A], m![1], m![A], m![1 # 8]>(AccumulationKind::Interleaved);
2079            }
2080
2081            #[test]
2082            #[should_panic(expected = "OutPacket::SIZE must be 8, got 32")]
2083            fn invalid() {
2084                verify_accumulate::<m![1], m![A], m![1], m![A], m![D]>(AccumulationKind::Interleaved);
2085            }
2086        }
2087
2088        mod interleaved {
2089            use super::*;
2090
2091            #[test]
2092            fn valid() {
2093                verify_accumulate::<m![1], m![A, B], m![1], m![B], m![1 # 8]>(AccumulationKind::Interleaved);
2094            }
2095
2096            #[test]
2097            fn valid_padding() {
2098                verify_accumulate::<m![1], m![A # 8, B # 4], m![1], m![B # 4], m![1 # 8]>(
2099                    AccumulationKind::Interleaved,
2100                );
2101            }
2102
2103            #[test]
2104            fn valid_no_reduction_with_padding() {
2105                verify_accumulate::<m![1], m![A # 8, B], m![D], m![A # 8, B, D], m![1 # 8]>(
2106                    AccumulationKind::Interleaved,
2107                );
2108            }
2109
2110            #[test]
2111            fn valid_non_outermost() {
2112                verify_accumulate::<m![N], m![C, A, B], m![1], m![C, B], m![N]>(AccumulationKind::Interleaved);
2113            }
2114
2115            #[test]
2116            fn valid_four_rows() {
2117                verify_accumulate::<m![M], m![C, A, B], m![1], m![C, B], m![M # 8]>(AccumulationKind::Interleaved);
2118            }
2119
2120            #[test]
2121            fn valid_all_time_reduced() {
2122                verify_accumulate::<m![N], m![A], m![1], m![1], m![N]>(AccumulationKind::Interleaved);
2123            }
2124
2125            #[test]
2126            #[should_panic(expected = "OutTime mismatch")]
2127            fn invalid_sliced_no_padding() {
2128                // Packet truncation is only allowed on padded elements.
2129                // Otherwise, input data would be silently discarded.
2130                verify_accumulate::<m![N], m![M], m![D], m![M, D = 16], m![N]>(AccumulationKind::Interleaved);
2131            }
2132
2133            #[test]
2134            #[should_panic(expected = "Could not decompose OutTime")]
2135            fn invalid_packet_size_out_time() {
2136                verify_accumulate::<m![N], m![M], m![D], m![M, K], m![N]>(AccumulationKind::Interleaved);
2137            }
2138
2139            #[test]
2140            #[should_panic(expected = "OutTime mismatch")]
2141            fn invalid_resize() {
2142                verify_accumulate::<m![1], m![A], m![D], m![A, D / 4 % 4], m![1 # 8]>(AccumulationKind::Interleaved);
2143            }
2144
2145            #[test]
2146            #[should_panic(expected = "OutTime mismatch")]
2147            fn invalid_out_time() {
2148                verify_accumulate::<m![N], m![A, B], m![1], m![C], m![N]>(AccumulationKind::Interleaved);
2149            }
2150
2151            #[test]
2152            #[should_panic(expected = "Padding mismatch")]
2153            fn invalid_out_time_padding() {
2154                verify_accumulate::<m![1], m![A, B # 4], m![1], m![B # 2], m![1 # 8]>(AccumulationKind::Interleaved);
2155            }
2156
2157            #[test]
2158            #[should_panic(expected = "OutTime axes must follow the same order as the Time axes")]
2159            fn invalid_out_time_reorder() {
2160                verify_accumulate::<m![N], m![A, B], m![1], m![B, A], m![N]>(AccumulationKind::Interleaved);
2161            }
2162
2163            #[test]
2164            #[should_panic(expected = "Padding mismatch")]
2165            fn invalid_out_time_no_padding() {
2166                verify_accumulate::<m![N], m![A, B # 32], m![1], m![A, B], m![N]>(AccumulationKind::Interleaved);
2167            }
2168
2169            #[test]
2170            #[should_panic(expected = "axes inner to reduce must be <= 128 in size, got 256")]
2171            fn invalid_buffer() {
2172                verify_accumulate::<m![N], m![A, D # 64, C], m![1], m![D # 64, C], m![N]>(
2173                    AccumulationKind::Interleaved,
2174                );
2175            }
2176
2177            #[test]
2178            #[should_panic(expected = "axes inner to reduce must be <= 128 in size, got 256")]
2179            fn invalid_buffer_multiple_reduce_axes() {
2180                verify_accumulate::<m![N], m![A, B # 64, M # 8, C], m![1], m![B # 64, C], m![N]>(
2181                    AccumulationKind::Interleaved,
2182                );
2183            }
2184        }
2185
2186        mod sequential {
2187            use super::*;
2188
2189            #[test]
2190            fn valid() {
2191                verify_accumulate::<m![N], m![A, B], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2192            }
2193
2194            #[test]
2195            fn valid_padded_row() {
2196                verify_accumulate::<m![N], m![A, B], m![1], m![B, N # 8], m![1 # 8]>(AccumulationKind::Sequential);
2197            }
2198
2199            #[test]
2200            fn valid_all_time_reduced() {
2201                verify_accumulate::<m![N], m![A], m![1], m![N], m![1 # 8]>(AccumulationKind::Sequential);
2202            }
2203
2204            #[test]
2205            fn valid_no_reduction_with_padding() {
2206                verify_accumulate::<m![N], m![A # 8, B], m![1], m![A # 8, B, N], m![1 # 8]>(
2207                    AccumulationKind::Sequential,
2208                );
2209            }
2210
2211            #[test]
2212            fn valid_padded_packet() {
2213                verify_accumulate::<m![N], m![M], m![B], m![M, N], m![B # 8]>(AccumulationKind::Sequential);
2214            }
2215
2216            #[test]
2217            fn valid_full_temporal_reduction() {
2218                verify_accumulate::<m![N], m![M], m![D], m![N, D / 8], m![D % 8]>(AccumulationKind::Sequential);
2219            }
2220
2221            #[test]
2222            #[should_panic(expected = "OutPacket mismatch")]
2223            fn invalid_packet_axis() {
2224                verify_accumulate::<m![N], m![M], m![B], m![M, N], m![A # 8]>(AccumulationKind::Sequential);
2225            }
2226
2227            #[test]
2228            #[should_panic(expected = "OutTime mismatch")]
2229            fn invalid_out_time() {
2230                verify_accumulate::<m![N], m![A, B], m![1], m![C, N], m![1 # 8]>(AccumulationKind::Sequential);
2231            }
2232
2233            #[test]
2234            #[should_panic(expected = "OutTime mismatch")]
2235            fn invalid_out_time_row() {
2236                verify_accumulate::<m![N], m![A, B], m![1], m![B, M], m![1 # 8]>(AccumulationKind::Sequential);
2237            }
2238
2239            #[test]
2240            #[should_panic(expected = "OutTime axes must follow the same order as the Time axes")]
2241            fn invalid_out_time_reorder() {
2242                verify_accumulate::<m![N], m![A, B], m![1], m![B, A, N], m![1 # 8]>(AccumulationKind::Sequential);
2243            }
2244
2245            #[test]
2246            #[should_panic(expected = "Padding mismatch")]
2247            fn invalid_out_time_padding() {
2248                verify_accumulate::<m![N], m![A, B # 32], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2249            }
2250
2251            #[test]
2252            #[should_panic(expected = "Padding mismatch")]
2253            fn invalid_out_time_padded_row() {
2254                verify_accumulate::<m![M], m![A, B], m![1], m![B, M # 8], m![1 # 8]>(AccumulationKind::Sequential);
2255            }
2256
2257            #[test]
2258            fn valid_multi_axis_reduction() {
2259                verify_accumulate::<m![N], m![A, B, C], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2260            }
2261
2262            #[test]
2263            #[should_panic(expected = "axes inner to reduce must be <= 32 in size, got 64")]
2264            fn invalid_buffer() {
2265                verify_accumulate::<m![N], m![A, P], m![1], m![P, N], m![1 # 8]>(AccumulationKind::Sequential);
2266            }
2267
2268            #[test]
2269            #[should_panic(expected = "axes inner to reduce must be <= 32 in size, got 64")]
2270            fn invalid_buffer_packet_outer() {
2271                verify_accumulate::<m![1], m![A, N, B], m![D], m![N, B, D / 8], m![D % 8]>(
2272                    AccumulationKind::Sequential,
2273                );
2274            }
2275        }
2276    }
2277
2278    mod switch {
2279        use super::super::*;
2280
2281        mod custom_broadcast {
2282            use super::*;
2283
2284            axes![
2285                A = 16,
2286                B = 16,
2287                C = 8,
2288                D = 2,
2289                E = 2,
2290                P = 4,
2291                Q = 8,
2292                R = 8,
2293                S = 256,
2294                X = 4,
2295                Y = 2,
2296                Z = 2,
2297            ];
2298
2299            mod permutation {
2300                use super::*;
2301
2302                #[test]
2303                fn identity() {
2304                    verify_switch::<m![S], m![C], m![S], m![C]>(&SwitchConfig::CustomBroadcast { ring_size: 1 });
2305                }
2306
2307                #[test]
2308                fn full_permutation() {
2309                    verify_switch::<m![A, B], m![C], m![B % 4, B / 4, A % 4, A / 4], m![C]>(
2310                        &SwitchConfig::CustomBroadcast { ring_size: 256 },
2311                    );
2312                }
2313
2314                #[test]
2315                fn partial_permutation() {
2316                    verify_switch::<m![A, B], m![C], m![A, B % 4, B / 4], m![C]>(&SwitchConfig::CustomBroadcast {
2317                        ring_size: 16,
2318                    });
2319                }
2320
2321                #[test]
2322                fn three_axis_inner_swap() {
2323                    verify_switch::<m![R, Q, P], m![C], m![R, P, Q], m![C]>(&SwitchConfig::CustomBroadcast {
2324                        ring_size: 32,
2325                    });
2326                }
2327
2328                #[test]
2329                fn three_axis_outer_swap() {
2330                    verify_switch::<m![R, Q, P], m![C], m![Q, R, P], m![C]>(&SwitchConfig::CustomBroadcast {
2331                        ring_size: 256,
2332                    });
2333                }
2334
2335                #[test]
2336                fn padded_identity() {
2337                    verify_switch::<m![P # 16, Q # 16], m![C], m![P # 16, Q # 16], m![C]>(
2338                        &SwitchConfig::CustomBroadcast { ring_size: 1 },
2339                    );
2340                }
2341
2342                #[test]
2343                fn padded_full_swap() {
2344                    verify_switch::<m![R # 16, Q # 16], m![C], m![Q # 16, R # 16], m![C]>(
2345                        &SwitchConfig::CustomBroadcast { ring_size: 256 },
2346                    );
2347                }
2348
2349                #[test]
2350                fn padded_full_swap_different_padding() {
2351                    verify_switch::<m![R # 16, Q # 16], m![C], m![Q # 32, R # 8], m![C]>(
2352                        &SwitchConfig::CustomBroadcast { ring_size: 256 },
2353                    );
2354                }
2355
2356                #[test]
2357                fn padded_partial_permutation() {
2358                    verify_switch::<m![R # 16, Q # 16], m![C], m![R # 16, Q # 16 % 4, Q # 16 / 4], m![C]>(
2359                        &SwitchConfig::CustomBroadcast { ring_size: 16 },
2360                    );
2361                }
2362
2363                #[test]
2364                #[should_panic(
2365                    expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2366                )]
2367                fn permutation_time_change() {
2368                    verify_switch::<m![A, B], m![C], m![B, A], m![R]>(&SwitchConfig::CustomBroadcast {
2369                        ring_size: 256,
2370                    });
2371                }
2372            }
2373
2374            mod broadcast {
2375                use super::*;
2376
2377                #[test]
2378                fn broadcast() {
2379                    verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2380                        ring_size: 4,
2381                    });
2382                }
2383
2384                #[test]
2385                fn multi_axis_broadcast() {
2386                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, A % 2, B % 2]>(
2387                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2388                    );
2389                }
2390
2391                #[test]
2392                fn broadcast_with_permutation() {
2393                    verify_switch::<m![A, B], m![C], m![A % 4, A / 4, B / 4, X], m![C, B % 4]>(
2394                        &SwitchConfig::CustomBroadcast { ring_size: 256 },
2395                    );
2396                }
2397
2398                #[test]
2399                fn broadcast_with_inner_permutation() {
2400                    verify_switch::<m![R, Q, P], m![C], m![R, P / 2, Q, Y], m![C, P % 2]>(
2401                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2402                    );
2403                }
2404
2405                #[test]
2406                fn broadcast_innermost_axis() {
2407                    verify_switch::<m![R, Q, P], m![C], m![R, Q, P / 2, Y], m![C, P % 2]>(
2408                        &SwitchConfig::CustomBroadcast { ring_size: 2 },
2409                    );
2410                }
2411
2412                #[test]
2413                fn non_contiguous_broadcast() {
2414                    // Move R % 2 and P % 2 to Time (skipping Q).
2415                    verify_switch::<m![R, Q, P], m![C], m![R / 2, Y, Q, P / 2, Z], m![C, R % 2, P % 2]>(
2416                        &SwitchConfig::CustomBroadcast { ring_size: 64 },
2417                    );
2418                }
2419
2420                #[test]
2421                fn full_broadcast() {
2422                    verify_switch::<m![A, B], m![C], m![S], m![C, A, B]>(&SwitchConfig::CustomBroadcast {
2423                        ring_size: 256,
2424                    });
2425                }
2426
2427                #[test]
2428                fn padded_outer_time() {
2429                    verify_switch::<m![A, B], m![C # 32], m![A, B / 4, X], m![C # 32, B % 4]>(
2430                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2431                    );
2432                }
2433
2434                #[test]
2435                fn padded_inner_axis_broadcast() {
2436                    verify_switch::<m![P # 8, Q # 32], m![C], m![P # 8, Q # 32 / 4, X], m![C, Q # 32 % 4]>(
2437                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2438                    );
2439                }
2440
2441                #[test]
2442                fn broadcast_with_padded_outer_axis() {
2443                    verify_switch::<m![P # 32, Q], m![C], m![P # 32, Q / 4, X], m![C, Q % 4]>(
2444                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2445                    );
2446                }
2447
2448                #[test]
2449                fn padded_both_axes_broadcast() {
2450                    verify_switch::<m![P # 16, Q # 16], m![C], m![P # 16, Q # 16 / 4, X], m![C, Q # 16 % 4]>(
2451                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2452                    );
2453                }
2454
2455                #[test]
2456                fn padded_time_broadcast() {
2457                    verify_switch::<m![P # 8, Q # 32], m![C # 16], m![P # 8, Q # 32 / 4, X], m![C # 16, Q # 32 % 4]>(
2458                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2459                    );
2460                }
2461
2462                #[test]
2463                #[should_panic(expected = "Switch broadcast axes must each be used exactly once in OutSlice")]
2464                fn duplicate_broadcast_symbol() {
2465                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Y], m![C, A % 2, B % 2]>(
2466                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2467                    );
2468                }
2469
2470                #[test]
2471                #[should_panic(
2472                    expected = "Switch broadcast axes must be new axes (not present in input Slice or Time)."
2473                )]
2474                fn broadcast_axis_from_time() {
2475                    verify_switch::<m![A, B], m![C], m![A, B / 8, C], m![C, B % 8]>(&SwitchConfig::CustomBroadcast {
2476                        ring_size: 8,
2477                    });
2478                }
2479
2480                #[test]
2481                #[should_panic(
2482                    expected = "Switch broadcast axes must be new axes (not present in input Slice or Time)."
2483                )]
2484                fn inter_transpose() {
2485                    verify_switch::<m![A, B], m![C], m![A, C, B / 8], m![B % 8]>(&SwitchConfig::CustomBroadcast {
2486                        ring_size: 256,
2487                    });
2488                }
2489
2490                #[test]
2491                fn partial_broadcast_replacement() {
2492                    // A % 2 replaced by broadcast
2493                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, B % 2]>(
2494                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2495                    );
2496                }
2497
2498                #[test]
2499                #[should_panic(expected = "Switch broadcast axis X in output Slice must not be padded.")]
2500                fn padded_broadcast_axis() {
2501                    verify_switch::<m![A, B], m![C], m![A / 2, X # 8, B / 8, Y], m![C, A % 2, B % 8]>(
2502                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2503                    );
2504                }
2505
2506                #[test]
2507                #[should_panic(
2508                    expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2509                )]
2510                fn moved_axes_not_innermost() {
2511                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![A % 2, C, B % 2]>(
2512                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2513                    );
2514                }
2515
2516                #[test]
2517                #[should_panic(
2518                    expected = "Switch axes moving from input Slice to output Time must preserve their relative order from input Slice."
2519                )]
2520                fn reversed_order_in_time() {
2521                    // A % 2 and B % 2 are reversed in output Time.
2522                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, B % 2, A % 2]>(
2523                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2524                    );
2525                }
2526
2527                #[test]
2528                #[should_panic(
2529                    expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2530                )]
2531                fn outer_time_padding_mismatch() {
2532                    verify_switch::<m![A, B], m![C # 32], m![A, B / 4, X], m![C # 16, B % 4]>(
2533                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2534                    );
2535                }
2536
2537                #[test]
2538                fn broadcast_replace_in_place() {
2539                    verify_switch::<m![R, P], m![C], m![R, X], m![C]>(&SwitchConfig::CustomBroadcast { ring_size: 4 });
2540                }
2541
2542                #[test]
2543                fn broadcast_with_moved_axis() {
2544                    axes![A = 2, B = 2, C = 2, D = 2, E = 2, X = 2];
2545                    verify_switch::<m![A, B, C, D, E], m![1], m![E, B, X, A, D], m![C]>(
2546                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2547                    )
2548                }
2549            }
2550
2551            mod slicing {
2552                use super::*;
2553
2554                #[test]
2555                fn slicing() {
2556                    verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4 = 3]>(
2557                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2558                    );
2559                }
2560
2561                #[test]
2562                fn slicing_with_broadcast() {
2563                    verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 4, X], m![C, A % 2, B % 4 = 3]>(
2564                        &SwitchConfig::CustomBroadcast { ring_size: 32 },
2565                    );
2566                }
2567
2568                #[test]
2569                fn single_axis_slicing() {
2570                    verify_switch::<m![S], m![C], m![S / 4, X], m![C, S % 4 = 3]>(&SwitchConfig::CustomBroadcast {
2571                        ring_size: 4,
2572                    });
2573                }
2574
2575                #[test]
2576                fn padded_broadcast_slicing() {
2577                    verify_switch::<m![P # 8, Q # 32], m![C], m![P # 8, Q # 32 / 4, X], m![C, Q # 32 % 4 = 3]>(
2578                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2579                    );
2580                }
2581
2582                #[test]
2583                #[should_panic(expected = "Switch broadcast axes in output time must come from input slice.")]
2584                fn wrong_axis_in_slicing() {
2585                    verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B / 4 = 3]>(
2586                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2587                    );
2588                }
2589            }
2590
2591            mod ring_size {
2592                use super::*;
2593
2594                #[test]
2595                #[should_panic(expected = "Switch ring size must be a power of 2, got 3")]
2596                fn non_power_of_two() {
2597                    verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2598                        ring_size: 3,
2599                    });
2600                }
2601
2602                #[test]
2603                #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 4")]
2604                fn wrong_full_permutation() {
2605                    verify_switch::<m![A, B], m![C], m![B % 4, B / 4, A % 4, A / 4], m![C]>(
2606                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2607                    );
2608                }
2609
2610                #[test]
2611                #[should_panic(expected = "Switch ring size mismatch. Expected 16, got 256")]
2612                fn wrong_partial_permutation() {
2613                    verify_switch::<m![A, B], m![C], m![A, B % 4, B / 4], m![C]>(&SwitchConfig::CustomBroadcast {
2614                        ring_size: 256,
2615                    });
2616                }
2617
2618                #[test]
2619                #[should_panic(expected = "Switch ring size mismatch. Expected 4, got 32")]
2620                fn wrong_broadcast() {
2621                    verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2622                        ring_size: 32,
2623                    });
2624                }
2625
2626                #[test]
2627                #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 2")]
2628                fn wrong_permutation_above_broadcast() {
2629                    verify_switch::<m![R, Q, P], m![C], m![Q, R, P / 2, Y], m![C, P % 2]>(
2630                        &SwitchConfig::CustomBroadcast { ring_size: 2 },
2631                    );
2632                }
2633
2634                #[test]
2635                #[should_panic(expected = "Switch ring size mismatch. Expected 4, got 16")]
2636                fn wrong_padded_broadcast() {
2637                    verify_switch::<m![A # 16, B # 16], m![C], m![A # 16, B # 16 / 4, X], m![C, B # 16 % 4]>(
2638                        &SwitchConfig::CustomBroadcast { ring_size: 16 },
2639                    );
2640                }
2641
2642                #[test]
2643                #[should_panic(expected = "Switch ring size mismatch. Expected 16, got 256")]
2644                fn wrong_padded_permutation() {
2645                    verify_switch::<m![A # 16, B # 16], m![C], m![A # 16, B # 16 % 4, B # 16 / 4], m![C]>(
2646                        &SwitchConfig::CustomBroadcast { ring_size: 256 },
2647                    );
2648                }
2649
2650                #[test]
2651                #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 4")]
2652                fn wrong_padded_permutation_above_broadcast() {
2653                    verify_switch::<m![P # 8, Q # 32], m![C], m![Q # 32 / 4, P # 8, X], m![C, Q # 32 % 4]>(
2654                        &SwitchConfig::CustomBroadcast { ring_size: 4 },
2655                    );
2656                }
2657            }
2658        }
2659    }
2660}