furiosa_mapping_types/
lib.rs

1//! Mapping expressions.
2
3#![feature(register_tool)]
4#![register_tool(tcp)]
5#![warn(missing_docs)]
6#![warn(missing_debug_implementations)]
7#![forbid(unused_must_use)]
8
9mod sorted_map;
10pub use sorted_map::RSortedMap;
11
12use abi_stable::{
13    StableAbi,
14    std_types::{RBox, RResult, RVec},
15};
16use std::{
17    fmt::{self, Debug, Display, Formatter},
18    marker::PhantomData,
19};
20
21use furiosa_mapping_macro::primitive;
22use itertools::Itertools;
23
24/// Axis identifiers.
25#[primitive(mapping::Ident)]
26#[repr(C)]
27// SAFETY: &'static str is not formally ABI-stable, but its layout (*const u8, usize)
28// is de facto stable across all Rust versions and extremely unlikely to change.
29#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30#[sabi(unsafe_opaque_fields)]
31pub struct Ident(&'static str);
32
33#[expect(missing_docs)]
34impl Ident {
35    /// Creates a new identifier.
36    ///
37    /// The identifier must start with an uppercase ASCII letter and contain
38    /// only ASCII alphanumeric characters or underscores.
39    pub const fn new(s: &'static str) -> Self {
40        let b = s.as_bytes();
41        assert!(!b.is_empty(), "Ident must not be empty");
42        assert!(
43            b[0].is_ascii_uppercase(),
44            "Ident must start with an uppercase ASCII letter"
45        );
46        let mut i = 1;
47        while i < b.len() {
48            assert!(
49                b[i].is_ascii_alphanumeric() || b[i] == b'_',
50                "Ident must contain only ASCII alphanumeric or underscore characters"
51            );
52            i += 1;
53        }
54        Self(s)
55    }
56
57    /// Returns the string representation.
58    pub fn as_str(&self) -> &'static str {
59        self.0
60    }
61
62    pub const A: Self = Self("A");
63    pub const B: Self = Self("B");
64    pub const C: Self = Self("C");
65    pub const D: Self = Self("D");
66    pub const E: Self = Self("E");
67    pub const F: Self = Self("F");
68    pub const G: Self = Self("G");
69    pub const H: Self = Self("H");
70    pub const I: Self = Self("I");
71    pub const J: Self = Self("J");
72    pub const K: Self = Self("K");
73    pub const L: Self = Self("L");
74    pub const M: Self = Self("M");
75    pub const N: Self = Self("N");
76    pub const O: Self = Self("O");
77    pub const P: Self = Self("P");
78    pub const Q: Self = Self("Q");
79    pub const R: Self = Self("R");
80    pub const S: Self = Self("S");
81    pub const T: Self = Self("T");
82    pub const U: Self = Self("U");
83    pub const V: Self = Self("V");
84    pub const W: Self = Self("W");
85    pub const X: Self = Self("X");
86    pub const Y: Self = Self("Y");
87    pub const Z: Self = Self("Z");
88}
89
90impl Display for Ident {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(f, "{}", self.0)
93    }
94}
95
96impl From<Ident> for &'static str {
97    fn from(value: Ident) -> Self {
98        value.0
99    }
100}
101
102impl serde::Serialize for Ident {
103    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
104        serializer.serialize_str(self.0)
105    }
106}
107
108impl<'de> serde::Deserialize<'de> for Ident {
109    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
110        let s: String = serde::Deserialize::deserialize(deserializer)?;
111        Ident::try_from(s.as_str()).map_err(|e| serde::de::Error::custom(format!("invalid Ident: {e}")))
112    }
113}
114
115impl serde_lite::Deserialize for Ident {
116    fn deserialize(val: &serde_lite::Intermediate) -> Result<Self, serde_lite::Error> {
117        let s: String = serde_lite::Deserialize::deserialize(val)?;
118        Ident::try_from(s.as_str()).map_err(|e| serde_lite::Error::custom(format!("invalid Ident: {e}")))
119    }
120}
121
122impl<'a> TryFrom<&'a str> for Ident {
123    type Error = &'a str;
124
125    fn try_from(value: &'a str) -> std::result::Result<Self, Self::Error> {
126        use lasso::ThreadedRodeo;
127        use std::sync::LazyLock;
128        static INTERNER: LazyLock<ThreadedRodeo> = LazyLock::new(ThreadedRodeo::new);
129
130        let key = INTERNER.get_or_intern(value);
131        let interned: &'static str = INTERNER.resolve(&key);
132        std::panic::catch_unwind(|| Self::new(interned)).map_err(|_| value)
133    }
134}
135
136/// Mapping expression enum.
137#[repr(C)]
138#[derive(StableAbi, Debug, Clone, PartialEq, Eq, Hash)]
139pub enum Mapping {
140    /// Identity mapping.
141    Identity,
142    /// Symbol mapping.
143    Symbol {
144        /// Symbol.
145        symbol: Ident,
146        /// Size.
147        size: usize,
148    },
149    /// Stride mapping.
150    Stride {
151        /// Inner mapping.
152        inner: RBox<Mapping>,
153        /// Stride size.
154        stride: usize,
155    },
156    /// Modulo mapping.
157    Modulo {
158        /// Inner mapping.
159        inner: RBox<Mapping>,
160        /// Stride size.
161        modulo: usize,
162    },
163    /// Resize mapping.
164    Resize {
165        /// Inner mapping.
166        inner: RBox<Mapping>,
167        /// Truncate size.
168        resize: usize,
169    },
170    /// Padding mapping.
171    Padding {
172        /// Inner mapping.
173        inner: RBox<Mapping>,
174        /// Size after padding.
175        padding: usize,
176        /// Accessibility of this padding region.
177        kind: PaddingKind,
178    },
179    /// Pair mapping.
180    Pair {
181        /// Left mapping.
182        left: RBox<Mapping>,
183        /// Right mapping.
184        right: RBox<Mapping>,
185    },
186}
187
188impl Display for Mapping {
189    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
190        fn flatten_pair<'a>(acc: &mut Vec<&'a Mapping>, m: &'a Mapping) {
191            match m {
192                Mapping::Pair { left, right } => {
193                    flatten_pair(acc, left);
194                    flatten_pair(acc, right);
195                }
196                _ => acc.push(m),
197            }
198        }
199
200        match self {
201            Self::Identity => write!(f, "1"),
202            Self::Symbol { symbol, size: _ } => {
203                // We hide the size just for readability.
204                write!(f, "{symbol}")
205            }
206            Self::Stride { inner, stride } => write!(f, "{inner} / {stride}"),
207            Self::Modulo { inner, modulo } => write!(f, "{inner} % {modulo}"),
208            Self::Resize { inner, resize } => write!(f, "{inner} = {resize}"),
209            Self::Padding {
210                inner,
211                padding,
212                kind: PaddingKind::Top,
213            } => write!(f, "{inner} # {padding}"),
214            Self::Padding {
215                inner,
216                padding,
217                kind: PaddingKind::Bottom,
218            } => write!(f, "{inner} #_ {padding}"),
219            Self::Pair { left, right } => {
220                // Collect all nested pairs and print them as flattened.
221                let mut elements = vec![];
222                flatten_pair(&mut elements, left);
223                flatten_pair(&mut elements, right);
224                write!(f, "({})", elements.iter().join(", "))
225            }
226        }
227    }
228}
229
230/// Serde-compatible mirror of [`Mapping`] using `Box` instead of `RBox`, so that
231/// the standard derive macros work. Used only for serialization/deserialization.
232#[derive(serde::Serialize, serde::Deserialize, serde_lite::Deserialize)]
233enum MappingSerde {
234    Identity,
235    Symbol {
236        symbol: Ident,
237        size: usize,
238    },
239    Stride {
240        inner: Box<MappingSerde>,
241        stride: usize,
242    },
243    Modulo {
244        inner: Box<MappingSerde>,
245        modulo: usize,
246    },
247    Resize {
248        inner: Box<MappingSerde>,
249        resize: usize,
250    },
251    Padding {
252        inner: Box<MappingSerde>,
253        padding: usize,
254        kind: PaddingKind,
255    },
256    Pair {
257        left: Box<MappingSerde>,
258        right: Box<MappingSerde>,
259    },
260}
261
262impl From<Mapping> for MappingSerde {
263    fn from(m: Mapping) -> Self {
264        match m {
265            Mapping::Identity => Self::Identity,
266            Mapping::Symbol { symbol, size } => Self::Symbol { symbol, size },
267            Mapping::Stride { inner, stride } => Self::Stride {
268                inner: Box::new(RBox::into_inner(inner).into()),
269                stride,
270            },
271            Mapping::Modulo { inner, modulo } => Self::Modulo {
272                inner: Box::new(RBox::into_inner(inner).into()),
273                modulo,
274            },
275            Mapping::Resize { inner, resize } => Self::Resize {
276                inner: Box::new(RBox::into_inner(inner).into()),
277                resize,
278            },
279            Mapping::Padding { inner, padding, kind } => Self::Padding {
280                inner: Box::new(RBox::into_inner(inner).into()),
281                padding,
282                kind,
283            },
284            Mapping::Pair { left, right } => Self::Pair {
285                left: Box::new(RBox::into_inner(left).into()),
286                right: Box::new(RBox::into_inner(right).into()),
287            },
288        }
289    }
290}
291
292impl From<MappingSerde> for Mapping {
293    fn from(m: MappingSerde) -> Self {
294        match m {
295            MappingSerde::Identity => Self::Identity,
296            MappingSerde::Symbol { symbol, size } => Self::Symbol { symbol, size },
297            MappingSerde::Stride { inner, stride } => Self::Stride {
298                inner: RBox::new((*inner).into()),
299                stride,
300            },
301            MappingSerde::Modulo { inner, modulo } => Self::Modulo {
302                inner: RBox::new((*inner).into()),
303                modulo,
304            },
305            MappingSerde::Resize { inner, resize } => Self::Resize {
306                inner: RBox::new((*inner).into()),
307                resize,
308            },
309            MappingSerde::Padding { inner, padding, kind } => Self::Padding {
310                inner: RBox::new((*inner).into()),
311                padding,
312                kind,
313            },
314            MappingSerde::Pair { left, right } => Self::Pair {
315                left: RBox::new((*left).into()),
316                right: RBox::new((*right).into()),
317            },
318        }
319    }
320}
321
322impl serde::Serialize for Mapping {
323    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
324        MappingSerde::from(self.clone()).serialize(s)
325    }
326}
327
328impl<'de> serde::Deserialize<'de> for Mapping {
329    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
330        MappingSerde::deserialize(d).map(Into::into)
331    }
332}
333
334impl serde_lite::Deserialize for Mapping {
335    fn deserialize(val: &serde_lite::Intermediate) -> Result<Self, serde_lite::Error> {
336        MappingSerde::deserialize(val).map(Into::into)
337    }
338}
339
340/// Atomic mapping expression.
341#[repr(C)]
342#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
343pub enum Atom {
344    /// Symbolic atomic mapping expression.
345    Symbol {
346        /// Symbol of the axis.
347        symbol: Ident,
348        /// Size of the axis.
349        size: usize,
350    },
351    /// Composite mapping expression.
352    Composite(RBox<FMapping>),
353}
354
355/// `inner / stride % modulo`.
356#[repr(C)]
357#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
358pub struct Term {
359    /// Inner mapping expression.
360    pub inner: Atom,
361    /// Stride of the mapping.
362    pub stride: usize,
363    /// Modulo of the mapping.
364    pub modulo: usize,
365}
366
367impl Display for Term {
368    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369        match &self.inner {
370            Atom::Symbol { symbol, size } => {
371                if self.stride == 1 && self.modulo == *size {
372                    write!(f, "{}", symbol)
373                } else if self.stride == 1 {
374                    write!(f, "{}%{}", symbol, self.modulo)
375                } else if self.modulo == *size {
376                    write!(f, "{}//{}", symbol, self.stride)
377                } else {
378                    write!(f, "({}//{})%{}", symbol, self.stride, self.modulo)
379                }
380            }
381            Atom::Composite(inner) => {
382                if self.stride == 1 && self.modulo == inner.size() {
383                    write!(f, "({})", inner)
384                } else if self.stride == 1 {
385                    write!(f, "({})%{}", inner, self.modulo)
386                } else if self.modulo == inner.size() {
387                    write!(f, "({})//{}", inner, self.stride)
388                } else {
389                    write!(f, "(({})//{})%{}", inner, self.stride, self.modulo)
390                }
391            }
392        }
393    }
394}
395
396/// Kind of a padding factor.
397#[repr(C)]
398#[derive(
399    StableAbi,
400    Debug,
401    Clone,
402    Copy,
403    PartialEq,
404    Eq,
405    PartialOrd,
406    Ord,
407    Hash,
408    serde::Serialize,
409    serde::Deserialize,
410    serde_lite::Deserialize,
411)]
412pub enum PaddingKind {
413    /// Accessible padding.
414    Top,
415    /// Inaccessible padding.
416    Bottom,
417}
418
419/// Factor representation of a mapping expression.
420#[repr(C)]
421#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
422pub enum Factor {
423    /// Term.
424    Term {
425        /// Inner term.
426        inner: Term,
427        /// Resize of the mapping.
428        resize: usize,
429    },
430    /// Padding.
431    Padding {
432        /// Size after padding.
433        size: usize,
434        /// Accessibility of this padding region.
435        kind: PaddingKind,
436    },
437}
438
439/// Factor mapping expression.
440///
441/// Factors are ordered from innermost (index 0) to outermost (last index).
442/// `push`/`pop`/`last` operate on the outermost factor.
443#[repr(C)]
444#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
445pub struct FMapping(pub RVec<Factor>);
446
447impl Default for FMapping {
448    fn default() -> Self {
449        Self::new()
450    }
451}
452
453impl FMapping {
454    /// Creates a new empty factor mapping.
455    pub fn new() -> Self {
456        Self(RVec::new())
457    }
458
459    /// Returns the size of the factor mapping.
460    pub fn size(&self) -> usize {
461        let mut x = 1;
462        for term in self.0.iter().rev() {
463            match term {
464                Factor::Padding { size, .. } => {
465                    return x * *size;
466                }
467                Factor::Term { resize, .. } => {
468                    x *= *resize;
469                }
470            }
471        }
472        x
473    }
474}
475
476impl Display for FMapping {
477    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
478        let terms: Vec<String> = self
479            .0
480            .iter()
481            .map(|term| match term {
482                Factor::Padding {
483                    size,
484                    kind: PaddingKind::Top,
485                } => format!("pad({})", size),
486                Factor::Padding {
487                    size,
488                    kind: PaddingKind::Bottom,
489                } => format!("bottom_pad({})", size),
490                Factor::Term {
491                    inner: Term { inner, stride, modulo },
492                    resize,
493                } => {
494                    let atom_str = match inner {
495                        Atom::Symbol { symbol, size } => format!("{}:{}", symbol, size),
496                        Atom::Composite(inner) => format!("({})", inner),
497                    };
498                    format!("term({} / {} % {} = {})", atom_str, stride, modulo, resize)
499                }
500            })
501            .collect();
502        write!(f, "FMapping[{}]", terms.join(" * "))
503    }
504}
505
506/// Error during division of mapping expressions.
507#[repr(C)]
508#[derive(StableAbi, Debug)]
509pub enum DivisionError {
510    /// No divisor terms found.
511    NoDivisorTerms,
512    /// Divisor term cannot divide dividend.
513    DivisorTermCannotDivide,
514}
515
516/// Selects which side of a [`DivisionTerm`] to operate on.
517#[repr(C)]
518#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
519pub enum DivisionSide {
520    /// The dividend (left-hand side of the division).
521    Dividend,
522    /// The divisor (right-hand side of the division).
523    Divisor,
524}
525
526/// Information about a single matched term in a division.
527#[repr(C)]
528#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
529pub struct DivisionTerm {
530    /// Stride of the dividend term.
531    pub dividend_stride: usize,
532    /// Stride of the divisor term.
533    pub divisor_stride: usize,
534    /// Divisor term.
535    pub term: Term,
536    /// Divisor resize.
537    pub resize: usize,
538}
539
540impl DivisionTerm {
541    /// Returns the stride of this matched term on the selected side.
542    pub fn stride(&self, side: DivisionSide) -> usize {
543        match side {
544            DivisionSide::Dividend => self.dividend_stride,
545            DivisionSide::Divisor => self.divisor_stride,
546        }
547    }
548}
549
550/// Bounds for the padded block removed for a matched term.
551#[repr(C)]
552#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
553pub struct BlockBounds {
554    /// Minimum block size chosen by the current replay semantics.
555    pub min: usize,
556    /// Largest normalized padding boundary that still encloses the removed content.
557    pub max: usize,
558}
559
560/// Per-term compact padding bounds.
561#[repr(C)]
562#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
563pub struct TermBounds {
564    /// Matched term metadata for this bounds row.
565    pub term: DivisionTerm,
566    /// Bounds reconstructed on the dividend side.
567    pub dividend: BlockBounds,
568    /// Bounds reconstructed on the divisor side.
569    pub divisor: BlockBounds,
570}
571
572/// Determines the padding kind used for matched-hole markers in division.
573pub trait DivisionMode {
574    /// The padding kind to use for matched-term holes.
575    const PADDING_KIND: PaddingKind;
576}
577
578/// Marker for analysis-capable division results.
579#[repr(C)]
580#[derive(StableAbi, Debug, Clone, Copy)]
581pub struct Strict;
582impl DivisionMode for Strict {
583    const PADDING_KIND: PaddingKind = PaddingKind::Bottom;
584}
585
586/// Marker for read-accessible division results.
587#[repr(C)]
588#[derive(StableAbi, Debug, Clone, Copy)]
589pub struct Relaxed;
590impl DivisionMode for Relaxed {
591    const PADDING_KIND: PaddingKind = PaddingKind::Top;
592}
593
594/// Marker for span-division results without padding analysis.
595#[repr(C)]
596#[derive(StableAbi, Debug, Clone, Copy)]
597pub struct Span;
598impl DivisionMode for Span {
599    const PADDING_KIND: PaddingKind = PaddingKind::Top;
600}
601
602/// Result of dividing two factor mappings.
603#[repr(C)]
604#[derive(StableAbi, Debug, Clone)]
605pub struct Division<M: DivisionMode> {
606    /// Information about each matched divisor term.
607    pub division_terms: RVec<DivisionTerm>,
608    /// Original dividend before matching.
609    pub dividend: FMapping,
610    /// Dividend residue. Strict results preserve Bottom padding;
611    /// relaxed results have all padding converted to Top.
612    pub dividend_residue: FMapping,
613    /// Original divisor before matching.
614    pub divisor: FMapping,
615    /// Divisor residue. Strict results preserve Bottom padding;
616    /// relaxed results have all padding converted to Top.
617    pub divisor_residue: FMapping,
618    _mode: PhantomData<M>,
619}
620
621impl<M: DivisionMode> Division<M> {
622    /// Creates a new division result.
623    pub fn new(
624        dividend: FMapping,
625        dividend_residue: FMapping,
626        mut division_terms: Vec<DivisionTerm>,
627        divisor: FMapping,
628        divisor_residue: FMapping,
629    ) -> Self {
630        division_terms.sort_by(|a, b| b.dividend_stride.cmp(&a.dividend_stride));
631
632        Self {
633            division_terms: division_terms.into(),
634            dividend,
635            dividend_residue,
636            divisor,
637            divisor_residue,
638            _mode: PhantomData,
639        }
640    }
641
642    /// Returns the original dividend.
643    pub fn dividend(&self) -> &FMapping {
644        &self.dividend
645    }
646
647    /// Returns the original divisor.
648    pub fn divisor(&self) -> &FMapping {
649        &self.divisor
650    }
651
652    /// Returns the residue for the selected side.
653    pub fn residue(&self, side: DivisionSide) -> &FMapping {
654        match side {
655            DivisionSide::Dividend => &self.dividend_residue,
656            DivisionSide::Divisor => &self.divisor_residue,
657        }
658    }
659
660    /// Returns the original mapping for the selected side.
661    pub fn mapping(&self, side: DivisionSide) -> &FMapping {
662        match side {
663            DivisionSide::Dividend => &self.dividend,
664            DivisionSide::Divisor => &self.divisor,
665        }
666    }
667
668    /// Returns matched division terms in dividend-order.
669    pub fn division_terms(&self) -> &[DivisionTerm] {
670        &self.division_terms
671    }
672}
673
674/// A Term factor with its position (stride) in an FMapping.
675///
676/// Returned by [`FMapping::terms_with_stride`].
677#[repr(C)]
678#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
679pub struct TermPosition {
680    /// The term.
681    pub term: Term,
682    /// Size of this term (number of positions it contributes).
683    pub resize: usize,
684    /// Effective stride of this term in the FMapping
685    /// (product of all inner factors' sizes).
686    pub stride: usize,
687}
688
689/// Index mapping for tensor operations.
690#[repr(C)]
691#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
692pub struct Index(pub RResult<RSortedMap<Term, usize>, ()>);
693
694/// Error returned when querying ident contributions from an [`Index`].
695#[repr(C)]
696#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
697pub enum IndexValueError {
698    /// The index is invalid, typically because it points into padding.
699    Invalid,
700    /// The index still contains Composite terms and is not evaluable per ident.
701    NonFlattened,
702}
703
704impl Default for Index {
705    fn default() -> Self {
706        Self(RResult::ROk(RSortedMap::new()))
707    }
708}
709
710impl Display for Index {
711    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
712        match &self.0 {
713            RResult::ROk(map) => {
714                let terms = map.iter().map(|(k, v)| format!("{k} = {v}")).join(", ");
715                write!(f, "Index[{}]", terms)
716            }
717            RResult::RErr(_) => {
718                write!(f, "Invalid Index")
719            }
720        }
721    }
722}
723
724/// Mapping expression that describes memory layout and computes size for a given shape.
725#[primitive(mapping::M)]
726// ANCHOR: trait_m
727pub trait M: Debug + Clone {
728    /// The computed size for the given shape.
729    const SIZE: usize;
730
731    /// Converts the mapping expression type into a value.
732    fn to_value() -> Mapping;
733
734    /// Converts a buffer index to a tensor index, returning `None` if out-of-bounds.
735    fn map(i: usize) -> Option<Index>;
736}
737// ANCHOR_END: trait_m