1use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use super::ext::IndexExt;
7use abi_stable::std_types::RBox;
8use furiosa_mapping_macro::primitive;
9use furiosa_mapping_types::{Atom, Ident, Index, M, Mapping, PaddingKind, Term};
10
11pub trait AxisName: Debug + Clone + 'static {
13 const NAME: Ident;
15 const SIZE: usize;
17}
18
19#[macro_export]
28macro_rules! axes {
29 (
30 $( $name:ident = $size:literal ),* $(,)?
31 ) => {
32 $(
33 #[allow(non_camel_case_types)]
34 #[derive(Debug, Clone)]
35 pub struct $name;
36 impl AxisName for $name {
37 const NAME: Ident = Ident::new(::core::stringify!($name));
38 const SIZE: usize = $size;
39 }
40 )*
41 };
42}
43
44#[primitive(mapping::Identity)]
46#[derive(Debug, Clone)]
47pub struct Identity;
48
49impl M for Identity {
51 const SIZE: usize = 1;
52
53 fn to_value() -> Mapping {
54 Mapping::Identity
55 }
56
57 fn map(i: usize) -> Option<Index> {
58 if i == 0 { Some(Index::new()) } else { None }
59 }
60}
61#[primitive(mapping::Symbol)]
65#[derive(Debug, Clone)]
66pub struct Symbol<S: AxisName> {
67 _marker: std::marker::PhantomData<S>,
68}
69
70impl<S: AxisName> M for Symbol<S> {
72 const SIZE: usize = S::SIZE;
73
74 fn to_value() -> Mapping {
75 Mapping::Symbol {
76 symbol: S::NAME,
77 size: S::SIZE,
78 }
79 }
80
81 fn map(i: usize) -> Option<Index> {
82 if i < S::SIZE {
83 let mut index = Index::new();
84 Index::add_term(
85 &mut index,
86 Term {
87 inner: Atom::Symbol {
88 symbol: S::NAME,
89 size: S::SIZE,
90 },
91 stride: 1,
92 modulo: S::SIZE,
93 },
94 i,
95 );
96 Some(index)
97 } else {
98 None
99 }
100 }
101}
102#[primitive(mapping::Stride)]
106#[derive(Debug, Clone)]
107pub struct Stride<L, const SIZE: usize> {
108 _marker: PhantomData<L>,
109}
110impl<L, const SIZE: usize> M for Stride<L, SIZE>
112where
113 L: M,
114{
115 const SIZE: usize = {
116 assert!(L::SIZE % SIZE == 0, "Stride size must divide the original size");
117 L::SIZE / SIZE
118 };
119
120 fn to_value() -> Mapping {
121 Mapping::Stride {
122 inner: RBox::new(L::to_value()),
123 stride: SIZE,
124 }
125 }
126
127 fn map(i: usize) -> Option<Index> {
128 if i < Self::SIZE { L::map(i * SIZE) } else { None }
129 }
130}
131#[primitive(mapping::Modulo)]
135#[derive(Debug, Clone)]
136pub struct Modulo<L, const SIZE: usize> {
137 _marker: PhantomData<L>,
138}
139impl<L, const SIZE: usize> M for Modulo<L, SIZE>
141where
142 L: M,
143{
144 const SIZE: usize = {
145 assert!(L::SIZE % SIZE == 0, "Modulo size must divide the original size");
146 SIZE
147 };
148
149 fn to_value() -> Mapping {
150 Mapping::Modulo {
151 inner: RBox::new(L::to_value()),
152 modulo: SIZE,
153 }
154 }
155
156 fn map(i: usize) -> Option<Index> {
157 if i < Self::SIZE { L::map(i % L::SIZE) } else { None }
158 }
159}
160#[primitive(mapping::Resize)]
164#[derive(Debug, Clone)]
165pub struct Resize<L, const SIZE: usize> {
166 _marker: PhantomData<L>,
167}
168impl<L, const SIZE: usize> M for Resize<L, SIZE>
170where
171 L: M,
172{
173 const SIZE: usize = SIZE;
174
175 fn to_value() -> Mapping {
176 Mapping::Resize {
177 inner: RBox::new(L::to_value()),
178 resize: SIZE,
179 }
180 }
181
182 fn map(i: usize) -> Option<Index> {
183 if i < SIZE { L::map(i) } else { None }
184 }
185}
186#[primitive(mapping::Padding)]
190#[derive(Debug, Clone)]
191pub struct Padding<L, const SIZE: usize> {
192 _marker: PhantomData<L>,
193}
194impl<L, const SIZE: usize> M for Padding<L, SIZE>
196where
197 L: M,
198{
199 const SIZE: usize = SIZE;
200
201 fn to_value() -> Mapping {
202 Mapping::Padding {
203 inner: RBox::new(L::to_value()),
204 padding: SIZE,
205 kind: PaddingKind::Top,
206 }
207 }
208
209 fn map(i: usize) -> Option<Index> {
210 L::map(i)
211 }
212}
213#[primitive(mapping::Pair)]
217#[derive(Debug, Clone)]
218pub struct Pair<L, R> {
219 _marker: PhantomData<(L, R)>,
220}
221impl<L, R> M for Pair<L, R>
223where
224 L: M,
225 R: M,
226{
227 const SIZE: usize = L::SIZE * R::SIZE;
228
229 fn to_value() -> Mapping {
230 Mapping::Pair {
231 left: RBox::new(L::to_value()),
232 right: RBox::new(R::to_value()),
233 }
234 }
235
236 fn map(i: usize) -> Option<Index> {
237 let mut l = L::map(i / R::SIZE)?;
238 let r = R::map(i % R::SIZE)?;
239 Index::add(&mut l, r);
240 Some(l)
241 }
242}
243pub fn assert_div<I: M, E: M, E2: M, const LEN: usize>() {
247 use crate::{DivisionExt, FMappingExt, MappingExt};
248 let e2 = if LEN == 1 {
249 E::to_value()
250 .factorize()
251 .divide_relaxed(I::to_value().factorize())
252 .exact()
253 .into_result()
254 .expect("[assert_div] failed to split by the index expression")
255 .dividend_residue
256 } else {
257 E::to_value()
258 .factorize()
259 .divide_span(I::to_value().factorize(), LEN)
260 .into_result()
261 .and_then(|d| d.exact().into_result())
262 .expect("[assert_div] failed to split by the index expression")
263 .dividend_residue
264 };
265 assert_eq!(
266 e2.clone(), E2::to_value().factorize(),
268 "[assert_div] inconsistent view type after split"
269 );
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::{DivisionExt, FMappingExt, IndexExt, MappingExt, StrictDivisionExt};
276 use abi_stable::std_types::{ROption, RResult, Tuple2};
277 use furiosa_mapping_types::{
278 BlockBounds, DivisionError, DivisionSide, DivisionTerm, FMapping, Factor, IndexValueError, RSortedMap,
279 TermBounds,
280 };
281 use furiosa_opt_macro::m;
282
283 #[test]
285 fn unittest_split_at_1() {
286 axes![A = 2, B = 2];
287 let f = <m![A # 4, B # 4]>::to_value().factorize();
288 assert_eq!(f.size(), 16);
289 let Tuple2(outer, inner) = f.split_at(8);
290 assert_eq!(inner.size(), 8);
291 assert_eq!(outer.size(), 2);
292 assert_eq!(inner, <m![A, B # 4]>::to_value().factorize());
293 assert_eq!(outer, <m![1 # 2]>::to_value().factorize());
294 }
295
296 #[test]
298 fn unittest_split_at_2() {
299 axes![A = 2, B = 2, C = 2];
300 let f = <m![A # 4, B, C]>::to_value().factorize();
301 let Tuple2(outer, inner) = f.split_at(8);
302 assert_eq!(inner.size(), 8);
303 assert_eq!(outer.size(), 2);
304 assert_eq!(inner, <m![A, B, C]>::to_value().factorize());
305 assert_eq!(outer, <m![1 # 2]>::to_value().factorize());
306 }
307
308 #[test]
310 fn unittest_split_at_3() {
311 axes![A = 2, B = 2, C = 2, D = 2];
312 let f = <m![A # 4, B # 4, C # 16, D # 4]>::to_value().factorize();
313 let Tuple2(outer, inner) = f.split_at(16);
314 assert_eq!(inner.size(), 16);
315 assert_eq!(outer.size(), 64);
316 assert_eq!(inner, <m![C # 4, D # 4]>::to_value().factorize());
317 assert_eq!(outer, <m![A # 4, B # 4, 1 # 4]>::to_value().factorize());
318 }
319
320 #[test]
325 fn unittest_split_at_4() {
326 axes![A = 4, B = 4];
327 let f = <m![B, A # 9]>::to_value().factorize();
328 assert_eq!(f.size(), 36);
329
330 let Tuple2(outer, inner) = f.split_at(4);
331 assert_eq!(inner.size(), 4);
332 assert_eq!(outer.size(), 9);
333 assert_eq!(inner, <m![[B, A # 9] % 4]>::to_value().factorize());
334 assert_eq!(outer, <m![[B, A # 9] / 4]>::to_value().factorize());
335 }
336
337 #[test]
340 #[should_panic(expected = "not divisible")]
341 fn unittest_split_at_5() {
342 axes![A = 6];
343 let f = <m![A # 9]>::to_value().factorize();
344 f.split_at(4);
345 }
346
347 #[test]
350 fn unittest_split_at_6() {
351 axes![A = 8];
352 let f = <m![A # 16]>::to_value().factorize();
353
354 let Tuple2(outer, inner) = f.split_at(4);
355 assert_eq!(inner.size(), 4);
356 assert_eq!(outer.size(), 4);
357 assert_eq!(inner, <m![A % 4]>::to_value().factorize());
358 assert_eq!(outer, <m![1 # 2, A / 4 % 2]>::to_value().factorize());
359 }
360
361 #[test]
362 fn unittest_split_at_7() {
363 axes![A = 2, B = 2, C = 2, D = 2];
364 let f = <m![A, B # 4, [C, D] # 8]>::to_value().factorize();
365 let Tuple2(outer, inner) = f.split_at(16);
366 assert_eq!(inner.size(), 16);
367 assert_eq!(outer.size(), 4);
368 assert_eq!(inner, <m![B, 1 # 2, C, D]>::to_value().factorize());
369 assert_eq!(outer, <m![A, 1 # 2]>::to_value().factorize());
370 }
371
372 #[test]
374 fn unittest_round_trip_1() {
375 axes![A = 4, B = 2];
376 let f = <m![A, B]>::to_value().factorize();
377 assert_eq!(f, f.to_mapping().factorize());
378 }
379
380 #[test]
381 fn unittest_round_trip_2() {
382 axes![A = 2, B = 2];
383 let f = <m![A # 4, B # 4]>::to_value().factorize();
384 assert_eq!(f, f.to_mapping().factorize());
385 }
386
387 #[test]
388 fn unittest_round_trip_3() {
389 axes![A = 8];
390 let f = <m![A # 16]>::to_value().factorize();
391 assert_eq!(f, f.to_mapping().factorize());
392 }
393
394 #[test]
395 fn unittest_round_trip_4() {
396 axes![A = 4, B = 4];
397 let f = <m![B, A # 9]>::to_value().factorize();
398 assert_eq!(f, f.to_mapping().factorize());
399 }
400
401 #[test]
402 fn unittest_round_trip_5() {
403 axes![A = 2, B = 2, C = 2, D = 2];
404 let f = <m![A # 4, B # 4, C # 16, D # 4]>::to_value().factorize();
405 assert_eq!(f, f.to_mapping().factorize());
406 }
407
408 #[test]
409 fn unittest_round_trip_6() {
410 axes![A = 2, B = 2, C = 2, D = 2];
411 let f = <m![A, B # 4, [C, D] # 8]>::to_value().factorize();
412 assert_eq!(f, f.to_mapping().factorize());
413 }
414
415 #[test]
416 fn unittest_round_trip_7() {
417 axes![A = 8];
418 let f = <m![A % 4]>::to_value().factorize();
419 assert_eq!(f, f.to_mapping().factorize());
420 }
421
422 #[test]
423 fn unittest_round_trip_8() {
424 axes![A = 8];
425 let f = <m![A / 4 % 2]>::to_value().factorize();
426 assert_eq!(f, f.to_mapping().factorize());
427 }
428
429 #[test]
430 fn unittest_round_trip_bottom_padding() {
431 let f = FMapping(
432 vec![Factor::Padding {
433 size: 16,
434 kind: PaddingKind::Bottom,
435 }]
436 .into(),
437 );
438 assert_eq!(f, f.to_mapping().factorize());
439 }
440
441 #[test]
442 fn unittest_modulo_one_identity() {
443 axes![A = 8];
444 let f = <m![A % 1]>::to_value().factorize();
445 let id = <m![1]>::to_value().factorize();
446 assert_eq!(f, id);
447 }
448
449 #[test]
450 fn unittest_divide_infinite_recursion() {
451 axes![A = 8];
452 let dividend = <m![A / 4]>::to_value().factorize();
453 let divisor = <m![A % 4]>::to_value().factorize();
454 assert!(dividend.divide_relaxed(divisor).exact().is_err());
455 }
456
457 #[test]
461 fn unittest_divide_padded_axis_split() {
462 axes![A = 512, R = 15];
463 let dst = <m![1 # 2, A / 8, R # 16 / 4, A % 8, R # 16 % 4]>::to_value();
464 let src = <m![A, R # 16]>::to_value();
465 let result = dst.divide_relaxed(&src).exact();
466 assert!(
467 result.is_ok(),
468 "Dividing dst by src should succeed, but got: {result:?}"
469 );
470 }
471
472 #[test]
473 fn unittest_divide_span_nontrivial_tile() {
474 axes![A = 8, B = 4];
475 let dividend = <m![A, B]>::to_value().factorize();
476 let divisor = <m![B]>::to_value().factorize();
477 let division = dividend.clone().divide_span(divisor.clone(), 2).unwrap();
478 let [term]: [DivisionTerm; 1] = division.division_terms().to_vec().try_into().unwrap();
479
480 assert_eq!(
481 *division.residue(DivisionSide::Dividend),
482 <m![A, B = 2 # 4]>::to_value().factorize()
483 );
484 assert_eq!(
485 *division.residue(DivisionSide::Divisor),
486 <m![1 # 4]>::to_value().factorize()
487 );
488 let Factor::Term {
489 inner: expected_term,
490 resize: expected_resize,
491 } = <m![B]>::to_value().factorize().into_factor()
492 else {
493 panic!("single-axis factorization must yield a term");
494 };
495 assert_eq!(term.term, expected_term);
496 assert_eq!(term.dividend_stride, 1);
497 assert_eq!(term.divisor_stride, 1);
498 assert_eq!(term.resize, expected_resize);
499 }
500
501 #[test]
502 fn unittest_div_span_exact_nontrivial_tile() {
503 axes![A = 8, B = 4];
504 let dividend = <m![A, B]>::to_value().factorize();
505 let divisor = <m![B]>::to_value().factorize();
506 let division = dividend.divide_span(divisor, 2).unwrap().exact().unwrap();
507
508 assert_eq!(division.dividend_residue, <m![A, B = 2 # 4]>::to_value().factorize());
509 assert_eq!(division.divisor_residue, <m![1 # 4]>::to_value().factorize());
510 }
511
512 #[test]
513 fn unittest_div_span_exact_unmatched() {
514 axes![A = 8, B = 4, C = 2];
515 let dividend = <m![A, B]>::to_value().factorize();
516 let divisor = <m![C]>::to_value().factorize();
517 assert!(dividend.divide_span(divisor, 2).and_then(|d| d.exact()).is_err());
518 }
519
520 #[test]
521 fn unittest_divide_span_unmatched() {
522 axes![A = 8, B = 4, C = 2];
523 let dividend = <m![A, B]>::to_value().factorize();
524 let divisor = <m![C]>::to_value().factorize();
525 let division = dividend.divide_span(divisor.clone(), 2).unwrap();
526
527 assert!(division.division_terms().is_empty());
528 assert_eq!(division.dividend_residue, <m![A, B]>::to_value().factorize());
529 assert_eq!(division.divisor_residue, divisor);
530 }
531
532 #[test]
535 fn normalize_merges_padded_stride_modulo_split() {
536 axes![R = 14];
537 let r_padded = <m![R # 16]>::to_value();
538 let outer = Mapping::Stride {
539 inner: RBox::new(r_padded.clone()),
540 stride: 4,
541 };
542 let inner = Mapping::Modulo {
543 inner: RBox::new(r_padded),
544 modulo: 4,
545 };
546 let paired = outer.pair(inner);
547 let factorized = paired.factorize();
548 let normalized = factorized.normalize();
549 let expected = <m![R # 16]>::to_value().factorize();
550
551 assert_eq!(
552 normalized, expected,
553 "Normalized mapping should merge complementary stride/modulo splits of padded mappings, but got: {normalized}"
554 );
555 }
556
557 #[test]
560 fn unittest_divide_partial() {
561 axes![A = 512, B = 4, C = 8];
562 let a = <m![A / 8, B]>::to_value().factorize();
563 let b = <m![B, C]>::to_value().factorize();
564 let division = a.clone().divide_strict(b.clone());
565 assert_eq!(division.division_terms().len(), 1);
566 assert_eq!(
567 division.remainder(DivisionSide::Dividend),
568 <m![A / 8]>::to_value().factorize()
569 );
570 assert_eq!(
571 division.remainder(DivisionSide::Divisor),
572 <m![C]>::to_value().factorize()
573 );
574 }
575
576 #[test]
577 fn unittest_divide_all_matched() {
578 axes![A = 512, B = 4];
579 let a = <m![A / 8, B]>::to_value().factorize();
580 let b = <m![B]>::to_value().factorize();
581 assert!(a.clone().divide_relaxed(b.clone()).exact().is_ok());
582 let division = a.clone().divide_strict(b.clone());
583 assert_eq!(division.division_terms().len(), 1);
584 assert_eq!(
585 division.remainder(DivisionSide::Dividend),
586 <m![A / 8]>::to_value().factorize()
587 );
588 assert_eq!(division.remainder(DivisionSide::Divisor), FMapping::new());
589 }
590
591 #[test]
592 fn unittest_divide_nothing_matched() {
593 axes![A = 512, C = 8];
594 let a = <m![A / 8]>::to_value().factorize();
595 let b = <m![C]>::to_value().factorize();
596 let division = a.clone().divide_strict(b.clone());
597 assert_eq!(division.division_terms().len(), 0);
598 assert_eq!(division.remainder(DivisionSide::Dividend), a);
599 assert_eq!(division.remainder(DivisionSide::Divisor), b);
600 }
601
602 #[test]
603 fn unittest_divide_multiple_matched() {
604 axes![A = 512, B = 4, C = 8, D = 2];
605 let a = <m![A / 4, B, C]>::to_value().factorize();
606 let b = <m![B, C, D]>::to_value().factorize();
607 let division = a.clone().divide_strict(b.clone());
608 assert_eq!(division.division_terms().len(), 2);
609 assert_eq!(
610 division.remainder(DivisionSide::Dividend),
611 <m![A / 4]>::to_value().factorize()
612 );
613 assert_eq!(
614 division.remainder(DivisionSide::Divisor),
615 <m![D]>::to_value().factorize()
616 );
617 }
618
619 #[test]
620 fn unittest_divide_padded_by_self() {
621 axes![A = 3];
622 let a = <m![A # 16]>::to_value().factorize();
623 let b = <m![A # 16]>::to_value().factorize();
624 assert!(a.clone().divide_relaxed(b.clone()).exact().is_ok());
625 let division = a.clone().divide_strict(b.clone());
626 assert_eq!(division.division_terms().len(), 1);
627 }
628
629 #[test]
630 fn unittest_div_exact_fails_on_partial() {
631 axes![A = 512, B = 4, C = 8];
632 let dividend = <m![A / 8, B]>::to_value().factorize();
633 let divisor = <m![B, C]>::to_value().factorize();
634 assert!(dividend.divide_relaxed(divisor).exact().is_err());
635 }
636
637 #[test]
638 fn unittest_divide_span_ne1_unmatched() {
639 axes![A = 8];
640 let a = <m![A / 4]>::to_value().factorize();
641 let b = <m![A % 4]>::to_value().factorize();
642 let division = a.clone().divide_span(b.clone(), 2).unwrap();
643 assert_eq!(division.division_terms().len(), 0);
644 assert_eq!(division.dividend_residue, a);
645 assert_eq!(division.divisor_residue, b);
646 }
647
648 #[test]
651 fn unittest_terms_with_stride() {
652 axes![A = 512, B = 4, C = 8];
653 let division = <m![A / 8, B]>::to_value()
654 .factorize()
655 .divide_strict(<m![B, C]>::to_value().factorize());
656 let q = division.residue(DivisionSide::Dividend).terms_with_stride();
657 assert_eq!(q.len(), 1);
658 assert_eq!((q[0].resize, q[0].stride), (64, 4));
659 let r = division.residue(DivisionSide::Divisor).terms_with_stride();
660 assert_eq!(r.len(), 1);
661 assert_eq!((r[0].resize, r[0].stride), (8, 1));
662 }
663
664 #[test]
667 fn unittest_remainder_single() {
668 axes![A = 512, B = 4, C = 8];
669 let a = <m![A / 8, B]>::to_value().factorize();
670 let b = <m![B, C]>::to_value().factorize();
671 let division = a.clone().divide_strict(b.clone());
672 assert_eq!(
673 division.remainder(DivisionSide::Dividend),
674 <m![A / 8]>::to_value().factorize()
675 );
676 assert_eq!(
677 division.remainder(DivisionSide::Divisor),
678 <m![C]>::to_value().factorize()
679 );
680 }
681
682 #[test]
683 fn unittest_remainder_multiple() {
684 axes![A = 512, B = 4, C = 8, D = 2];
685 let a = <m![A / 4, B, C]>::to_value().factorize();
686 let b = <m![B, C, D]>::to_value().factorize();
687 let division = a.clone().divide_strict(b.clone());
688 assert_eq!(
689 division.remainder(DivisionSide::Dividend),
690 <m![A / 4]>::to_value().factorize()
691 );
692 assert_eq!(
693 division.remainder(DivisionSide::Divisor),
694 <m![D]>::to_value().factorize()
695 );
696 }
697
698 #[test]
699 fn unittest_remainder_interleaved_3_matches() {
700 axes![A = 2, B = 3, C = 5, D = 7, E = 11];
701 let a = <m![A, B, C, D, E]>::to_value().factorize();
702 let b = <m![A, C, E]>::to_value().factorize();
703 let division = a.clone().divide_strict(b);
704 assert_eq!(
705 division.remainder(DivisionSide::Dividend),
706 <m![B, D]>::to_value().factorize()
707 );
708 }
709
710 #[test]
711 fn unittest_remainder_interleaved_partial() {
712 axes![A = 2, B = 3, C = 5, D = 7, E = 11, F = 13];
713 let a = <m![A, B, C, D, E]>::to_value().factorize();
714 let b = <m![B, D, F]>::to_value().factorize();
715 let division = a.clone().divide_strict(b.clone());
716 assert_eq!(
717 division.remainder(DivisionSide::Dividend),
718 <m![A, C, E]>::to_value().factorize()
719 );
720 assert_eq!(
721 division.remainder(DivisionSide::Divisor),
722 <m![F]>::to_value().factorize()
723 );
724 }
725
726 #[test]
727 fn unittest_remainder_preserves_padding() {
728 axes![A = 3, B = 4];
729 let a = <m![A # 16, B]>::to_value().factorize();
730 let division = a.clone().divide_strict(<m![B]>::to_value().factorize());
731 assert_eq!(
732 division.remainder(DivisionSide::Dividend),
733 <m![A # 16]>::to_value().factorize()
734 );
735 }
736
737 #[test]
738 fn unittest_remainder_nothing_matched() {
739 axes![A = 512, C = 8];
740 let a = <m![A / 8]>::to_value().factorize();
741 let b = <m![C]>::to_value().factorize();
742 let division = a.clone().divide_strict(b.clone());
743 assert_eq!(division.remainder(DivisionSide::Dividend), a);
744 assert_eq!(division.remainder(DivisionSide::Divisor), b);
745 }
746
747 #[test]
748 fn unittest_remainder_all_matched() {
749 axes![B = 4];
750 let a = <m![B]>::to_value().factorize();
751 let b = <m![B]>::to_value().factorize();
752 let division = a.clone().divide_strict(b.clone());
753 assert_eq!(division.remainder(DivisionSide::Dividend), FMapping::new());
754 assert_eq!(division.remainder(DivisionSide::Divisor), FMapping::new());
755 }
756
757 #[test]
758 fn unittest_remainder_removes_padding_hole() {
759 axes![A = 4];
760 let a = <m![A # 16]>::to_value().factorize();
761 let division = a.clone().divide_strict(<m![A]>::to_value().factorize());
762 assert_eq!(
763 division.remainder(DivisionSide::Dividend),
764 <m![1 # 4]>::to_value().factorize()
765 );
766 }
767
768 #[test]
769 fn unittest_remainder_non_divisible_padding() {
770 axes![A = 3];
771 let a = <m![A # 16]>::to_value().factorize();
772 let b = <m![A]>::to_value().factorize();
773 assert!(a.clone().divide_relaxed(b.clone()).exact().is_ok());
774 let division = a.clone().divide_strict(b.clone());
775 assert_eq!(division.division_terms().len(), 1);
776 assert_eq!(
777 division.remainder(DivisionSide::Dividend),
778 <m![A # 16 / 4]>::to_value().factorize()
779 );
780 assert_eq!(division.remainder(DivisionSide::Divisor), FMapping::new());
781 }
782
783 #[test]
784 fn unittest_division_remainder_accessor() {
785 axes![A = 3];
786 let dividend = <m![A # 16]>::to_value().factorize();
787 let divisor = <m![A]>::to_value().factorize();
788 let division = dividend.clone().divide_strict(divisor.clone());
789
790 assert_eq!(division.dividend(), ÷nd);
791 assert_eq!(division.divisor(), &divisor);
792 assert_eq!(
793 division.remainder(DivisionSide::Dividend),
794 <m![A # 16 / 4]>::to_value().factorize()
795 );
796 assert_eq!(division.remainder(DivisionSide::Divisor), FMapping::new());
797 }
798
799 #[test]
800 fn unittest_division_residue_accessor_is_mapping_safe() {
801 axes![A = 3];
802 let residue = <m![A # 16]>::to_value()
803 .factorize()
804 .divide_strict(<m![A]>::to_value().factorize())
805 .residue(DivisionSide::Dividend)
806 .clone();
807
808 assert_eq!(residue.clone().normalize(), residue);
809 }
810
811 #[test]
812 fn unittest_division_padding_bounds_compact_bounds() {
813 axes![A = 4];
814 let dividend = <m![A # 16]>::to_value().factorize();
815 let divisor = <m![A]>::to_value().factorize();
816 let division = dividend.divide_strict(divisor);
817 let bounds = division.bounds();
818 let [term]: [DivisionTerm; 1] = division.division_terms().to_vec().try_into().unwrap();
819
820 assert_eq!(
821 bounds,
822 vec![TermBounds {
823 term,
824 dividend: BlockBounds { min: 4, max: 16 },
825 divisor: BlockBounds { min: 4, max: 4 },
826 }]
827 );
828 }
829
830 #[test]
831 fn unittest_divide_allows_inexact_reconstruction() {
832 axes![A = 3];
833 let dividend = <m![A # 16]>::to_value().factorize();
834 let divisor = <m![A]>::to_value().factorize();
835 let division = dividend.divide_strict(divisor);
836
837 assert_eq!(division.division_terms().len(), 1);
838 }
839
840 #[test]
841 fn unittest_division_padding_bounds_allow_semantic_reconstruction() {
842 axes![A = 3];
843 let dividend = <m![A # 16]>::to_value().factorize();
844 let divisor = <m![A]>::to_value().factorize();
845 let division = dividend.divide_strict(divisor);
846 let [term]: [DivisionTerm; 1] = division.division_terms().to_vec().try_into().unwrap();
847
848 assert_eq!(
849 division.bounds(),
850 vec![TermBounds {
851 term,
852 dividend: BlockBounds { min: 4, max: 16 },
853 divisor: BlockBounds { min: 3, max: 3 },
854 }]
855 );
856 }
857
858 #[test]
859 fn unittest_division_padding_bounds_collapse_only_contiguous_padding() {
860 axes![A = 3, B = 2];
861 let dividend = <m![A # 4, B # 16]>::to_value().factorize();
862 let divisor = <m![A]>::to_value().factorize();
863 let division = dividend.divide_strict(divisor);
864 let [term]: [DivisionTerm; 1] = division.division_terms().to_vec().try_into().unwrap();
865
866 assert_eq!(
867 division.bounds(),
868 vec![TermBounds {
869 term,
870 dividend: BlockBounds { min: 4, max: 4 },
871 divisor: BlockBounds { min: 3, max: 3 },
872 }]
873 );
874 }
875
876 #[test]
877 fn unittest_division_padding_bounds_preserve_consecutive_padding_run() {
878 axes![A = 3];
879 let dividend = <m![A # 4 # 16]>::to_value().factorize();
880 let divisor = <m![A]>::to_value().factorize();
881 let division = dividend.divide_strict(divisor);
882 let [term]: [DivisionTerm; 1] = division.division_terms().to_vec().try_into().unwrap();
883
884 assert_eq!(
885 division.bounds(),
886 vec![TermBounds {
887 term,
888 dividend: BlockBounds { min: 4, max: 16 },
889 divisor: BlockBounds { min: 3, max: 3 },
890 }]
891 );
892 }
893
894 #[test]
895 fn unittest_division_residue_marks_removed_block_as_bottom_padding() {
896 axes![A = 3];
897 let dividend = <m![A # 16]>::to_value().factorize();
898 let divisor = <m![A]>::to_value().factorize();
899 let division = dividend.divide_strict(divisor);
900 let factors = division.dividend_residue.factors();
901
902 assert!(matches!(
903 &*factors,
904 [
905 Factor::Padding {
906 size: 3,
907 kind: PaddingKind::Bottom
908 },
909 Factor::Padding {
910 size: 16,
911 kind: PaddingKind::Top
912 }
913 ]
914 ));
915 }
916
917 #[test]
918 fn unittest_division_padding_bounds_multiple_terms_with_padding() {
919 axes![A = 4, B = 3];
920 let dividend = <m![A # 16, B]>::to_value().factorize();
921 let divisor = <m![A, B]>::to_value().factorize();
922 let division = dividend.divide_strict(divisor);
923 let bounds = division.bounds();
924 let [a_term, b_term]: [DivisionTerm; 2] = division.division_terms().to_vec().try_into().unwrap();
925
926 assert_eq!(
927 bounds,
928 vec![
929 TermBounds {
930 term: a_term,
931 dividend: BlockBounds { min: 4, max: 16 },
932 divisor: BlockBounds { min: 4, max: 4 },
933 },
934 TermBounds {
935 term: b_term,
936 dividend: BlockBounds { min: 3, max: 3 },
937 divisor: BlockBounds { min: 3, max: 3 },
938 },
939 ]
940 );
941 }
942
943 #[test]
944 fn unittest_division_padding_bounds_partial_division_only_reports_matched_terms() {
945 axes![A = 512, B = 4, C = 8, D = 2];
946 let dividend = <m![A / 4, B, C]>::to_value().factorize();
947 let divisor = <m![B, C, D]>::to_value().factorize();
948 let division = dividend.divide_strict(divisor);
949 let bounds = division.bounds();
950 let [b_term, c_term]: [DivisionTerm; 2] = division.division_terms().to_vec().try_into().unwrap();
951
952 assert_eq!(division.division_terms().len(), 2);
953 assert_eq!(
954 bounds,
955 vec![
956 TermBounds {
957 term: b_term,
958 dividend: BlockBounds { min: 4, max: 4 },
959 divisor: BlockBounds { min: 4, max: 4 },
960 },
961 TermBounds {
962 term: c_term,
963 dividend: BlockBounds { min: 8, max: 8 },
964 divisor: BlockBounds { min: 8, max: 8 },
965 },
966 ]
967 );
968 }
969
970 #[test]
973 fn unittest_split_at_non_divisible_padding() {
974 axes![A = 3];
975 let a = <m![A # 16]>::to_value().factorize();
976 let Tuple2(outer, inner) = a.split_at(4);
977 assert_eq!(inner, <m![A # 4]>::to_value().factorize());
978 assert_eq!(outer, <m![1 # 4]>::to_value().factorize());
979 }
980
981 #[test]
982 fn unittest_stride_partial_live_prefix_padding() {
983 axes![R = 5];
984 let fm = <m![R # 16 / 4]>::to_value().factorize();
985 assert_eq!(fm, <m![1 # 2, [R # 8] / 4]>::to_value().factorize());
986
987 assert_eq!(fm.eval(0).ident_value(Ident::R), RResult::ROk(0));
988 assert_eq!(fm.eval(1).ident_value(Ident::R), RResult::ROk(4));
989 assert_eq!(
990 fm.eval(2).ident_value(Ident::R),
991 RResult::RErr(IndexValueError::Invalid)
992 );
993 assert_eq!(
994 fm.eval(3).ident_value(Ident::R),
995 RResult::RErr(IndexValueError::Invalid)
996 );
997 }
998
999 #[test]
1000 fn unittest_stride_single_live_slot_collapses_to_padding() {
1001 axes![R = 3];
1002 let fm = <m![R # 64 / 8]>::to_value().factorize();
1003 assert_eq!(fm, <m![1 # 8]>::to_value().factorize());
1004
1005 assert_eq!(fm.eval(0).ident_value(Ident::R), RResult::ROk(0));
1006 assert_eq!(
1007 fm.eval(1).ident_value(Ident::R),
1008 RResult::RErr(IndexValueError::Invalid)
1009 );
1010 assert_eq!(
1011 fm.eval(7).ident_value(Ident::R),
1012 RResult::RErr(IndexValueError::Invalid)
1013 );
1014 }
1015
1016 #[test]
1017 fn unittest_mul_scales_single_live_slot_padding() {
1018 axes![R = 3, A = 5];
1019 let fm = <m![R # 16 / 4, A]>::to_value().factorize();
1020 assert_eq!(fm, <m![A # 20]>::to_value().factorize());
1021 }
1022
1023 #[test]
1024 fn unittest_mul_scales_single_live_slot_padding_large_stride() {
1025 axes![R = 3, A = 5];
1026 let fm = <m![R # 64 / 8, A]>::to_value().factorize();
1027 assert_eq!(fm, <m![A # 40]>::to_value().factorize());
1028 }
1029
1030 #[test]
1031 fn unittest_mul_scales_partial_live_prefix_padding() {
1032 axes![R = 5, A = 5];
1033 let fm = <m![R # 16 / 4, A]>::to_value().factorize();
1034 assert_eq!(fm, <m![[[[R # 8] / 4] # 4], A]>::to_value().factorize());
1035 }
1036
1037 #[test]
1038 fn unittest_recursive_stride_over_nested_padding_composite() {
1039 axes![A = 5];
1040 let fm = <m![[A # 16 / 4] / 2]>::to_value().factorize();
1041 assert_eq!(fm, <m![1 # 2]>::to_value().factorize());
1042 }
1043
1044 #[test]
1045 fn unittest_multi_char_ident() {
1046 axes![Batch = 4, Seq = 8, Hidden = 16];
1048 let f = <m![Batch, Seq, Hidden]>::to_value().factorize();
1049 assert_eq!(f.size(), 4 * 8 * 16);
1050
1051 assert_eq!(<m![Batch]>::SIZE, 4);
1053 assert_eq!(<m![Seq]>::SIZE, 8);
1054 assert_eq!(<m![Hidden]>::SIZE, 16);
1055
1056 let idents = f.idents();
1058 assert_eq!(idents.len(), 3);
1059 assert_eq!(idents[0].as_str(), "Hidden");
1060 assert_eq!(idents[1].as_str(), "Seq");
1061 assert_eq!(idents[2].as_str(), "Batch");
1062 }
1063
1064 #[test]
1065 fn unittest_multi_char_ident_with_ops() {
1066 axes![Abxcbjhkdfhjdkf = 32];
1068 let f = <m![Abxcbjhkdfhjdkf / 8]>::to_value().factorize();
1069 assert_eq!(f.size(), 4);
1070
1071 let f2 = <m![Abxcbjhkdfhjdkf % 8]>::to_value().factorize();
1072 assert_eq!(f2.size(), 8);
1073
1074 let f3 = <m![Abxcbjhkdfhjdkf # 64]>::to_value().factorize();
1075 assert_eq!(f3.size(), 64);
1076
1077 let idents = f.idents();
1079 assert_eq!(idents[0].as_str(), "Abxcbjhkdfhjdkf");
1080 }
1081
1082 #[test]
1083 fn unittest_multi_char_ident_mixed() {
1084 axes![A = 2, Batch = 4, C = 8];
1086 let f = <m![A, Batch, C]>::to_value().factorize();
1087 assert_eq!(f.size(), 2 * 4 * 8);
1088
1089 let idents = f.idents();
1090 assert_eq!(idents.len(), 3);
1091 }
1092
1093 #[test]
1094 fn unittest_multi_char_ident_underscore() {
1095 axes![ASDKNASGDHJKAWD_CXVXCKVHSDF = 16];
1097 let f = <m![ASDKNASGDHJKAWD_CXVXCKVHSDF]>::to_value().factorize();
1098 assert_eq!(f.size(), 16);
1099 assert_eq!(f.idents()[0].as_str(), "ASDKNASGDHJKAWD_CXVXCKVHSDF");
1100 }
1101
1102 #[test]
1103 fn unittest_find_symbol_size_in_composite() {
1104 axes![R = 13, B = 2];
1105 let fm = <m![R # 16 % 4, B]>::to_value().factorize();
1106 assert_eq!(fm.find_symbol_size(Ident::R), ROption::RSome(13));
1107 assert_eq!(fm.find_symbol_size(Ident::B), ROption::RSome(2));
1108 assert_eq!(fm.find_symbol_size(Ident::A), ROption::RNone);
1109 }
1110
1111 #[test]
1115 fn add_term_composite_stride_recursion() {
1116 axes![R = 3];
1117 type P = m![R # 64 / 8];
1118 let fm = P::to_value().factorize();
1119
1120 for i in 0..P::SIZE {
1122 match P::map(i) {
1123 Some(idx) => assert_eq!(idx.ident_value(Ident::R), fm.eval(i).ident_value(Ident::R), "pos {i}"),
1124 None => assert!(
1125 fm.eval(i).ident_value(Ident::R).is_err(),
1126 "pos {i}: map returned None but eval succeeded"
1127 ),
1128 }
1129 }
1130
1131 assert_eq!(fm.eval(0).ident_value(Ident::R), RResult::ROk(0));
1133 assert_eq!(
1135 fm.eval(1).ident_value(Ident::R),
1136 RResult::RErr(IndexValueError::Invalid)
1137 );
1138
1139 {
1141 axes![R = 50];
1142 type Q = m![R # 64 / 8];
1143 let fm2 = Q::to_value().factorize();
1144 for i in 0..Q::SIZE {
1145 match Q::map(i) {
1146 Some(idx) => assert_eq!(
1147 idx.ident_value(Ident::R),
1148 fm2.eval(i).ident_value(Ident::R),
1149 "R=50 pos {i}"
1150 ),
1151 None => assert!(
1152 fm2.eval(i).ident_value(Ident::R).is_err(),
1153 "R=50 pos {i}: map returned None but eval succeeded"
1154 ),
1155 }
1156 }
1157 assert_eq!(fm2.eval(0).ident_value(Ident::R), RResult::ROk(0));
1159 assert_eq!(fm2.eval(1).ident_value(Ident::R), RResult::ROk(8));
1160 assert_eq!(fm2.eval(6).ident_value(Ident::R), RResult::ROk(48));
1161 assert_eq!(
1163 fm2.eval(7).ident_value(Ident::R),
1164 RResult::RErr(IndexValueError::Invalid)
1165 );
1166 }
1167 }
1168
1169 #[test]
1172 fn add_term_composite_stride_recursion_mixed_idents() {
1173 axes![A = 2, B = 3];
1174 type P = m![[A, B] # 10 / 2];
1175 let fm = P::to_value().factorize();
1176
1177 for i in 0..P::SIZE {
1178 match P::map(i) {
1179 Some(idx) => {
1180 assert_eq!(idx.ident_value(Ident::A), fm.eval(i).ident_value(Ident::A), "A pos {i}");
1181 assert_eq!(idx.ident_value(Ident::B), fm.eval(i).ident_value(Ident::B), "B pos {i}");
1182 }
1183 None => {
1184 assert!(fm.eval(i).ident_value(Ident::A).is_err(), "A pos {i}");
1185 assert!(fm.eval(i).ident_value(Ident::B).is_err(), "B pos {i}");
1186 }
1187 }
1188 }
1189
1190 assert_eq!(fm.eval(0).ident_value(Ident::A), RResult::ROk(0));
1191 assert_eq!(fm.eval(0).ident_value(Ident::B), RResult::ROk(0));
1192 assert_eq!(fm.eval(1).ident_value(Ident::A), RResult::ROk(0));
1193 assert_eq!(fm.eval(1).ident_value(Ident::B), RResult::ROk(2));
1194 assert_eq!(fm.eval(2).ident_value(Ident::A), RResult::ROk(1));
1195 assert_eq!(fm.eval(2).ident_value(Ident::B), RResult::ROk(1));
1196 assert_eq!(
1197 fm.eval(3).ident_value(Ident::A),
1198 RResult::RErr(IndexValueError::Invalid)
1199 );
1200 assert_eq!(
1201 fm.eval(3).ident_value(Ident::B),
1202 RResult::RErr(IndexValueError::Invalid)
1203 );
1204 assert_eq!(
1205 fm.eval(4).ident_value(Ident::A),
1206 RResult::RErr(IndexValueError::Invalid)
1207 );
1208 assert_eq!(
1209 fm.eval(4).ident_value(Ident::B),
1210 RResult::RErr(IndexValueError::Invalid)
1211 );
1212 }
1213
1214 #[test]
1215 fn ident_value_rejects_composite_terms() {
1216 axes![A = 2, B = 3];
1217 let term = <m![[A, B] # 10 / 2]>::to_value()
1218 .factorize()
1219 .terms_with_stride()
1220 .into_iter()
1221 .next()
1222 .unwrap()
1223 .term;
1224 let index = Index(RResult::ROk(RSortedMap::from_iter([(term, 0)])));
1225 assert_eq!(
1226 index.ident_value(Ident::A),
1227 RResult::RErr(IndexValueError::NonFlattened)
1228 );
1229 }
1230
1231 #[test]
1232 fn ident_value_treats_absent_as_zero_and_invalid_as_error() {
1233 axes![A = 4];
1234 let mut index = Index::new();
1235 index.add_term(
1236 Term {
1237 inner: Atom::Symbol {
1238 symbol: Ident::A,
1239 size: 4,
1240 },
1241 stride: 1,
1242 modulo: 4,
1243 },
1244 2,
1245 );
1246 assert_eq!(index.ident_value(Ident::B), RResult::ROk(0));
1247
1248 let invalid = <m![A # 8]>::to_value().factorize().eval(5);
1249 assert_eq!(invalid.ident_value(Ident::A), RResult::RErr(IndexValueError::Invalid));
1250 }
1251
1252 #[test]
1255 fn unittest_divide_relaxed_uses_top_padding() {
1256 axes![A = 4];
1257 let division = <m![A # 16]>::to_value()
1258 .factorize()
1259 .divide_relaxed(<m![A]>::to_value().factorize());
1260 assert_eq!(
1261 division.dividend_residue,
1262 FMapping(
1263 vec![Factor::Padding {
1264 size: 16,
1265 kind: PaddingKind::Top,
1266 }]
1267 .into()
1268 )
1269 );
1270 assert_eq!(division.divisor_residue, <m![1 # 4]>::to_value().factorize());
1271 }
1272
1273 #[test]
1274 fn unittest_exact_on_relaxed() {
1275 axes![A = 8, B = 4];
1276 let ok = <m![A, B]>::to_value()
1277 .factorize()
1278 .divide_relaxed(<m![B]>::to_value().factorize())
1279 .exact()
1280 .unwrap();
1281 assert_eq!(ok.dividend_residue, <m![A, 1 # 4]>::to_value().factorize());
1282 assert_eq!(ok.divisor_residue, <m![1 # 4]>::to_value().factorize());
1283
1284 let err = <m![A]>::to_value()
1285 .factorize()
1286 .divide_relaxed(<m![B]>::to_value().factorize())
1287 .exact();
1288 assert!(matches!(err, RResult::RErr(DivisionError::DivisorTermCannotDivide)));
1289 }
1290
1291 #[test]
1292 fn unittest_padding_same_kind_merge() {
1293 let mut f = FMapping::new();
1294 f = f.padding(4, PaddingKind::Bottom);
1295 f = f.padding(8, PaddingKind::Bottom);
1296 assert_eq!(
1297 f,
1298 FMapping(
1299 vec![Factor::Padding {
1300 size: 8,
1301 kind: PaddingKind::Bottom,
1302 }]
1303 .into()
1304 )
1305 );
1306
1307 let mut f = FMapping::new();
1308 f = f.padding(4, PaddingKind::Top);
1309 f = f.padding(8, PaddingKind::Top);
1310 assert_eq!(
1311 f,
1312 FMapping(
1313 vec![Factor::Padding {
1314 size: 8,
1315 kind: PaddingKind::Top,
1316 }]
1317 .into()
1318 )
1319 );
1320
1321 let mut f = FMapping::new();
1322 f = f.padding(4, PaddingKind::Bottom);
1323 f = f.padding(8, PaddingKind::Top);
1324 assert_eq!(
1325 f,
1326 FMapping(
1327 vec![
1328 Factor::Padding {
1329 size: 4,
1330 kind: PaddingKind::Bottom,
1331 },
1332 Factor::Padding {
1333 size: 8,
1334 kind: PaddingKind::Top,
1335 },
1336 ]
1337 .into()
1338 )
1339 );
1340 }
1341
1342 #[test]
1346 fn unittest_bounds_skips_inner_bottom() {
1347 axes![A = 3, B = 2, C = 3];
1348 let dividend = <m![A # 4, B, C # 8]>::to_value().factorize();
1349 let divisor = <m![A, C]>::to_value().factorize();
1350 let division = dividend.divide_strict(divisor);
1351 let bounds = division.bounds();
1352 let [a_term, c_term]: [DivisionTerm; 2] = division.division_terms().to_vec().try_into().unwrap();
1353 assert_eq!(
1354 bounds,
1355 vec![
1356 TermBounds {
1357 term: a_term,
1358 dividend: BlockBounds { min: 4, max: 4 },
1359 divisor: BlockBounds { min: 3, max: 3 },
1360 },
1361 TermBounds {
1362 term: c_term,
1363 dividend: BlockBounds { min: 4, max: 8 },
1364 divisor: BlockBounds { min: 3, max: 3 },
1365 },
1366 ]
1367 );
1368 }
1369
1370 #[test]
1373 fn unittest_bounds_with_merged_divisor_bottom() {
1374 axes![A = 4, B = 3];
1375 let dividend = <m![A, B]>::to_value().factorize();
1376 let divisor = <m![A, B]>::to_value().factorize();
1377 let division = dividend.divide_strict(divisor);
1378 let bounds = division.bounds();
1379 let [a_term, b_term]: [DivisionTerm; 2] = division.division_terms().to_vec().try_into().unwrap();
1380 assert_eq!(
1381 bounds,
1382 vec![
1383 TermBounds {
1384 term: a_term,
1385 dividend: BlockBounds { min: 4, max: 4 },
1386 divisor: BlockBounds { min: 4, max: 4 },
1387 },
1388 TermBounds {
1389 term: b_term,
1390 dividend: BlockBounds { min: 3, max: 3 },
1391 divisor: BlockBounds { min: 3, max: 3 },
1392 },
1393 ]
1394 );
1395 }
1396
1397 #[test]
1398 fn serde_round_trip() {
1399 axes![A = 512, B = 8, C = 4, D = 3];
1400
1401 let mappings: Vec<Mapping> = vec![
1402 Mapping::Identity,
1403 <m![A]>::to_value(),
1404 <m![B / 4]>::to_value(),
1405 <m![B % 4]>::to_value(),
1406 <m![A = 2]>::to_value(),
1407 <m![D # 4]>::to_value(),
1408 <m![A, C]>::to_value(),
1409 ];
1410
1411 for m in mappings {
1412 let json = serde_json::to_string(&m).unwrap();
1413 let deserialized: Mapping = serde_json::from_str(&json).unwrap();
1414 assert_eq!(m, deserialized, "Round-trip failed for: {json}");
1415 }
1416 }
1417}