1use std::collections::{HashMap, HashSet};
4use std::marker::PhantomData;
5
6use abi_stable::std_types::{RBox, Tuple2};
7use furiosa_mapping::*;
8use furiosa_mapping_macro::primitive;
9use furiosa_opt_macro::m;
10
11const FETCH_ALIGN_BYTES: usize = 8;
16const FLIT_BYTES: usize = 32;
17
18const TRANSPOSE_INPUT_BYTES: usize = 32;
20
21const TRANSPOSE_OUTPUT_BYTES: usize = 32;
23
24const TRANSPOSE_ELEMENTS_PER_PACKET_NON_4BIT: usize = 8;
26
27const TRANSPOSE_ELEMENTS_PER_PACKET_4BIT: usize = 16;
29
30const TRANSPOSE_MAX_IN_ROWS_BYTES: usize = 8;
32
33const TRANSPOSE_VALID_IN_COLS: &[usize] = &[8, 16, 32];
35
36const TRANSPOSE_VALID_IN_COLS_4BIT: &[usize] = &[16, 32];
38
39const TEMPORAL_ACCUMULATOR_COLS: usize = 32;
41
42const ACCUMULATE_OUT_PACKET_ELEMENTS: usize = 8;
44
45const COMMIT_OUT_PACKET_SIZES: [usize; 4] = [8, 16, 24, 32];
47
48#[primitive(AccumulationKind)]
50#[derive(Clone, Debug)]
51pub enum AccumulationKind {
52 Interleaved,
54 Sequential,
56}
57
58impl std::fmt::Display for AccumulationKind {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 AccumulationKind::Interleaved => write!(f, "Interleaved"),
62 AccumulationKind::Sequential => write!(f, "Sequential"),
63 }
64 }
65}
66
67use crate::cast::*;
68use crate::context::*;
69use crate::memory_tensor::*;
70use crate::scalar::*;
71use crate::tensor::*;
72use crate::vector_engine::scalar::VeScalar;
73use crate::vector_engine::tensor::VectorInitTensor;
74
75pub trait Position: std::fmt::Debug + 'static {}
79
80#[derive(Debug)]
82pub struct PositionBegin;
83
84#[derive(Debug)]
86pub struct PositionFetch;
87
88#[derive(Debug)]
90pub struct PositionSwitch;
91
92#[derive(Debug)]
94pub struct PositionCollect;
95
96#[derive(Debug)]
98pub struct PositionContraction;
99
100#[derive(Debug)]
102pub struct PositionVectorFinal;
103
104#[derive(Debug)]
106pub struct PositionCast;
107
108#[derive(Debug)]
110pub struct PositionTranspose;
111
112impl Position for PositionBegin {}
113impl Position for PositionFetch {}
114impl Position for PositionSwitch {}
115impl Position for PositionCollect {}
116impl Position for PositionContraction {}
117impl Position for PositionVectorFinal {}
118impl Position for PositionCast {}
119impl Position for PositionTranspose {}
120
121#[derive(Debug)]
123pub struct StreamTensor<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> {
124 pub(crate) ctx: &'l mut TuContext<{ T }>,
125 pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Time, Packet>>>>>,
126 _position: PhantomData<P>,
127}
128
129impl<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
130 StreamTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet>
131{
132 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }];
134
135 pub fn new(ctx: &'l mut TuContext<{ T }>, inner: Tensor<D, Self::Mapping>) -> Self {
137 Self {
138 ctx,
139 inner,
140 _position: PhantomData,
141 }
142 }
143}
144
145pub type BeginTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
147 StreamTensor<'l, { T }, PositionBegin, D, Chip, Cluster, Slice, Time, Packet>;
148
149pub type FetchTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
151 StreamTensor<'l, { T }, PositionFetch, D, Chip, Cluster, Slice, Time, Packet>;
152
153pub type SwitchTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
155 StreamTensor<'l, { T }, PositionSwitch, D, Chip, Cluster, Slice, Time, Packet>;
156
157pub type CollectTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
159 StreamTensor<'l, { T }, PositionCollect, D, Chip, Cluster, Slice, Time, Packet>;
160
161#[derive(Debug)]
163pub struct AlignedPair<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M> {
164 ctx: &'l mut TuContext<{ T }>,
165 lhs: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
166 trf: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
167}
168
169#[derive(Debug)]
172pub struct ContractionTensor<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M> {
173 ctx: &'l mut TuContext<{ T }>,
174 inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Pair<Time, Packet>>>>>>,
175}
176
177pub type AccumulationTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
179 StreamTensor<'l, { T }, PositionContraction, D, Chip, Cluster, Slice, Time, Packet>;
180
181pub type VectorFinalTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
183 StreamTensor<'l, { T }, PositionVectorFinal, D, Chip, Cluster, Slice, Time, Packet>;
184
185pub type CastTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
187 StreamTensor<'l, { T }, PositionCast, D, Chip, Cluster, Slice, Time, Packet>;
188
189pub type TransposeTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet> =
191 StreamTensor<'l, { T }, PositionTranspose, D, Chip, Cluster, Slice, Time, Packet>;
192
193impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
195 BeginTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
196{
197 #[primitive(BeginTensor::fetch)]
199 pub fn fetch<D2: Scalar, Time2: M, Packet2: M>(self) -> FetchTensor<'l, T, D2, Chip, Cluster, Slice, Time2, Packet2>
200 where
201 D: FetchCast<D2>,
202 {
203 assert_eq!(Cluster::SIZE, 2, "Cluster size must be 2, got {}", Cluster::SIZE);
204 assert_eq!(Slice::SIZE, 256, "Slice size must be 256, got {}", Slice::SIZE);
205 let packet_bytes = D2::size_in_bytes_from_length(Packet2::SIZE);
206 assert_eq!(
207 packet_bytes % FETCH_ALIGN_BYTES,
208 0,
209 "Fetch output packet must be {FETCH_ALIGN_BYTES}-byte aligned, got {packet_bytes} bytes.",
210 );
211 FetchTensor::new(self.ctx, self.inner.map(|v| v.map(|v| v.cast())).transpose(true))
212 }
213}
214pub(crate) fn cast_stream<
227 'l,
228 const T: Tu,
229 D: VeScalar + Cast<D2>,
230 D2: Scalar,
231 Chip: M,
232 Cluster: M,
233 Slice: M,
234 Time: M,
235 InPacket: M,
236 OutPacket: M,
237>(
238 ctx: &'l mut TuContext<{ T }>,
239 inner: Tensor<D, m![{ Chip }, { Cluster }, { Slice }, { Time }, { InPacket }]>,
240) -> CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket> {
241 assert_eq!(
243 D::size_in_bytes_from_length(InPacket::SIZE),
244 FLIT_BYTES,
245 "Cast input packet must be exactly {FLIT_BYTES} bytes (one flit): \
246 {} elements = {} bytes",
247 InPacket::SIZE,
248 D::size_in_bytes_from_length(InPacket::SIZE),
249 );
250
251 let out_flit_elements = D2::length_from_bytes(FLIT_BYTES);
252
253 let in_packet = InPacket::to_value().factorize();
255 let expected_packet = in_packet.pad(out_flit_elements).normalize();
256
257 let out_packet = OutPacket::to_value().factorize();
259 assert_eq!(
260 D2::size_in_bytes_from_length(OutPacket::SIZE),
261 FLIT_BYTES,
262 "Cast output packet must be exactly {FLIT_BYTES} bytes (one flit). \
263 Expected: {expected_packet}, got: {out_packet}",
264 );
265 assert_eq!(
266 expected_packet, out_packet,
267 "Cast packet mismatch. Expected: {expected_packet}, got: {out_packet}",
268 );
269
270 CastTensor::new(ctx, inner.map(|v| v.map(|v| v.cast())).transpose(false))
271}
272
273#[derive(Debug, Clone)]
275pub enum SwitchConfig {
276 Broadcast01 {
280 slice1: usize,
282 slice0: usize,
284 time0: usize,
286 },
287 Broadcast1 {
291 slice1: usize,
293 slice0: usize,
295 },
296 Transpose {
299 slice1: usize,
301 slice0: usize,
303 },
304 InterTranspose {
308 slice1: usize,
310 slice0: usize,
312 time0: usize,
314 },
315 CustomBroadcast {
319 ring_size: usize,
321 },
322 TransposedBroadcast1 {
328 slice1: usize,
330 slice0: usize,
332 },
333}
334
335fn extract_symbols(mapping: &Mapping) -> HashSet<Ident> {
337 let mut symbols = HashSet::new();
338 let mapping = mapping.factorize();
339 let mut stack = mapping.clone().into_inner();
340 while let Some(factor) = stack.pop() {
341 if let Factor::Term { inner, .. } = factor {
342 match &inner.inner {
343 Atom::Symbol { symbol, .. } => {
344 symbols.insert(*symbol);
345 }
346 Atom::Composite(inner) => {
347 stack.extend(RBox::into_inner(inner.clone()).into_inner());
348 }
349 }
350 }
351 }
352 symbols
353}
354
355fn verify_switch<InSlice: M, InTime: M, OutSlice: M, OutTime: M>(config: &SwitchConfig) {
360 assert_eq!(
361 InSlice::SIZE,
362 OutSlice::SIZE,
363 "Switch input and output slice sizes must match, got {} and {}",
364 InSlice::SIZE,
365 OutSlice::SIZE
366 );
367 config.verify::<InSlice, InTime, OutSlice, OutTime>();
368}
369
370impl SwitchConfig {
371 pub fn verify<InSlice: M, InTime: M, OutSlice: M, OutTime: M>(&self) {
374 match self {
375 SwitchConfig::Broadcast01 { slice1, slice0, time0 } => {
376 assert!(
377 [slice1, slice0, time0].iter().all(|&x| *x > 0),
378 "All dimensions must be greater than 0"
379 );
380 assert_eq!(
381 InSlice::SIZE % (slice1 * slice0),
382 0,
383 "InSlice::SIZE must be divisible by (slice1 * slice0)"
384 );
385 assert_eq!(InTime::SIZE % time0, 0, "InTime::SIZE must be divisible by time0");
386
387 assert_eq!(
389 Mapping::Stride {
390 inner: RBox::new(InSlice::to_value()),
391 stride: *slice1 * *slice0,
392 }
393 .factorize(),
394 Mapping::Stride {
395 inner: RBox::new(OutSlice::to_value()),
396 stride: *slice1 * *slice0,
397 }
398 .factorize(),
399 "OutSlice must preserve slice2 from InSlice",
400 );
401
402 if *slice1 * *slice0 > 1 {
404 let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
405 inner: RBox::new(OutSlice::to_value()),
406 modulo: *slice1 * *slice0,
407 });
408
409 let mut symbols = HashSet::new();
410 symbols.extend(extract_symbols(&InSlice::to_value()));
411 symbols.extend(extract_symbols(&InTime::to_value()));
412 assert_eq!(
413 out_slice_broadcast_symbols.intersection(&symbols).count(),
414 0,
415 "OutSlice broadcast axes must be new axes"
416 );
417 }
418
419 let mut expected_out_time = Mapping::Identity
420 .pair(Mapping::Stride {
422 inner: RBox::new(InTime::to_value()),
423 stride: *time0,
424 })
425 .pair(Mapping::Modulo {
427 inner: RBox::new(Mapping::Stride {
428 inner: RBox::new(InSlice::to_value()),
429 stride: *slice0,
430 }),
431 modulo: *slice1,
432 });
433 if *time0 > 1 {
434 expected_out_time = expected_out_time.pair(Mapping::Modulo {
436 inner: RBox::new(InTime::to_value()),
437 modulo: *time0,
438 })
439 }
440 if *slice0 > 1 {
441 expected_out_time = expected_out_time.pair(Mapping::Modulo {
443 inner: RBox::new(InSlice::to_value()),
444 modulo: *slice0,
445 })
446 }
447
448 assert_eq!(
450 expected_out_time.factorize(),
451 OutTime::to_value().factorize(),
452 "OutTime does not match expected layout: [time1, slice1, time0, slice0]"
453 );
454 }
455
456 SwitchConfig::Broadcast1 { slice1, slice0 } => {
457 assert!(
458 [slice1, slice0].iter().all(|&x| *x > 0),
459 "All dimensions must be greater than 0"
460 );
461 assert_eq!(
462 InSlice::SIZE % (slice1 * slice0),
463 0,
464 "InSlice::SIZE must be divisible by (slice1 * slice0)"
465 );
466
467 assert_eq!(
470 Mapping::Stride {
471 inner: RBox::new(InSlice::to_value()),
472 stride: *slice1 * *slice0,
473 }
474 .factorize(),
475 Mapping::Stride {
476 inner: RBox::new(OutSlice::to_value()),
477 stride: *slice1 * *slice0,
478 }
479 .factorize(),
480 "OutSlice must preserve slice2 from InSlice",
481 );
482
483 if *slice1 > 1 {
485 let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
486 inner: RBox::new(Mapping::Stride {
487 inner: RBox::new(OutSlice::to_value()),
488 stride: *slice0,
489 }),
490 modulo: *slice1,
491 });
492
493 let mut symbols = HashSet::new();
494 symbols.extend(extract_symbols(&InSlice::to_value()));
495 symbols.extend(extract_symbols(&InTime::to_value()));
496 assert_eq!(
497 out_slice_broadcast_symbols.intersection(&symbols).count(),
498 0,
499 "OutSlice broadcast axes must be new axes"
500 );
501 }
502
503 if *slice0 > 1 {
505 assert_eq!(
506 Mapping::Modulo {
507 inner: RBox::new(OutSlice::to_value()),
508 modulo: *slice0,
509 }
510 .factorize(),
511 Mapping::Modulo {
512 inner: RBox::new(InSlice::to_value()),
513 modulo: *slice0,
514 }
515 .factorize(),
516 "OutSlice must preserve slice0 from InSlice",
517 );
518 }
519
520 let mut expected_out_time = Mapping::Identity.pair(InTime::to_value());
522 if *slice1 > 1 {
523 expected_out_time = expected_out_time.pair(Mapping::Modulo {
525 inner: RBox::new(Mapping::Stride {
526 inner: RBox::new(InSlice::to_value()),
527 stride: *slice0,
528 }),
529 modulo: *slice1,
530 });
531 }
532
533 assert_eq!(
534 expected_out_time.factorize(),
535 OutTime::to_value().factorize(),
536 "OutTime does not match expected layout: [time0, slice1]"
537 );
538 }
539
540 SwitchConfig::Transpose { slice1, slice0 } => {
541 assert!(
542 [slice1, slice0].iter().all(|&x| *x > 0),
543 "All dimensions must be greater than 0"
544 );
545
546 assert_eq!(
548 InTime::SIZE,
549 OutTime::SIZE,
550 "Input and output time dimensions must have the same size"
551 );
552 assert_eq!(
553 InTime::to_value().factorize().remove_padding(),
554 OutTime::to_value().factorize().remove_padding(),
555 "Input and output time dimensions must match (excluding padding)"
556 );
557
558 assert_eq!(
560 Mapping::Stride {
561 inner: RBox::new(InSlice::to_value()),
562 stride: *slice1 * *slice0,
563 }
564 .pair(Mapping::Modulo {
565 inner: RBox::new(InSlice::to_value()),
566 modulo: *slice0,
567 })
568 .pair(Mapping::Modulo {
569 inner: RBox::new(Mapping::Stride {
570 inner: RBox::new(InSlice::to_value()),
571 stride: *slice0,
572 }),
573 modulo: *slice1,
574 })
575 .factorize(),
576 OutSlice::to_value().factorize(),
577 "OutSlice does not match expected layout: [slice2, slice0, slice1]"
578 );
579 }
580
581 SwitchConfig::InterTranspose { slice1, slice0, time0 } => {
582 assert!(
583 [slice1, slice0, time0].iter().all(|&x| *x > 0),
584 "All dimensions must be greater than 0"
585 );
586
587 assert_eq!(
588 InSlice::SIZE % (slice1 * slice0),
589 0,
590 "InSlice::SIZE must be divisible by (slice1 * slice0)"
591 );
592 assert_eq!(
593 InTime::SIZE % (slice1 * time0),
594 0,
595 "InTime::SIZE must be divisible by (slice1 * time0)"
596 );
597
598 let slice2 = InSlice::SIZE / (slice1 * slice0);
599 let time2 = InTime::SIZE / (slice1 * time0);
600
601 assert_eq!(slice2 * slice1 * slice0, 256, "All dimensions should multiply to 256");
602 assert_eq!(
603 time2 * slice1 * time0,
604 InTime::SIZE,
605 "time2 * slice1 * time0 must equal InTime::SIZE"
606 );
607
608 assert_eq!(
612 Mapping::Stride {
613 inner: RBox::new(OutSlice::to_value()),
614 stride: *slice1 * *slice0,
615 }
616 .factorize(),
617 Mapping::Stride {
618 inner: RBox::new(InSlice::to_value()),
619 stride: *slice1 * *slice0,
620 }
621 .factorize(),
622 "OutSlice must preserve slice2 from InSlice",
623 );
624 if *slice0 > 1 {
625 assert_eq!(
626 Mapping::Modulo {
627 inner: RBox::new(OutSlice::to_value()),
628 modulo: *slice0,
629 }
630 .factorize(),
631 Mapping::Modulo {
632 inner: RBox::new(InSlice::to_value()),
633 modulo: *slice0,
634 }
635 .factorize(),
636 "OutSlice must preserve slice0 from InSlice",
637 );
638 }
639
640 if *slice1 > 1 {
644 assert_eq!(
645 Mapping::Modulo {
646 inner: RBox::new(Mapping::Stride {
647 inner: RBox::new(OutSlice::to_value()),
648 stride: *slice0,
649 }),
650 modulo: *slice1,
651 }
652 .factorize(),
653 Mapping::Modulo {
654 inner: RBox::new(Mapping::Stride {
655 inner: RBox::new(InTime::to_value()),
656 stride: *time0,
657 }),
658 modulo: *slice1,
659 }
660 .factorize(),
661 "OutSlice time1 must come from InTime"
662 );
663 }
664
665 assert_eq!(
669 Mapping::Stride {
670 inner: RBox::new(OutTime::to_value()),
671 stride: *slice1 * time0,
672 }
673 .factorize(),
674 Mapping::Stride {
675 inner: RBox::new(InTime::to_value()),
676 stride: *time0 * slice1,
677 }
678 .factorize(),
679 "OutTime must preserve 'time2' from InTime",
680 );
681 if *time0 > 1 {
683 assert_eq!(
684 Mapping::Modulo {
685 inner: RBox::new(Mapping::Stride {
686 inner: RBox::new(OutTime::to_value()),
687 stride: *slice1,
688 }),
689 modulo: *time0,
690 }
691 .factorize(),
692 Mapping::Modulo {
693 inner: RBox::new(InTime::to_value()),
694 modulo: *time0,
695 }
696 .factorize(),
697 "OutTime must preserve 'time0' from InTime"
698 );
699 }
700 if *slice1 > 1 {
704 assert_eq!(
705 Mapping::Modulo {
706 inner: RBox::new(OutTime::to_value()),
707 modulo: *slice1,
708 }
709 .factorize(),
710 Mapping::Modulo {
711 inner: RBox::new(Mapping::Stride {
712 inner: RBox::new(InSlice::to_value()),
713 stride: *slice0,
714 }),
715 modulo: *slice1,
716 }
717 .factorize(),
718 "OutTime must preserve 'slice1' from InSlice"
719 );
720 }
721 }
722
723 SwitchConfig::CustomBroadcast { ring_size } => {
735 assert!(
736 ring_size.is_power_of_two(),
737 "Switch ring size must be a power of 2, got {ring_size}"
738 );
739
740 let slice = InSlice::to_value().factorize();
741 let out_slice = OutSlice::to_value().factorize();
742 let time = InTime::to_value().factorize();
743 let out_time = OutTime::to_value().factorize();
744
745 let furiosa_mapping::Division {
748 dividend_residue,
749 divisor_residue,
750 division_terms,
751 ..
752 } = out_slice.clone().divide_strict(slice.clone());
753 assert!(
754 dividend_residue
755 .clone()
756 .divide_strict(slice.clone().mul(time.clone()))
757 .division_terms()
758 .is_empty(),
759 "Switch broadcast axes must be new axes (not present in input Slice or Time)."
760 );
761
762 assert_eq!(
767 dividend_residue.terms_with_stride().len(),
768 dividend_residue.idents().len(),
769 "Switch broadcast axes must each be used exactly once in OutSlice"
770 );
771
772 let factors = out_slice.factors();
774 for (i, factor) in factors.iter().enumerate() {
775 if let Factor::Term { inner, .. } = factor
776 && !division_terms.iter().any(|d| d.term.inner == inner.inner)
777 && matches!(factors.get(i + 1), Some(Factor::Padding { .. }))
778 {
779 panic!("Switch broadcast axis {inner} in output Slice must not be padded.");
780 }
781 }
782
783 let Tuple2(outer, inner) =
785 out_time.split_at(exact_div(OutTime::SIZE, InTime::SIZE).unwrap_or_else(|| {
786 panic!(
787 "Input Time size ({}) does not divide Output Time size ({})",
788 InTime::SIZE,
789 OutTime::SIZE
790 )
791 }));
792 assert_eq!(
793 outer, time,
794 "Switch axes moving from input slice to output time must be at the output time innermost positions. \
795 Expected outer portion of output time to be {time}, got {outer}"
796 );
797
798 let broadcast_divisions = divisor_residue
799 .clone()
800 .divide_relaxed(inner.clone())
801 .exact()
802 .unwrap_or_else(|_| {
803 panic!(
804 "Switch broadcast axes in output time must come from input slice. \
805 Input Slice: {slice}, inner part of output Time: {inner}"
806 )
807 })
808 .division_terms;
809
810 for window in broadcast_divisions.windows(2) {
813 assert!(
814 window[0].divisor_stride > window[1].divisor_stride,
815 "Switch axes moving from input Slice to output Time must preserve their relative order from input Slice. \
816 {} (outer in input Slice) appears inner to {} in output Time",
817 window[0].term,
818 window[1].term
819 );
820 }
821
822 let mut max_non_dc = 0;
825 for d in &division_terms {
826 if d.dividend_stride != d.divisor_stride {
827 max_non_dc = max_non_dc.max(d.divisor_stride * d.resize);
828 }
829 }
830 for t in &divisor_residue.terms_with_stride() {
831 max_non_dc = max_non_dc.max(t.stride * t.resize);
832 }
833
834 let mut expected_ring_size = InSlice::SIZE;
837 for d in &division_terms {
838 if d.dividend_stride == d.divisor_stride && d.divisor_stride >= max_non_dc {
839 expected_ring_size = expected_ring_size.min(d.divisor_stride);
840 }
841 }
842 assert_eq!(
843 *ring_size, expected_ring_size,
844 "Switch ring size mismatch. Expected {expected_ring_size}, got {ring_size}"
845 );
846 }
847
848 SwitchConfig::TransposedBroadcast1 { slice1, slice0 } => {
849 assert!(
850 [slice1, slice0].iter().all(|&x| *x > 0),
851 "All dimensions must be greater than 0"
852 );
853
854 assert_eq!(
855 InSlice::SIZE % (slice1 * slice0),
856 0,
857 "InSlice::SIZE must be divisible by (slice1 * slice0)"
858 );
859
860 assert_eq!(
863 Mapping::Stride {
864 inner: RBox::new(InSlice::to_value()),
865 stride: *slice1 * *slice0,
866 }
867 .factorize(),
868 Mapping::Stride {
869 inner: RBox::new(OutSlice::to_value()),
870 stride: *slice0 * *slice1,
871 }
872 .factorize(),
873 "OutSlice must preserve slice2 from InSlice",
874 );
875
876 if *slice0 > 1 {
878 let out_slice_broadcast_symbols = extract_symbols(&Mapping::Modulo {
879 inner: RBox::new(Mapping::Stride {
880 inner: RBox::new(OutSlice::to_value()),
881 stride: *slice1,
882 }),
883 modulo: *slice0,
884 });
885
886 let mut symbols = HashSet::new();
887 symbols.extend(extract_symbols(&InSlice::to_value()));
888 symbols.extend(extract_symbols(&InTime::to_value()));
889 assert_eq!(
890 out_slice_broadcast_symbols.intersection(&symbols).count(),
891 0,
892 "OutSlice broadcast axes must be new axes"
893 );
894 }
895
896 if *slice1 > 1 {
898 assert_eq!(
899 Mapping::Modulo {
900 inner: RBox::new(OutSlice::to_value()),
901 modulo: *slice1,
902 }
903 .factorize(),
904 Mapping::Modulo {
905 inner: RBox::new(Mapping::Stride {
906 inner: RBox::new(InSlice::to_value()),
907 stride: *slice0
908 }),
909 modulo: *slice1,
910 }
911 .factorize(),
912 "OutSlice must preserve slice1 from InSlice",
913 );
914 }
915
916 let mut expected_out_time = Mapping::Identity.pair(InTime::to_value());
918 if *slice0 > 1 {
919 expected_out_time = expected_out_time.pair(Mapping::Modulo {
921 inner: RBox::new(InSlice::to_value()),
922 modulo: *slice0,
923 });
924 }
925
926 assert_eq!(
927 expected_out_time.factorize(),
928 OutTime::to_value().factorize(),
929 "OutTime does not match expected layout: [time0, slice0]"
930 );
931 }
932 }
933 }
934}
935
936fn verify_collect<D: Scalar, Time: M, Packet: M, Time2: M, Packet2: M>() {
944 let in_packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
945 let aligned_bytes = align_up(in_packet_bytes, FLIT_BYTES);
946 let flit_elements = D::length_from_bytes(FLIT_BYTES);
947
948 assert_eq!(
950 D::size_in_bytes_from_length(Packet2::SIZE),
951 FLIT_BYTES,
952 "Collect output packet must be exactly {FLIT_BYTES} bytes (one flit)."
953 );
954
955 let in_factorized = Packet::to_value().factorize();
957 let padded = in_factorized.pad(D::length_from_bytes(aligned_bytes));
958 let Tuple2(in_outer, in_flit) = padded.split_at(flit_elements);
959
960 let expected_packet = in_flit.normalize();
962 let out_packet = Packet2::to_value().factorize();
963 assert_eq!(
964 expected_packet, out_packet,
965 "Collect packet mismatch. Expected: {expected_packet}, got: {out_packet}"
966 );
967
968 let expected_time = Time::to_value().factorize().mul(in_outer).normalize();
970 let out_time = Time2::to_value().factorize();
971 assert_eq!(
972 expected_time, out_time,
973 "Collect time mismatch. Expected: {expected_time}, got: {out_time}"
974 );
975}
976
977fn verify_transpose<D: Scalar, Time: M, Packet: M, OutTime: M, OutPacket: M>() {
985 let packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
987 assert_eq!(
988 packet_bytes, TRANSPOSE_INPUT_BYTES,
989 "Transpose input packet must be {TRANSPOSE_INPUT_BYTES} bytes, got {packet_bytes}"
990 );
991
992 let out_packet_bytes = D::size_in_bytes_from_length(OutPacket::SIZE);
994 assert_eq!(
995 out_packet_bytes, TRANSPOSE_OUTPUT_BYTES,
996 "Transpose output packet must be {TRANSPOSE_OUTPUT_BYTES} bytes, got {out_packet_bytes}",
997 );
998
999 let in_rows = OutPacket::to_value().factorize().remove_padding();
1002 let in_rows_bytes = D::size_in_bytes_from_length(in_rows.size());
1003 assert!(
1004 in_rows_bytes <= TRANSPOSE_MAX_IN_ROWS_BYTES,
1005 "Transpose `in_rows` must be <= {TRANSPOSE_MAX_IN_ROWS_BYTES} bytes, got {in_rows_bytes}"
1006 );
1007
1008 let time = Time::to_value().factorize();
1011 let division_terms = time
1012 .clone()
1013 .divide_relaxed(in_rows.clone())
1014 .exact()
1015 .unwrap_or_else(|_| panic!("Transpose `in_rows` ({in_rows}) must be present in the input Time ({time})"))
1016 .division_terms;
1017
1018 let Tuple2(time_outer, packets_per_col) = time.split_at(division_terms[0].dividend_stride);
1020 let Tuple2(time_outer, _) = time_outer.split_at(in_rows.size());
1021 let (elements_per_packet, valid_in_cols) = match D::BITS {
1022 4 => (TRANSPOSE_ELEMENTS_PER_PACKET_4BIT, TRANSPOSE_VALID_IN_COLS_4BIT),
1023 _ => (TRANSPOSE_ELEMENTS_PER_PACKET_NON_4BIT, TRANSPOSE_VALID_IN_COLS),
1024 };
1025 let in_cols = packets_per_col.size() * elements_per_packet;
1026 assert!(
1027 valid_in_cols.contains(&in_cols),
1028 "Transpose `in_cols` size ({in_cols}) must be one of {valid_in_cols:?} for {}-bit type",
1029 D::BITS,
1030 );
1031
1032 let elements_per_packet = Packet::to_value().factorize().pad(elements_per_packet);
1033 let out_time = OutTime::to_value().factorize();
1034
1035 let Tuple2(out_time_outer, out_time_elements_per_packet) = if packets_per_col.size() > 1 {
1039 let division_terms = out_time
1040 .clone()
1041 .divide_relaxed(packets_per_col.clone())
1042 .exact()
1043 .unwrap_or_else(|_| {
1044 panic!("Transpose `packets_per_col` ({packets_per_col}) not found in OutTime ({out_time})")
1045 })
1046 .division_terms;
1047 let Tuple2(outer, num_elems) = out_time.split_at(division_terms[0].dividend_stride);
1048 let Tuple2(outer, _) = outer.split_at(packets_per_col.size());
1049 Tuple2(outer, num_elems)
1050 } else {
1051 let out_rows = OutTime::SIZE / time_outer.size();
1052 out_time.split_at(out_rows)
1053 };
1054 let out_rows = packets_per_col
1055 .clone()
1056 .mul(out_time_elements_per_packet.clone())
1057 .normalize();
1058 let in_cols = packets_per_col.mul(elements_per_packet.clone()).normalize();
1059 assert!(
1060 out_rows.size() <= in_cols.size(),
1061 "Transpose `out_rows` ({}) must be <= `in_cols` ({})",
1062 out_rows.size(),
1063 in_cols.size(),
1064 );
1065 assert_eq!(
1066 out_time_elements_per_packet.remove_padding(),
1067 elements_per_packet.remove_padding(),
1068 "Transpose `out_rows` ({out_rows}) must match `in_cols` ({in_cols}) (excluding padding)",
1069 );
1070
1071 assert_eq!(
1073 time_outer, out_time_outer,
1074 "Transpose time mismatch: expected outer OutTime to be {time_outer}, got ({out_time_outer})"
1075 );
1076}
1077
1078impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1080 FetchTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1081{
1082 #[primitive(FetchTensor::switch)]
1088 pub fn switch<Slice2: M, Time2: M>(
1089 self,
1090 config: SwitchConfig,
1091 ) -> SwitchTensor<'l, T, D, Chip, Cluster, Slice2, Time2, Packet> {
1092 verify_switch::<Slice, Time, Slice2, Time2>(&config);
1093 SwitchTensor::new(self.ctx, self.inner.transpose(true))
1094 }
1095
1096 #[primitive(FetchTensor::collect)]
1101 pub fn collect<Time2: M, Packet2: M>(self) -> CollectTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1102 verify_collect::<D, Time, Packet, Time2, Packet2>();
1103 CollectTensor::new(self.ctx, self.inner.transpose(false))
1104 }
1105}
1106impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1110 SwitchTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1111{
1112 #[primitive(SwitchTensor::collect)]
1118 pub fn collect<Time2: M, Packet2: M>(self) -> CollectTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1119 verify_collect::<D, Time, Packet, Time2, Packet2>();
1120 CollectTensor::new(self.ctx, self.inner.transpose(false))
1121 }
1122}
1123impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1126 CollectTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1127{
1128 #[primitive(CollectTensor::to_trf)]
1130 pub fn to_trf<Row: M, Element: M>(self, address: TrfAddress) -> TrfTensor<D, Chip, Cluster, Slice, Row, Element> {
1131 assert!(
1132 [1, 2, 4, 8].contains(&Row::SIZE),
1133 "Row::SIZE must be 1, 2, 4, or 8, got {}",
1134 Row::SIZE
1135 );
1136
1137 let capacity = address.capacity();
1139 let total_trf_bytes = D::size_in_bytes_from_length(Row::SIZE * Element::SIZE);
1140 assert!(
1141 total_trf_bytes <= capacity,
1142 "TRF data ({} bytes = {} rows x {} bytes) exceeds register file capacity ({} bytes for {})",
1143 total_trf_bytes,
1144 Row::SIZE,
1145 D::size_in_bytes_from_length(Element::SIZE),
1146 capacity,
1147 address,
1148 );
1149
1150 assert!(
1152 Row::SIZE <= Time::SIZE,
1153 "Row::SIZE must be <= Time::SIZE, got {} > {}",
1154 Row::SIZE,
1155 Time::SIZE,
1156 );
1157
1158 let time = Time::to_value().factorize();
1160 let Tuple2(time_outer, time_inner) = time.split_at(
1161 exact_div(Time::SIZE, Row::SIZE)
1162 .unwrap_or_else(|| panic!("Row::SIZE ({}) does not divide Time::SIZE ({})", Row::SIZE, Time::SIZE)),
1163 );
1164 let row = Row::to_value().factorize();
1165 assert_eq!(
1166 time_outer, row,
1167 "`to_trf` row mismatch: time_outer != Row: {time_outer} != {row}",
1168 );
1169
1170 let expected_element = time_inner.mul(Packet::to_value().factorize()).normalize();
1172 let element = Element::to_value().factorize();
1173 assert_eq!(
1174 expected_element, element,
1175 "`to_trf` element mismatch: [time_inner, Packet] != Element: {expected_element} != {element}",
1176 );
1177
1178 TrfTensor::new(self.inner.transpose(false), address)
1179 }
1180
1181 #[primitive(CollectTensor::to_vrf)]
1183 pub fn to_vrf<Element2: M>(self, address: Address) -> VrfTensor<D, Chip, Cluster, Slice, Element2>
1184 where
1185 D: VeScalar,
1186 {
1187 VrfTensor::new(self.inner.transpose(false), address)
1188 }
1189
1190 #[primitive(CollectTensor::vector_init)]
1193 pub fn vector_init(self) -> VectorInitTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1194 where
1195 D: VeScalar,
1196 {
1198 VectorInitTensor::new(self.ctx, self.inner)
1199 }
1200
1201 #[primitive(CollectTensor::align)]
1203 pub fn align<OutTime: M, OutPacket: M, Row: M, TrfElement: M>(
1204 self,
1205 trf_tensor: &TrfTensor<D, Chip, Cluster, Slice, Row, TrfElement>,
1206 ) -> AlignedPair<'l, { T }, D, Chip, Cluster, Slice, Row, OutTime, OutPacket>
1207 where
1208 Chip: M,
1209 Cluster: M,
1210 Slice: M,
1211 {
1212 assert!([1, 2, 4, 8].contains(&Row::SIZE), "Row::SIZE should be 1, 2, 4, or 8");
1213
1214 let out_packet_size = D::size_in_bytes_from_length(OutPacket::SIZE);
1215 assert_eq!(
1216 out_packet_size, 64,
1217 "OutPacket must be 64 bytes, got {out_packet_size} bytes"
1218 );
1219
1220 let flit_elements = D::length_from_bytes(FLIT_BYTES);
1222 let Tuple2(out_packet_outer, out_packet_inner) = OutPacket::to_value().factorize().split_at(flit_elements);
1223 let out_packet_inner = out_packet_inner.normalize();
1224 let expected_packet = Packet::to_value().factorize();
1225 assert_eq!(
1226 out_packet_inner, expected_packet,
1227 "`align` packet mismatch: inner flit of OutPacket != Packet: {out_packet_inner} != {expected_packet}",
1228 );
1229
1230 let expected_time = OutTime::to_value()
1233 .factorize()
1234 .mul(out_packet_outer.remove_padding())
1235 .normalize();
1236 let input_time = Time::to_value().factorize();
1237
1238 let tiling_size = expected_time.size() / input_time.size();
1241 let align_div = expected_time
1242 .divide_relaxed(input_time.clone())
1243 .exact()
1244 .expect("`align`: Time does not divide OutTime");
1245 let tiling = align_div.dividend_residue;
1246 let division_terms = align_div.division_terms;
1247
1248 assert!(
1250 division_terms
1251 .windows(2)
1252 .all(|w| w[0].divisor_stride > w[1].divisor_stride),
1253 "`align`: Time axes are reordered in OutTime"
1254 );
1255
1256 assert!(
1258 division_terms.iter().all(|d| d.dividend_stride >= tiling_size),
1259 "`align`: tiling axes must be innermost in OutTime"
1260 );
1261
1262 if tiling.size() > 1 {
1263 let trf_element = TrfElement::to_value().factorize();
1264 assert!(
1265 trf_element.clone().divide_relaxed(tiling.clone()).exact().is_ok(),
1266 "tiling axes must be present in TRF",
1267 );
1268 }
1269
1270 let lhs = self.inner.transpose(true);
1271 let trf = trf_tensor.inner.transpose(true);
1272
1273 AlignedPair {
1274 ctx: self.ctx,
1275 lhs,
1276 trf,
1277 }
1278 }
1279
1280 #[primitive(CollectTensor::transpose)]
1282 pub fn transpose<OutTime: M, OutPacket: M>(
1283 self,
1284 ) -> TransposeTensor<'l, T, D, Chip, Cluster, Slice, OutTime, OutPacket> {
1285 verify_transpose::<D, Time, Packet, OutTime, OutPacket>();
1286 TransposeTensor::new(self.ctx, self.inner.transpose(false))
1287 }
1288}
1289
1290impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1292 for CollectTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1293{
1294 type CastOutput<D2: Scalar, OutPacket: M>
1295 = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1296 where
1297 D: Cast<D2>;
1298
1299 #[primitive(CollectTensor::cast)]
1300 fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1301 where
1302 D: Cast<D2>,
1303 {
1304 cast_stream(self.ctx, self.inner)
1305 }
1306}
1307fn verify_contract<D: Scalar, Packet: M, OutPacket: M>() {
1317 let packet_size = D::size_in_bytes_from_length(Packet::SIZE);
1318 assert_eq!(packet_size, 64, "Packet must be 64 bytes, got {packet_size} bytes");
1319
1320 assert!(
1321 OutPacket::SIZE <= TEMPORAL_ACCUMULATOR_COLS,
1322 "OutPacket::SIZE must be at most {TEMPORAL_ACCUMULATOR_COLS}, got {}",
1323 OutPacket::SIZE
1324 );
1325
1326 assert!(
1327 OutPacket::SIZE.is_power_of_two(),
1328 "OutPacket::SIZE must be a power of two, got {}",
1329 OutPacket::SIZE
1330 );
1331
1332 let packet = Packet::to_value().factorize();
1333 let out_packet = OutPacket::to_value().factorize().remove_padding();
1334
1335 assert!(
1336 (0..=Packet::SIZE.trailing_zeros()).rev().any(|depth| {
1337 let split = 1 << depth;
1338 let outer = packet.clone().stride(split);
1339 outer.remove_padding() == out_packet.clone()
1340 }),
1341 "OutPacket {out_packet} is not a valid contraction of Packet {packet}",
1342 );
1343}
1344
1345fn verify_accumulate<Row: M, Time: M, Packet: M, OutTime: M, OutPacket: M>(kind: AccumulationKind) {
1357 assert!(
1358 Packet::SIZE <= TEMPORAL_ACCUMULATOR_COLS,
1359 "accumulate: Packet::SIZE must be at most {TEMPORAL_ACCUMULATOR_COLS}, got {}",
1360 Packet::SIZE
1361 );
1362 assert_eq!(
1363 OutPacket::SIZE,
1364 ACCUMULATE_OUT_PACKET_ELEMENTS,
1365 "accumulate: OutPacket::SIZE must be {ACCUMULATE_OUT_PACKET_ELEMENTS}, got {}",
1366 OutPacket::SIZE
1367 );
1368
1369 let time = Time::to_value().factorize();
1370 let packet = Packet::to_value().factorize().remove_padding();
1371 let out_time = OutTime::to_value().factorize();
1372 let out_packet = OutPacket::to_value().factorize();
1373
1374 let (outer_time, packet_outer_size) = match kind {
1378 AccumulationKind::Interleaved => {
1379 let expected_out_packet = Row::to_value().factorize().pad(ACCUMULATE_OUT_PACKET_ELEMENTS);
1381 assert_eq!(
1382 out_packet, expected_out_packet,
1383 "accumulate ({kind}): OutPacket mismatch. Expected: {expected_out_packet}, got: {out_packet}"
1384 );
1385
1386 let outer_time = (1..=out_time.size().min(packet.size()).min(TEMPORAL_ACCUMULATOR_COLS))
1389 .filter(|&split| {
1390 out_time.size() % split == 0
1391 && (split > 1 || packet.size() == 1)
1394 && out_time.size() / split <= time.size()
1396 })
1397 .find_map(|split| {
1398 let Tuple2(outer_time, sliced_packet) = out_time.split_at(split);
1399
1400 if sliced_packet != packet {
1402 return None;
1403 }
1404
1405 Some(outer_time)
1406 })
1407 .unwrap_or_else(|| {
1408 panic!(
1409 "accumulate ({kind}): OutTime mismatch. \
1410 Could not decompose OutTime {out_time} into [Time', Packet'] \
1411 where Time' is {time} after temporal accumulation and Packet' is a truncation of {packet}"
1412 )
1413 });
1414 (outer_time, 1)
1415 }
1416 AccumulationKind::Sequential => {
1417 let padded = packet
1425 .clone()
1426 .pad(align_up(packet.size(), ACCUMULATE_OUT_PACKET_ELEMENTS));
1427 let Tuple2(packet_outer, packet_inner) = padded.split_at(ACCUMULATE_OUT_PACKET_ELEMENTS);
1428 let packet_outer_size = packet_outer.size();
1429
1430 assert_eq!(
1432 packet_inner, out_packet,
1433 "accumulate ({kind}): OutPacket mismatch. Expected: {packet_inner}, got: {out_packet}"
1434 );
1435
1436 let row_packet = Row::to_value().factorize().mul(packet_outer);
1438 let Tuple2(outer_time, inner_time) = out_time.split_at(row_packet.size());
1439 assert_eq!(
1440 inner_time, row_packet,
1441 "accumulate ({kind}): OutTime mismatch. Expected {row_packet}, got {inner_time}"
1442 );
1443
1444 (outer_time, packet_outer_size)
1445 }
1446 };
1447
1448 let division_terms = time
1451 .clone()
1452 .divide_relaxed(outer_time.clone())
1453 .exact()
1454 .unwrap_or_else(|_| {
1455 panic!("accumulate ({kind}): OutTime mismatch. Some axes present in Time are not present in Time': {time}, {outer_time}")
1456 })
1457 .division_terms;
1458 assert!(
1460 division_terms
1461 .windows(2)
1462 .all(|w| w[0].divisor_stride > w[1].divisor_stride),
1463 "accumulate ({kind}): OutTime axes must follow the same order as the Time axes"
1464 );
1465
1466 let mut time_padding_per_stride: HashMap<usize, usize> = HashMap::new();
1472 let factors = time.factors();
1473 let mut stride = 1;
1474 for (i, factor) in factors.iter().enumerate() {
1477 match factor {
1478 Factor::Term { resize, .. } => {
1479 time_padding_per_stride.insert(
1480 stride,
1481 if let Some(Factor::Padding { size, .. }) = factors.get(i + 1) {
1482 size / stride
1483 } else {
1484 *resize
1485 },
1486 );
1487 stride *= resize;
1488 }
1489 Factor::Padding { size, .. } => {
1490 stride = *size;
1491 }
1492 }
1493 }
1494
1495 let mut sorted_divisions: Vec<&DivisionTerm> = division_terms.iter().collect();
1497 sorted_divisions.sort_by_key(|d| d.divisor_stride);
1498
1499 if let Some(first) = sorted_divisions.first() {
1504 assert_eq!(
1505 first.divisor_stride, 1,
1506 "accumulate ({kind}): Padding mismatch. \
1507 OutTime {outer_time} has unexpected leading padding not present in Time {time}"
1508 );
1509 }
1510
1511 for (pos, d) in sorted_divisions.iter().enumerate() {
1513 let expected_end = d.divisor_stride
1514 * time_padding_per_stride
1515 .get(&d.dividend_stride)
1516 .copied()
1517 .unwrap_or(d.resize);
1518 let end = sorted_divisions
1519 .get(pos + 1)
1520 .map_or(outer_time.size(), |next| next.divisor_stride);
1522 assert_eq!(
1523 expected_end, end,
1524 "accumulate ({kind}): Padding mismatch. \
1525 Non-reduced axes in OutTime {outer_time} do not preserve padding from Time {time}"
1526 );
1527 }
1528
1529 let padding_end = |d: &DivisionTerm| {
1531 d.dividend_stride
1532 * time_padding_per_stride
1533 .get(&d.dividend_stride)
1534 .copied()
1535 .unwrap_or(d.resize)
1536 };
1537 let inner_time = if division_terms.is_empty() {
1538 1
1540 } else if padding_end(&division_terms[0]) < time.size() {
1541 outer_time.size()
1543 } else {
1544 division_terms
1547 .windows(2)
1548 .find(|w| padding_end(&w[1]) != w[0].dividend_stride)
1549 .map_or(1, |w| w[0].divisor_stride)
1550 };
1551
1552 match kind {
1554 AccumulationKind::Interleaved => {
1555 let buffer = inner_time * packet.size();
1556 assert!(
1557 buffer <= 1024 / ACCUMULATE_OUT_PACKET_ELEMENTS,
1558 "accumulate ({}): axes inner to reduce must be <= {} in size, got {}",
1559 kind,
1560 1024 / ACCUMULATE_OUT_PACKET_ELEMENTS,
1561 buffer
1562 );
1563 }
1564 AccumulationKind::Sequential => {
1565 let buffer = inner_time * Row::SIZE * packet_outer_size;
1566 assert!(
1567 buffer <= 1024 / TEMPORAL_ACCUMULATOR_COLS,
1568 "accumulate ({}): axes inner to reduce must be <= {} in size, got {}",
1569 kind,
1570 1024 / TEMPORAL_ACCUMULATOR_COLS,
1571 buffer
1572 );
1573 }
1574 }
1575}
1576
1577impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M>
1578 AlignedPair<'l, { T }, D, Chip, Cluster, Slice, Row, Time, Packet>
1579{
1580 #[primitive(AlignedPair::contract)]
1583 pub fn contract<OutPacket: M>(
1584 self,
1585 ) -> ContractionTensor<'l, { T }, <D as ContractionCast>::Output, Chip, Cluster, Slice, Row, Time, OutPacket>
1586 where
1587 D: ContractionCast + Cast<<D as ContractionCast>::Output>,
1588 {
1589 verify_contract::<D, Packet, OutPacket>();
1590 let lhs = self.lhs.map(|v| v.map(|v| v.cast()));
1591 let trf = self.trf.map(|v| v.map(|v| v.cast()));
1592 ContractionTensor {
1593 ctx: self.ctx,
1594 inner: lhs.zip_with(&trf, |a, b| a * b).reduce_add(),
1595 }
1596 }
1597}
1598
1599impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Time: M, Packet: M>
1600 ContractionTensor<'l, { T }, D, Chip, Cluster, Slice, Row, Time, Packet>
1601{
1602 #[primitive(ContractionTensor::accumulate)]
1604 pub fn accumulate<OutTime: M, OutPacket: M>(
1605 self,
1606 kind: AccumulationKind,
1607 ) -> AccumulationTensor<'l, { T }, D, Chip, Cluster, Slice, OutTime, OutPacket> {
1608 verify_accumulate::<Row, Time, Packet, OutTime, OutPacket>(kind);
1609 AccumulationTensor::new(self.ctx, self.inner.reduce_add())
1610 }
1611}
1612
1613impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1614 AccumulationTensor<'l, { T }, D, Chip, Cluster, Slice, Time, Packet>
1615{
1616 #[primitive(AccumulationTensor::vector_init)]
1619 pub fn vector_init(self) -> VectorInitTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1620 where
1621 D: VeScalar,
1622 {
1624 VectorInitTensor::new(self.ctx, self.inner)
1625 }
1626}
1627
1628impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1629 for AccumulationTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1630{
1631 type CastOutput<D2: Scalar, OutPacket: M>
1632 = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1633 where
1634 D: Cast<D2>;
1635
1636 #[primitive(AccumulationTensor::cast)]
1637 fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1638 where
1639 D: Cast<D2>,
1640 {
1641 cast_stream(self.ctx, self.inner)
1642 }
1643}
1644
1645impl<'l, const T: Tu, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M> StreamCast<D>
1646 for VectorFinalTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1647{
1648 type CastOutput<D2: Scalar, OutPacket: M>
1649 = CastTensor<'l, T, D2, Chip, Cluster, Slice, Time, OutPacket>
1650 where
1651 D: Cast<D2>;
1652
1653 #[primitive(VectorFinalTensor::cast)]
1654 fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
1655 where
1656 D: Cast<D2>,
1657 {
1658 cast_stream(self.ctx, self.inner)
1659 }
1660}
1661
1662impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1663 VectorFinalTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet>
1664{
1665 #[primitive(VectorFinalTensor::to_vrf)]
1667 pub fn to_vrf<Element: M>(self, address: Address) -> VrfTensor<D, Chip, Cluster, Slice, Element>
1668 where
1669 D: VeScalar,
1670 {
1671 VrfTensor::new(self.inner.transpose(false), address)
1672 }
1673
1674 #[primitive(VectorFinalTensor::transpose)]
1676 pub fn transpose<Time2: M, Packet2: M>(self) -> TransposeTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2> {
1677 TransposeTensor::new(self.ctx, self.inner.transpose(false))
1678 }
1679}
1680
1681fn verify_commit<D: Scalar, Time: M, Packet: M, Element: M>() {
1688 let packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
1690 assert_eq!(
1691 packet_bytes, FLIT_BYTES,
1692 "Commit input packet must be exactly {FLIT_BYTES} bytes (one flit), got {packet_bytes}",
1693 );
1694
1695 let Tuple2(time, packet) = Element::to_value()
1697 .factorize()
1698 .split_at(exact_div(Element::SIZE, Time::SIZE).expect("Commit element size does not divide time size"));
1699 let input_time = Time::to_value().factorize().normalize();
1700 if input_time.clone().divide_relaxed(time.clone()).exact().is_err()
1701 || time.clone().divide_relaxed(input_time.clone()).exact().is_err()
1702 {
1703 panic!("Commit output Time ({time}) is not a valid transpose of the input Time ({input_time})");
1704 }
1705
1706 let out_packet_elements = Element::SIZE / Time::SIZE;
1708 let out_packet_bytes = D::size_in_bytes_from_length(out_packet_elements);
1709 assert!(
1710 COMMIT_OUT_PACKET_SIZES.contains(&out_packet_bytes),
1711 "Commit output packet must be one of {COMMIT_OUT_PACKET_SIZES:?} bytes, got {out_packet_bytes}",
1712 );
1713
1714 let expected_packet = Packet::to_value().factorize();
1716 assert!(
1717 packet.is_resize_of(&expected_packet),
1718 "Commit packet mismatch. Expected {expected_packet} or a truncation of it, got {packet}",
1719 );
1720}
1721
1722impl<'l, const T: Tu, P: Position, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M>
1724 StreamTensor<'l, { T }, P, D, Chip, Cluster, Slice, Time, Packet>
1725{
1726 #[primitive(StreamTensor::commit)]
1728 pub fn commit<Element: M>(self, address: Address) -> DmTensor<D, Chip, Cluster, Slice, Element> {
1729 verify_commit::<D, Time, Packet, Element>();
1730 DmTensor::new(self.inner.transpose(false), address)
1731 }
1732
1733 #[primitive(StreamTensor::commit_view)]
1735 pub fn commit_view<Element: M>(self, mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>) {
1736 verify_commit::<D, Time, Packet, Element>();
1737 dst.inner.write_transpose(self.inner.view(), false);
1738 }
1739}
1740fn align_up(a: usize, b: usize) -> usize {
1743 assert_ne!(b, 0);
1744 a.div_ceil(b) * b
1745}
1746
1747fn exact_div(a: usize, b: usize) -> Option<usize> {
1748 if a.is_multiple_of(b) { Some(a / b) } else { None }
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753 use super::*;
1754
1755 mod transpose {
1756 use super::*;
1757 use crate::scalar::bf16;
1758
1759 mod valid {
1760 use super::*;
1761 axes![A = 4, B = 2, C = 8, D = 4, E = 8, F = 8, G = 2, X = 64, Y = 512];
1762
1763 #[test]
1764 fn basic() {
1765 verify_transpose::<i8, m![C, F], m![E # 32], m![C, E], m![F # 32]>();
1766 }
1767
1768 #[test]
1769 fn small() {
1770 verify_transpose::<i8, m![A], m![B # 32], m![B], m![A # 32]>();
1772 }
1773
1774 #[test]
1775 fn small_no_slicing() {
1776 verify_transpose::<i8, m![A], m![B # 32], m![B # 8], m![A # 32]>();
1777 }
1778
1779 #[test]
1780 fn large_col() {
1781 verify_transpose::<i8, m![B, C, D], m![E # 32], m![B, D, E], m![C # 32]>();
1782 }
1783
1784 #[test]
1785 fn bf16() {
1786 verify_transpose::<bf16, m![C, D], m![E # 16], m![C, E], m![D # 16]>();
1787 }
1788 }
1789
1790 mod input_packet {
1791 use super::*;
1792 axes![C = 8, D = 8, E = 4, F = 16];
1793
1794 #[test]
1795 #[should_panic(expected = "Transpose input packet must be 32 bytes, got 16")]
1796 fn invalid() {
1797 verify_transpose::<i8, m![C, D], m![E # 16], m![C, E], m![D # 32]>();
1798 }
1799 }
1800
1801 mod output_packet {
1802 use super::*;
1803 axes![C = 8, D = 8, E = 8];
1804
1805 #[test]
1806 #[should_panic(expected = "Transpose output packet must be 32 bytes, got 16")]
1807 fn invalid() {
1808 verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![D # 16]>();
1809 }
1810 }
1811
1812 mod in_rows {
1813 use super::*;
1814 axes![A = 4, C = 8, D = 8, E = 8, F = 32];
1815
1816 #[test]
1817 #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 16")]
1818 fn invalid_i4() {
1819 verify_transpose::<i4, m![C, D], m![E # 64], m![C, E], m![F # 64]>();
1820 }
1821
1822 #[test]
1823 #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 32")]
1824 fn invalid_i8() {
1825 verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![A, D]>();
1826 }
1827
1828 #[test]
1829 #[should_panic(expected = "Transpose `in_rows` must be <= 8 bytes, got 16")]
1830 fn invalid_bf16() {
1831 verify_transpose::<bf16, m![C, D], m![E # 16], m![C, E], m![D # 16]>();
1832 }
1833
1834 #[test]
1835 #[should_panic(expected = "must be present in the input Time")]
1836 fn invalid_in_rows_not_in_time() {
1837 verify_transpose::<i8, m![C, D], m![E # 32], m![C, E], m![A # 32]>();
1838 }
1839 }
1840
1841 mod in_cols {
1842 use super::*;
1843 axes![C = 8, D = 8, E = 8, F = 16, G = 4];
1844
1845 #[test]
1846 #[should_panic(expected = "Transpose `in_cols` size (64) must be one of [16, 32] for 4-bit type")]
1847 fn invalid_i4() {
1848 verify_transpose::<i4, m![F, G], m![E # 64], m![G, E], m![F # 64]>();
1849 }
1850
1851 #[test]
1852 #[should_panic(expected = "Transpose `in_cols` size (64) must be one of [8, 16, 32] for 8-bit type")]
1853 fn invalid_i8() {
1854 verify_transpose::<i8, m![C, D], m![E # 32], m![D, E], m![C # 32]>();
1855 }
1856 }
1857
1858 mod out_time {
1859 use super::*;
1860 axes![A = 2, B = 2, C = 4, D = 2, E = 8, F = 8, G = 16];
1861
1862 #[test]
1863 #[should_panic(expected = "Transpose time mismatch")]
1864 fn invalid_outer_mismatch() {
1865 verify_transpose::<i8, m![B, C, D], m![E # 32], m![A, D, E], m![C # 32]>();
1866 }
1867
1868 #[test]
1869 #[should_panic(expected = "not found in OutTime")]
1870 fn invalid_missing_packets_per_col() {
1871 verify_transpose::<i8, m![C, D], m![E # 32], m![E], m![C # 32]>();
1872 }
1873
1874 #[test]
1875 #[should_panic(expected = "must match")]
1876 fn invalid_wrong_axis() {
1877 verify_transpose::<i8, m![C, D], m![E # 32], m![D, F], m![C # 32]>();
1878 }
1879
1880 #[test]
1881 #[should_panic(expected = "must match")]
1882 fn invalid_non_padding_resize() {
1883 verify_transpose::<i8, m![C], m![E # 32], m![E = 4], m![C # 32]>();
1885 }
1886
1887 #[test]
1888 #[should_panic(expected = "must be <=")]
1889 fn invalid_out_rows_exceeds_in_cols() {
1890 verify_transpose::<i8, m![A], m![B # 32], m![B # 16], m![A # 32]>();
1891 }
1892 }
1893 }
1894
1895 mod commit {
1896 use super::*;
1897
1898 mod valid {
1899 use super::*;
1900
1901 axes![M = 4, N = 8, A = 4, B = 3, C = 4];
1902
1903 #[test]
1904 fn full_truncation() {
1905 verify_commit::<i8, m![A, B, C], m![N # 32], m![A, B, C, N]>();
1906 }
1907
1908 #[test]
1909 fn partial_truncation() {
1910 verify_commit::<i8, m![M], m![N # 32], m![M, N # 16]>();
1911 }
1912
1913 #[test]
1914 fn no_truncation() {
1915 verify_commit::<i8, m![M], m![N # 32], m![M, N # 32]>();
1916 }
1917
1918 #[test]
1919 fn bf16() {
1920 verify_commit::<bf16, m![M], m![N # 16], m![M, N]>();
1921 }
1922
1923 #[test]
1924 fn f32() {
1925 verify_commit::<f32, m![M], m![N # 8], m![M, N]>();
1926 }
1927
1928 #[test]
1929 fn single_time_step() {
1930 verify_commit::<i8, m![1], m![N # 32], m![N # 8]>();
1931 }
1932
1933 #[test]
1934 fn non_padding_resize() {
1935 verify_commit::<bf16, m![1], m![N # 16], m![N = 4]>();
1936 }
1937
1938 #[test]
1939 fn time_transpose() {
1940 verify_commit::<bf16, m![A # 32, B], m![N # 16], m![B, A # 32, N = 4]>();
1941 }
1942 }
1943
1944 mod invalid {
1945 use super::*;
1946
1947 axes![M = 4, N = 8, X = 8, Y = 4, Z = 2];
1948
1949 #[test]
1950 #[should_panic(expected = "Commit input packet must be exactly 32 bytes (one flit), got 16")]
1951 fn input_packet_not_flit() {
1952 verify_commit::<i8, m![M], m![N # 16], m![M, N]>();
1953 }
1954
1955 #[test]
1956 #[should_panic(expected = "Commit output packet must be one of [8, 16, 24, 32] bytes, got 13")]
1957 fn out_packet_invalid_size() {
1958 verify_commit::<i8, m![M], m![N # 32], m![M, N # 13]>();
1959 }
1960
1961 #[test]
1962 #[should_panic(expected = "Commit output packet must be one of [8, 16, 24, 32] bytes, got 48")]
1963 fn extra_padding() {
1964 verify_commit::<i8, m![M], m![N # 32], m![M, N # 48]>();
1965 }
1966
1967 #[test]
1968 #[should_panic(expected = "Commit packet mismatch")]
1969 fn different_packet_axes() {
1970 verify_commit::<i8, m![M], m![N # 32], m![M, X]>();
1971 }
1972
1973 #[test]
1974 #[should_panic(expected = "not a valid transpose of the input Time")]
1975 fn different_time_axes() {
1976 verify_commit::<i8, m![M], m![N # 32], m![Y, N # 16]>();
1977 }
1978
1979 #[test]
1980 #[should_panic(expected = "not a valid transpose of the input Time")]
1981 fn time_transpose_padding_mismatch() {
1982 verify_commit::<bf16, m![M # 32, X], m![N # 16], m![X, M # 16, N = 4]>();
1983 }
1984
1985 #[test]
1986 #[should_panic(expected = "not a valid transpose of the input Time")]
1987 fn time_axis_dropped_with_padding() {
1988 verify_commit::<i8, m![M, Z], m![N # 32], m![M # 8, N]>();
1989 }
1990
1991 #[test]
1992 #[should_panic(expected = "not a valid transpose of the input Time")]
1993 fn time_axis_resized_with_padding() {
1994 verify_commit::<i8, m![M # 8], m![N # 32], m![M = 2 # 8, N]>();
1995 }
1996 }
1997 }
1998
1999 mod contract {
2000 use super::*;
2001
2002 axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
2003
2004 #[test]
2005 fn valid_full_reduction() {
2006 verify_contract::<i8, m![K], m![1]>();
2007 }
2008
2009 #[test]
2010 fn valid_partial_reduction() {
2011 verify_contract::<i8, m![K], m![K / 4]>();
2013 }
2014
2015 #[test]
2016 #[should_panic(expected = "not a valid contraction")]
2017 fn invalid_retained_packet_size() {
2018 verify_contract::<i8, m![K], m![D]>();
2019 }
2020
2021 #[test]
2022 #[should_panic(expected = "OutPacket::SIZE must be at most 32, got 64")]
2023 fn invalid_no_reduction() {
2024 verify_contract::<i8, m![K], m![K]>();
2026 }
2027
2028 #[test]
2029 fn valid_partial_reduction_multi_axis() {
2030 verify_contract::<i8, m![A, D / 2], m![A, D / 8]>();
2032 }
2033
2034 #[test]
2035 fn valid_padded_packet_inner_reduction() {
2036 verify_contract::<i8, m![A # 16, C], m![A]>();
2037 }
2038
2039 #[test]
2040 fn valid_padded_packet_inner_reduction_with_padding() {
2041 verify_contract::<i8, m![A # 16, C], m![A # 16]>();
2042 }
2043
2044 #[test]
2045 fn valid_padded_packet_split() {
2046 verify_contract::<i8, m![B # 8, N], m![B]>();
2047 }
2048
2049 #[test]
2050 fn valid_no_spatial_reduction_bf16() {
2051 verify_contract::<bf16, m![D], m![D]>();
2053 }
2054
2055 #[test]
2056 #[should_panic(expected = "OutPacket::SIZE must be a power of two, got 3")]
2057 fn invalid_non_power_of_two_out_packet() {
2058 verify_contract::<i8, m![K], m![K = 3]>();
2059 }
2060
2061 #[test]
2062 #[should_panic(expected = "not a valid contraction")]
2063 fn invalid_partial_inner_packet() {
2064 verify_contract::<i8, m![K], m![K % 4]>();
2065 }
2066 }
2067
2068 mod accumulate {
2069 use super::*;
2070
2071 axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
2072
2073 mod out_packet_size {
2074 use super::*;
2075
2076 #[test]
2077 fn valid() {
2078 verify_accumulate::<m![1], m![A], m![1], m![A], m![1 # 8]>(AccumulationKind::Interleaved);
2079 }
2080
2081 #[test]
2082 #[should_panic(expected = "OutPacket::SIZE must be 8, got 32")]
2083 fn invalid() {
2084 verify_accumulate::<m![1], m![A], m![1], m![A], m![D]>(AccumulationKind::Interleaved);
2085 }
2086 }
2087
2088 mod interleaved {
2089 use super::*;
2090
2091 #[test]
2092 fn valid() {
2093 verify_accumulate::<m![1], m![A, B], m![1], m![B], m![1 # 8]>(AccumulationKind::Interleaved);
2094 }
2095
2096 #[test]
2097 fn valid_padding() {
2098 verify_accumulate::<m![1], m![A # 8, B # 4], m![1], m![B # 4], m![1 # 8]>(
2099 AccumulationKind::Interleaved,
2100 );
2101 }
2102
2103 #[test]
2104 fn valid_no_reduction_with_padding() {
2105 verify_accumulate::<m![1], m![A # 8, B], m![D], m![A # 8, B, D], m![1 # 8]>(
2106 AccumulationKind::Interleaved,
2107 );
2108 }
2109
2110 #[test]
2111 fn valid_non_outermost() {
2112 verify_accumulate::<m![N], m![C, A, B], m![1], m![C, B], m![N]>(AccumulationKind::Interleaved);
2113 }
2114
2115 #[test]
2116 fn valid_four_rows() {
2117 verify_accumulate::<m![M], m![C, A, B], m![1], m![C, B], m![M # 8]>(AccumulationKind::Interleaved);
2118 }
2119
2120 #[test]
2121 fn valid_all_time_reduced() {
2122 verify_accumulate::<m![N], m![A], m![1], m![1], m![N]>(AccumulationKind::Interleaved);
2123 }
2124
2125 #[test]
2126 #[should_panic(expected = "OutTime mismatch")]
2127 fn invalid_sliced_no_padding() {
2128 verify_accumulate::<m![N], m![M], m![D], m![M, D = 16], m![N]>(AccumulationKind::Interleaved);
2131 }
2132
2133 #[test]
2134 #[should_panic(expected = "Could not decompose OutTime")]
2135 fn invalid_packet_size_out_time() {
2136 verify_accumulate::<m![N], m![M], m![D], m![M, K], m![N]>(AccumulationKind::Interleaved);
2137 }
2138
2139 #[test]
2140 #[should_panic(expected = "OutTime mismatch")]
2141 fn invalid_resize() {
2142 verify_accumulate::<m![1], m![A], m![D], m![A, D / 4 % 4], m![1 # 8]>(AccumulationKind::Interleaved);
2143 }
2144
2145 #[test]
2146 #[should_panic(expected = "OutTime mismatch")]
2147 fn invalid_out_time() {
2148 verify_accumulate::<m![N], m![A, B], m![1], m![C], m![N]>(AccumulationKind::Interleaved);
2149 }
2150
2151 #[test]
2152 #[should_panic(expected = "Padding mismatch")]
2153 fn invalid_out_time_padding() {
2154 verify_accumulate::<m![1], m![A, B # 4], m![1], m![B # 2], m![1 # 8]>(AccumulationKind::Interleaved);
2155 }
2156
2157 #[test]
2158 #[should_panic(expected = "OutTime axes must follow the same order as the Time axes")]
2159 fn invalid_out_time_reorder() {
2160 verify_accumulate::<m![N], m![A, B], m![1], m![B, A], m![N]>(AccumulationKind::Interleaved);
2161 }
2162
2163 #[test]
2164 #[should_panic(expected = "Padding mismatch")]
2165 fn invalid_out_time_no_padding() {
2166 verify_accumulate::<m![N], m![A, B # 32], m![1], m![A, B], m![N]>(AccumulationKind::Interleaved);
2167 }
2168
2169 #[test]
2170 #[should_panic(expected = "axes inner to reduce must be <= 128 in size, got 256")]
2171 fn invalid_buffer() {
2172 verify_accumulate::<m![N], m![A, D # 64, C], m![1], m![D # 64, C], m![N]>(
2173 AccumulationKind::Interleaved,
2174 );
2175 }
2176
2177 #[test]
2178 #[should_panic(expected = "axes inner to reduce must be <= 128 in size, got 256")]
2179 fn invalid_buffer_multiple_reduce_axes() {
2180 verify_accumulate::<m![N], m![A, B # 64, M # 8, C], m![1], m![B # 64, C], m![N]>(
2181 AccumulationKind::Interleaved,
2182 );
2183 }
2184 }
2185
2186 mod sequential {
2187 use super::*;
2188
2189 #[test]
2190 fn valid() {
2191 verify_accumulate::<m![N], m![A, B], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2192 }
2193
2194 #[test]
2195 fn valid_padded_row() {
2196 verify_accumulate::<m![N], m![A, B], m![1], m![B, N # 8], m![1 # 8]>(AccumulationKind::Sequential);
2197 }
2198
2199 #[test]
2200 fn valid_all_time_reduced() {
2201 verify_accumulate::<m![N], m![A], m![1], m![N], m![1 # 8]>(AccumulationKind::Sequential);
2202 }
2203
2204 #[test]
2205 fn valid_no_reduction_with_padding() {
2206 verify_accumulate::<m![N], m![A # 8, B], m![1], m![A # 8, B, N], m![1 # 8]>(
2207 AccumulationKind::Sequential,
2208 );
2209 }
2210
2211 #[test]
2212 fn valid_padded_packet() {
2213 verify_accumulate::<m![N], m![M], m![B], m![M, N], m![B # 8]>(AccumulationKind::Sequential);
2214 }
2215
2216 #[test]
2217 fn valid_full_temporal_reduction() {
2218 verify_accumulate::<m![N], m![M], m![D], m![N, D / 8], m![D % 8]>(AccumulationKind::Sequential);
2219 }
2220
2221 #[test]
2222 #[should_panic(expected = "OutPacket mismatch")]
2223 fn invalid_packet_axis() {
2224 verify_accumulate::<m![N], m![M], m![B], m![M, N], m![A # 8]>(AccumulationKind::Sequential);
2225 }
2226
2227 #[test]
2228 #[should_panic(expected = "OutTime mismatch")]
2229 fn invalid_out_time() {
2230 verify_accumulate::<m![N], m![A, B], m![1], m![C, N], m![1 # 8]>(AccumulationKind::Sequential);
2231 }
2232
2233 #[test]
2234 #[should_panic(expected = "OutTime mismatch")]
2235 fn invalid_out_time_row() {
2236 verify_accumulate::<m![N], m![A, B], m![1], m![B, M], m![1 # 8]>(AccumulationKind::Sequential);
2237 }
2238
2239 #[test]
2240 #[should_panic(expected = "OutTime axes must follow the same order as the Time axes")]
2241 fn invalid_out_time_reorder() {
2242 verify_accumulate::<m![N], m![A, B], m![1], m![B, A, N], m![1 # 8]>(AccumulationKind::Sequential);
2243 }
2244
2245 #[test]
2246 #[should_panic(expected = "Padding mismatch")]
2247 fn invalid_out_time_padding() {
2248 verify_accumulate::<m![N], m![A, B # 32], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2249 }
2250
2251 #[test]
2252 #[should_panic(expected = "Padding mismatch")]
2253 fn invalid_out_time_padded_row() {
2254 verify_accumulate::<m![M], m![A, B], m![1], m![B, M # 8], m![1 # 8]>(AccumulationKind::Sequential);
2255 }
2256
2257 #[test]
2258 fn valid_multi_axis_reduction() {
2259 verify_accumulate::<m![N], m![A, B, C], m![1], m![B, N], m![1 # 8]>(AccumulationKind::Sequential);
2260 }
2261
2262 #[test]
2263 #[should_panic(expected = "axes inner to reduce must be <= 32 in size, got 64")]
2264 fn invalid_buffer() {
2265 verify_accumulate::<m![N], m![A, P], m![1], m![P, N], m![1 # 8]>(AccumulationKind::Sequential);
2266 }
2267
2268 #[test]
2269 #[should_panic(expected = "axes inner to reduce must be <= 32 in size, got 64")]
2270 fn invalid_buffer_packet_outer() {
2271 verify_accumulate::<m![1], m![A, N, B], m![D], m![N, B, D / 8], m![D % 8]>(
2272 AccumulationKind::Sequential,
2273 );
2274 }
2275 }
2276 }
2277
2278 mod switch {
2279 use super::super::*;
2280
2281 mod custom_broadcast {
2282 use super::*;
2283
2284 axes![
2285 A = 16,
2286 B = 16,
2287 C = 8,
2288 D = 2,
2289 E = 2,
2290 P = 4,
2291 Q = 8,
2292 R = 8,
2293 S = 256,
2294 X = 4,
2295 Y = 2,
2296 Z = 2,
2297 ];
2298
2299 mod permutation {
2300 use super::*;
2301
2302 #[test]
2303 fn identity() {
2304 verify_switch::<m![S], m![C], m![S], m![C]>(&SwitchConfig::CustomBroadcast { ring_size: 1 });
2305 }
2306
2307 #[test]
2308 fn full_permutation() {
2309 verify_switch::<m![A, B], m![C], m![B % 4, B / 4, A % 4, A / 4], m![C]>(
2310 &SwitchConfig::CustomBroadcast { ring_size: 256 },
2311 );
2312 }
2313
2314 #[test]
2315 fn partial_permutation() {
2316 verify_switch::<m![A, B], m![C], m![A, B % 4, B / 4], m![C]>(&SwitchConfig::CustomBroadcast {
2317 ring_size: 16,
2318 });
2319 }
2320
2321 #[test]
2322 fn three_axis_inner_swap() {
2323 verify_switch::<m![R, Q, P], m![C], m![R, P, Q], m![C]>(&SwitchConfig::CustomBroadcast {
2324 ring_size: 32,
2325 });
2326 }
2327
2328 #[test]
2329 fn three_axis_outer_swap() {
2330 verify_switch::<m![R, Q, P], m![C], m![Q, R, P], m![C]>(&SwitchConfig::CustomBroadcast {
2331 ring_size: 256,
2332 });
2333 }
2334
2335 #[test]
2336 fn padded_identity() {
2337 verify_switch::<m![P # 16, Q # 16], m![C], m![P # 16, Q # 16], m![C]>(
2338 &SwitchConfig::CustomBroadcast { ring_size: 1 },
2339 );
2340 }
2341
2342 #[test]
2343 fn padded_full_swap() {
2344 verify_switch::<m![R # 16, Q # 16], m![C], m![Q # 16, R # 16], m![C]>(
2345 &SwitchConfig::CustomBroadcast { ring_size: 256 },
2346 );
2347 }
2348
2349 #[test]
2350 fn padded_full_swap_different_padding() {
2351 verify_switch::<m![R # 16, Q # 16], m![C], m![Q # 32, R # 8], m![C]>(
2352 &SwitchConfig::CustomBroadcast { ring_size: 256 },
2353 );
2354 }
2355
2356 #[test]
2357 fn padded_partial_permutation() {
2358 verify_switch::<m![R # 16, Q # 16], m![C], m![R # 16, Q # 16 % 4, Q # 16 / 4], m![C]>(
2359 &SwitchConfig::CustomBroadcast { ring_size: 16 },
2360 );
2361 }
2362
2363 #[test]
2364 #[should_panic(
2365 expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2366 )]
2367 fn permutation_time_change() {
2368 verify_switch::<m![A, B], m![C], m![B, A], m![R]>(&SwitchConfig::CustomBroadcast {
2369 ring_size: 256,
2370 });
2371 }
2372 }
2373
2374 mod broadcast {
2375 use super::*;
2376
2377 #[test]
2378 fn broadcast() {
2379 verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2380 ring_size: 4,
2381 });
2382 }
2383
2384 #[test]
2385 fn multi_axis_broadcast() {
2386 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, A % 2, B % 2]>(
2387 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2388 );
2389 }
2390
2391 #[test]
2392 fn broadcast_with_permutation() {
2393 verify_switch::<m![A, B], m![C], m![A % 4, A / 4, B / 4, X], m![C, B % 4]>(
2394 &SwitchConfig::CustomBroadcast { ring_size: 256 },
2395 );
2396 }
2397
2398 #[test]
2399 fn broadcast_with_inner_permutation() {
2400 verify_switch::<m![R, Q, P], m![C], m![R, P / 2, Q, Y], m![C, P % 2]>(
2401 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2402 );
2403 }
2404
2405 #[test]
2406 fn broadcast_innermost_axis() {
2407 verify_switch::<m![R, Q, P], m![C], m![R, Q, P / 2, Y], m![C, P % 2]>(
2408 &SwitchConfig::CustomBroadcast { ring_size: 2 },
2409 );
2410 }
2411
2412 #[test]
2413 fn non_contiguous_broadcast() {
2414 verify_switch::<m![R, Q, P], m![C], m![R / 2, Y, Q, P / 2, Z], m![C, R % 2, P % 2]>(
2416 &SwitchConfig::CustomBroadcast { ring_size: 64 },
2417 );
2418 }
2419
2420 #[test]
2421 fn full_broadcast() {
2422 verify_switch::<m![A, B], m![C], m![S], m![C, A, B]>(&SwitchConfig::CustomBroadcast {
2423 ring_size: 256,
2424 });
2425 }
2426
2427 #[test]
2428 fn padded_outer_time() {
2429 verify_switch::<m![A, B], m![C # 32], m![A, B / 4, X], m![C # 32, B % 4]>(
2430 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2431 );
2432 }
2433
2434 #[test]
2435 fn padded_inner_axis_broadcast() {
2436 verify_switch::<m![P # 8, Q # 32], m![C], m![P # 8, Q # 32 / 4, X], m![C, Q # 32 % 4]>(
2437 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2438 );
2439 }
2440
2441 #[test]
2442 fn broadcast_with_padded_outer_axis() {
2443 verify_switch::<m![P # 32, Q], m![C], m![P # 32, Q / 4, X], m![C, Q % 4]>(
2444 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2445 );
2446 }
2447
2448 #[test]
2449 fn padded_both_axes_broadcast() {
2450 verify_switch::<m![P # 16, Q # 16], m![C], m![P # 16, Q # 16 / 4, X], m![C, Q # 16 % 4]>(
2451 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2452 );
2453 }
2454
2455 #[test]
2456 fn padded_time_broadcast() {
2457 verify_switch::<m![P # 8, Q # 32], m![C # 16], m![P # 8, Q # 32 / 4, X], m![C # 16, Q # 32 % 4]>(
2458 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2459 );
2460 }
2461
2462 #[test]
2463 #[should_panic(expected = "Switch broadcast axes must each be used exactly once in OutSlice")]
2464 fn duplicate_broadcast_symbol() {
2465 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Y], m![C, A % 2, B % 2]>(
2466 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2467 );
2468 }
2469
2470 #[test]
2471 #[should_panic(
2472 expected = "Switch broadcast axes must be new axes (not present in input Slice or Time)."
2473 )]
2474 fn broadcast_axis_from_time() {
2475 verify_switch::<m![A, B], m![C], m![A, B / 8, C], m![C, B % 8]>(&SwitchConfig::CustomBroadcast {
2476 ring_size: 8,
2477 });
2478 }
2479
2480 #[test]
2481 #[should_panic(
2482 expected = "Switch broadcast axes must be new axes (not present in input Slice or Time)."
2483 )]
2484 fn inter_transpose() {
2485 verify_switch::<m![A, B], m![C], m![A, C, B / 8], m![B % 8]>(&SwitchConfig::CustomBroadcast {
2486 ring_size: 256,
2487 });
2488 }
2489
2490 #[test]
2491 fn partial_broadcast_replacement() {
2492 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, B % 2]>(
2494 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2495 );
2496 }
2497
2498 #[test]
2499 #[should_panic(expected = "Switch broadcast axis X in output Slice must not be padded.")]
2500 fn padded_broadcast_axis() {
2501 verify_switch::<m![A, B], m![C], m![A / 2, X # 8, B / 8, Y], m![C, A % 2, B % 8]>(
2502 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2503 );
2504 }
2505
2506 #[test]
2507 #[should_panic(
2508 expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2509 )]
2510 fn moved_axes_not_innermost() {
2511 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![A % 2, C, B % 2]>(
2512 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2513 );
2514 }
2515
2516 #[test]
2517 #[should_panic(
2518 expected = "Switch axes moving from input Slice to output Time must preserve their relative order from input Slice."
2519 )]
2520 fn reversed_order_in_time() {
2521 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 2, Z], m![C, B % 2, A % 2]>(
2523 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2524 );
2525 }
2526
2527 #[test]
2528 #[should_panic(
2529 expected = "Switch axes moving from input slice to output time must be at the output time innermost positions."
2530 )]
2531 fn outer_time_padding_mismatch() {
2532 verify_switch::<m![A, B], m![C # 32], m![A, B / 4, X], m![C # 16, B % 4]>(
2533 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2534 );
2535 }
2536
2537 #[test]
2538 fn broadcast_replace_in_place() {
2539 verify_switch::<m![R, P], m![C], m![R, X], m![C]>(&SwitchConfig::CustomBroadcast { ring_size: 4 });
2540 }
2541
2542 #[test]
2543 fn broadcast_with_moved_axis() {
2544 axes![A = 2, B = 2, C = 2, D = 2, E = 2, X = 2];
2545 verify_switch::<m![A, B, C, D, E], m![1], m![E, B, X, A, D], m![C]>(
2546 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2547 )
2548 }
2549 }
2550
2551 mod slicing {
2552 use super::*;
2553
2554 #[test]
2555 fn slicing() {
2556 verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4 = 3]>(
2557 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2558 );
2559 }
2560
2561 #[test]
2562 fn slicing_with_broadcast() {
2563 verify_switch::<m![A, B], m![C], m![A / 2, Y, B / 4, X], m![C, A % 2, B % 4 = 3]>(
2564 &SwitchConfig::CustomBroadcast { ring_size: 32 },
2565 );
2566 }
2567
2568 #[test]
2569 fn single_axis_slicing() {
2570 verify_switch::<m![S], m![C], m![S / 4, X], m![C, S % 4 = 3]>(&SwitchConfig::CustomBroadcast {
2571 ring_size: 4,
2572 });
2573 }
2574
2575 #[test]
2576 fn padded_broadcast_slicing() {
2577 verify_switch::<m![P # 8, Q # 32], m![C], m![P # 8, Q # 32 / 4, X], m![C, Q # 32 % 4 = 3]>(
2578 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2579 );
2580 }
2581
2582 #[test]
2583 #[should_panic(expected = "Switch broadcast axes in output time must come from input slice.")]
2584 fn wrong_axis_in_slicing() {
2585 verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B / 4 = 3]>(
2586 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2587 );
2588 }
2589 }
2590
2591 mod ring_size {
2592 use super::*;
2593
2594 #[test]
2595 #[should_panic(expected = "Switch ring size must be a power of 2, got 3")]
2596 fn non_power_of_two() {
2597 verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2598 ring_size: 3,
2599 });
2600 }
2601
2602 #[test]
2603 #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 4")]
2604 fn wrong_full_permutation() {
2605 verify_switch::<m![A, B], m![C], m![B % 4, B / 4, A % 4, A / 4], m![C]>(
2606 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2607 );
2608 }
2609
2610 #[test]
2611 #[should_panic(expected = "Switch ring size mismatch. Expected 16, got 256")]
2612 fn wrong_partial_permutation() {
2613 verify_switch::<m![A, B], m![C], m![A, B % 4, B / 4], m![C]>(&SwitchConfig::CustomBroadcast {
2614 ring_size: 256,
2615 });
2616 }
2617
2618 #[test]
2619 #[should_panic(expected = "Switch ring size mismatch. Expected 4, got 32")]
2620 fn wrong_broadcast() {
2621 verify_switch::<m![A, B], m![C], m![A, B / 4, X], m![C, B % 4]>(&SwitchConfig::CustomBroadcast {
2622 ring_size: 32,
2623 });
2624 }
2625
2626 #[test]
2627 #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 2")]
2628 fn wrong_permutation_above_broadcast() {
2629 verify_switch::<m![R, Q, P], m![C], m![Q, R, P / 2, Y], m![C, P % 2]>(
2630 &SwitchConfig::CustomBroadcast { ring_size: 2 },
2631 );
2632 }
2633
2634 #[test]
2635 #[should_panic(expected = "Switch ring size mismatch. Expected 4, got 16")]
2636 fn wrong_padded_broadcast() {
2637 verify_switch::<m![A # 16, B # 16], m![C], m![A # 16, B # 16 / 4, X], m![C, B # 16 % 4]>(
2638 &SwitchConfig::CustomBroadcast { ring_size: 16 },
2639 );
2640 }
2641
2642 #[test]
2643 #[should_panic(expected = "Switch ring size mismatch. Expected 16, got 256")]
2644 fn wrong_padded_permutation() {
2645 verify_switch::<m![A # 16, B # 16], m![C], m![A # 16, B # 16 % 4, B # 16 / 4], m![C]>(
2646 &SwitchConfig::CustomBroadcast { ring_size: 256 },
2647 );
2648 }
2649
2650 #[test]
2651 #[should_panic(expected = "Switch ring size mismatch. Expected 256, got 4")]
2652 fn wrong_padded_permutation_above_broadcast() {
2653 verify_switch::<m![P # 8, Q # 32], m![C], m![Q # 32 / 4, P # 8, X], m![C, Q # 32 % 4]>(
2654 &SwitchConfig::CustomBroadcast { ring_size: 4 },
2655 );
2656 }
2657 }
2658 }
2659 }
2660}