1#![feature(register_tool)]
4#![register_tool(tcp)]
5#![warn(missing_docs)]
6#![warn(missing_debug_implementations)]
7#![forbid(unused_must_use)]
8
9mod sorted_map;
10pub use sorted_map::RSortedMap;
11
12use abi_stable::{
13 StableAbi,
14 std_types::{RBox, RResult, RVec},
15};
16use std::{
17 fmt::{self, Debug, Display, Formatter},
18 marker::PhantomData,
19};
20
21use furiosa_mapping_macro::primitive;
22use itertools::Itertools;
23
24#[primitive(mapping::Ident)]
26#[repr(C)]
27#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30#[sabi(unsafe_opaque_fields)]
31pub struct Ident(&'static str);
32
33#[expect(missing_docs)]
34impl Ident {
35 pub const fn new(s: &'static str) -> Self {
40 let b = s.as_bytes();
41 assert!(!b.is_empty(), "Ident must not be empty");
42 assert!(
43 b[0].is_ascii_uppercase(),
44 "Ident must start with an uppercase ASCII letter"
45 );
46 let mut i = 1;
47 while i < b.len() {
48 assert!(
49 b[i].is_ascii_alphanumeric() || b[i] == b'_',
50 "Ident must contain only ASCII alphanumeric or underscore characters"
51 );
52 i += 1;
53 }
54 Self(s)
55 }
56
57 pub fn as_str(&self) -> &'static str {
59 self.0
60 }
61
62 pub const A: Self = Self("A");
63 pub const B: Self = Self("B");
64 pub const C: Self = Self("C");
65 pub const D: Self = Self("D");
66 pub const E: Self = Self("E");
67 pub const F: Self = Self("F");
68 pub const G: Self = Self("G");
69 pub const H: Self = Self("H");
70 pub const I: Self = Self("I");
71 pub const J: Self = Self("J");
72 pub const K: Self = Self("K");
73 pub const L: Self = Self("L");
74 pub const M: Self = Self("M");
75 pub const N: Self = Self("N");
76 pub const O: Self = Self("O");
77 pub const P: Self = Self("P");
78 pub const Q: Self = Self("Q");
79 pub const R: Self = Self("R");
80 pub const S: Self = Self("S");
81 pub const T: Self = Self("T");
82 pub const U: Self = Self("U");
83 pub const V: Self = Self("V");
84 pub const W: Self = Self("W");
85 pub const X: Self = Self("X");
86 pub const Y: Self = Self("Y");
87 pub const Z: Self = Self("Z");
88}
89
90impl Display for Ident {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96impl From<Ident> for &'static str {
97 fn from(value: Ident) -> Self {
98 value.0
99 }
100}
101
102impl serde::Serialize for Ident {
103 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
104 serializer.serialize_str(self.0)
105 }
106}
107
108impl<'de> serde::Deserialize<'de> for Ident {
109 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
110 let s: String = serde::Deserialize::deserialize(deserializer)?;
111 Ident::try_from(s.as_str()).map_err(|e| serde::de::Error::custom(format!("invalid Ident: {e}")))
112 }
113}
114
115impl serde_lite::Deserialize for Ident {
116 fn deserialize(val: &serde_lite::Intermediate) -> Result<Self, serde_lite::Error> {
117 let s: String = serde_lite::Deserialize::deserialize(val)?;
118 Ident::try_from(s.as_str()).map_err(|e| serde_lite::Error::custom(format!("invalid Ident: {e}")))
119 }
120}
121
122impl<'a> TryFrom<&'a str> for Ident {
123 type Error = &'a str;
124
125 fn try_from(value: &'a str) -> std::result::Result<Self, Self::Error> {
126 use lasso::ThreadedRodeo;
127 use std::sync::LazyLock;
128 static INTERNER: LazyLock<ThreadedRodeo> = LazyLock::new(ThreadedRodeo::new);
129
130 let key = INTERNER.get_or_intern(value);
131 let interned: &'static str = INTERNER.resolve(&key);
132 std::panic::catch_unwind(|| Self::new(interned)).map_err(|_| value)
133 }
134}
135
136#[repr(C)]
138#[derive(StableAbi, Debug, Clone, PartialEq, Eq, Hash)]
139pub enum Mapping {
140 Identity,
142 Symbol {
144 symbol: Ident,
146 size: usize,
148 },
149 Stride {
151 inner: RBox<Mapping>,
153 stride: usize,
155 },
156 Modulo {
158 inner: RBox<Mapping>,
160 modulo: usize,
162 },
163 Resize {
165 inner: RBox<Mapping>,
167 resize: usize,
169 },
170 Padding {
172 inner: RBox<Mapping>,
174 padding: usize,
176 kind: PaddingKind,
178 },
179 Pair {
181 left: RBox<Mapping>,
183 right: RBox<Mapping>,
185 },
186}
187
188impl Display for Mapping {
189 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
190 fn flatten_pair<'a>(acc: &mut Vec<&'a Mapping>, m: &'a Mapping) {
191 match m {
192 Mapping::Pair { left, right } => {
193 flatten_pair(acc, left);
194 flatten_pair(acc, right);
195 }
196 _ => acc.push(m),
197 }
198 }
199
200 match self {
201 Self::Identity => write!(f, "1"),
202 Self::Symbol { symbol, size: _ } => {
203 write!(f, "{symbol}")
205 }
206 Self::Stride { inner, stride } => write!(f, "{inner} / {stride}"),
207 Self::Modulo { inner, modulo } => write!(f, "{inner} % {modulo}"),
208 Self::Resize { inner, resize } => write!(f, "{inner} = {resize}"),
209 Self::Padding {
210 inner,
211 padding,
212 kind: PaddingKind::Top,
213 } => write!(f, "{inner} # {padding}"),
214 Self::Padding {
215 inner,
216 padding,
217 kind: PaddingKind::Bottom,
218 } => write!(f, "{inner} #_ {padding}"),
219 Self::Pair { left, right } => {
220 let mut elements = vec![];
222 flatten_pair(&mut elements, left);
223 flatten_pair(&mut elements, right);
224 write!(f, "({})", elements.iter().join(", "))
225 }
226 }
227 }
228}
229
230#[derive(serde::Serialize, serde::Deserialize, serde_lite::Deserialize)]
233enum MappingSerde {
234 Identity,
235 Symbol {
236 symbol: Ident,
237 size: usize,
238 },
239 Stride {
240 inner: Box<MappingSerde>,
241 stride: usize,
242 },
243 Modulo {
244 inner: Box<MappingSerde>,
245 modulo: usize,
246 },
247 Resize {
248 inner: Box<MappingSerde>,
249 resize: usize,
250 },
251 Padding {
252 inner: Box<MappingSerde>,
253 padding: usize,
254 kind: PaddingKind,
255 },
256 Pair {
257 left: Box<MappingSerde>,
258 right: Box<MappingSerde>,
259 },
260}
261
262impl From<Mapping> for MappingSerde {
263 fn from(m: Mapping) -> Self {
264 match m {
265 Mapping::Identity => Self::Identity,
266 Mapping::Symbol { symbol, size } => Self::Symbol { symbol, size },
267 Mapping::Stride { inner, stride } => Self::Stride {
268 inner: Box::new(RBox::into_inner(inner).into()),
269 stride,
270 },
271 Mapping::Modulo { inner, modulo } => Self::Modulo {
272 inner: Box::new(RBox::into_inner(inner).into()),
273 modulo,
274 },
275 Mapping::Resize { inner, resize } => Self::Resize {
276 inner: Box::new(RBox::into_inner(inner).into()),
277 resize,
278 },
279 Mapping::Padding { inner, padding, kind } => Self::Padding {
280 inner: Box::new(RBox::into_inner(inner).into()),
281 padding,
282 kind,
283 },
284 Mapping::Pair { left, right } => Self::Pair {
285 left: Box::new(RBox::into_inner(left).into()),
286 right: Box::new(RBox::into_inner(right).into()),
287 },
288 }
289 }
290}
291
292impl From<MappingSerde> for Mapping {
293 fn from(m: MappingSerde) -> Self {
294 match m {
295 MappingSerde::Identity => Self::Identity,
296 MappingSerde::Symbol { symbol, size } => Self::Symbol { symbol, size },
297 MappingSerde::Stride { inner, stride } => Self::Stride {
298 inner: RBox::new((*inner).into()),
299 stride,
300 },
301 MappingSerde::Modulo { inner, modulo } => Self::Modulo {
302 inner: RBox::new((*inner).into()),
303 modulo,
304 },
305 MappingSerde::Resize { inner, resize } => Self::Resize {
306 inner: RBox::new((*inner).into()),
307 resize,
308 },
309 MappingSerde::Padding { inner, padding, kind } => Self::Padding {
310 inner: RBox::new((*inner).into()),
311 padding,
312 kind,
313 },
314 MappingSerde::Pair { left, right } => Self::Pair {
315 left: RBox::new((*left).into()),
316 right: RBox::new((*right).into()),
317 },
318 }
319 }
320}
321
322impl serde::Serialize for Mapping {
323 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
324 MappingSerde::from(self.clone()).serialize(s)
325 }
326}
327
328impl<'de> serde::Deserialize<'de> for Mapping {
329 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
330 MappingSerde::deserialize(d).map(Into::into)
331 }
332}
333
334impl serde_lite::Deserialize for Mapping {
335 fn deserialize(val: &serde_lite::Intermediate) -> Result<Self, serde_lite::Error> {
336 MappingSerde::deserialize(val).map(Into::into)
337 }
338}
339
340#[repr(C)]
342#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
343pub enum Atom {
344 Symbol {
346 symbol: Ident,
348 size: usize,
350 },
351 Composite(RBox<FMapping>),
353}
354
355#[repr(C)]
357#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
358pub struct Term {
359 pub inner: Atom,
361 pub stride: usize,
363 pub modulo: usize,
365}
366
367impl Display for Term {
368 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369 match &self.inner {
370 Atom::Symbol { symbol, size } => {
371 if self.stride == 1 && self.modulo == *size {
372 write!(f, "{}", symbol)
373 } else if self.stride == 1 {
374 write!(f, "{}%{}", symbol, self.modulo)
375 } else if self.modulo == *size {
376 write!(f, "{}//{}", symbol, self.stride)
377 } else {
378 write!(f, "({}//{})%{}", symbol, self.stride, self.modulo)
379 }
380 }
381 Atom::Composite(inner) => {
382 if self.stride == 1 && self.modulo == inner.size() {
383 write!(f, "({})", inner)
384 } else if self.stride == 1 {
385 write!(f, "({})%{}", inner, self.modulo)
386 } else if self.modulo == inner.size() {
387 write!(f, "({})//{}", inner, self.stride)
388 } else {
389 write!(f, "(({})//{})%{}", inner, self.stride, self.modulo)
390 }
391 }
392 }
393 }
394}
395
396#[repr(C)]
398#[derive(
399 StableAbi,
400 Debug,
401 Clone,
402 Copy,
403 PartialEq,
404 Eq,
405 PartialOrd,
406 Ord,
407 Hash,
408 serde::Serialize,
409 serde::Deserialize,
410 serde_lite::Deserialize,
411)]
412pub enum PaddingKind {
413 Top,
415 Bottom,
417}
418
419#[repr(C)]
421#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
422pub enum Factor {
423 Term {
425 inner: Term,
427 resize: usize,
429 },
430 Padding {
432 size: usize,
434 kind: PaddingKind,
436 },
437}
438
439#[repr(C)]
444#[derive(StableAbi, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
445pub struct FMapping(pub RVec<Factor>);
446
447impl Default for FMapping {
448 fn default() -> Self {
449 Self::new()
450 }
451}
452
453impl FMapping {
454 pub fn new() -> Self {
456 Self(RVec::new())
457 }
458
459 pub fn size(&self) -> usize {
461 let mut x = 1;
462 for term in self.0.iter().rev() {
463 match term {
464 Factor::Padding { size, .. } => {
465 return x * *size;
466 }
467 Factor::Term { resize, .. } => {
468 x *= *resize;
469 }
470 }
471 }
472 x
473 }
474}
475
476impl Display for FMapping {
477 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
478 let terms: Vec<String> = self
479 .0
480 .iter()
481 .map(|term| match term {
482 Factor::Padding {
483 size,
484 kind: PaddingKind::Top,
485 } => format!("pad({})", size),
486 Factor::Padding {
487 size,
488 kind: PaddingKind::Bottom,
489 } => format!("bottom_pad({})", size),
490 Factor::Term {
491 inner: Term { inner, stride, modulo },
492 resize,
493 } => {
494 let atom_str = match inner {
495 Atom::Symbol { symbol, size } => format!("{}:{}", symbol, size),
496 Atom::Composite(inner) => format!("({})", inner),
497 };
498 format!("term({} / {} % {} = {})", atom_str, stride, modulo, resize)
499 }
500 })
501 .collect();
502 write!(f, "FMapping[{}]", terms.join(" * "))
503 }
504}
505
506#[repr(C)]
508#[derive(StableAbi, Debug)]
509pub enum DivisionError {
510 NoDivisorTerms,
512 DivisorTermCannotDivide,
514}
515
516#[repr(C)]
518#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
519pub enum DivisionSide {
520 Dividend,
522 Divisor,
524}
525
526#[repr(C)]
528#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
529pub struct DivisionTerm {
530 pub dividend_stride: usize,
532 pub divisor_stride: usize,
534 pub term: Term,
536 pub resize: usize,
538}
539
540impl DivisionTerm {
541 pub fn stride(&self, side: DivisionSide) -> usize {
543 match side {
544 DivisionSide::Dividend => self.dividend_stride,
545 DivisionSide::Divisor => self.divisor_stride,
546 }
547 }
548}
549
550#[repr(C)]
552#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
553pub struct BlockBounds {
554 pub min: usize,
556 pub max: usize,
558}
559
560#[repr(C)]
562#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
563pub struct TermBounds {
564 pub term: DivisionTerm,
566 pub dividend: BlockBounds,
568 pub divisor: BlockBounds,
570}
571
572pub trait DivisionMode {
574 const PADDING_KIND: PaddingKind;
576}
577
578#[repr(C)]
580#[derive(StableAbi, Debug, Clone, Copy)]
581pub struct Strict;
582impl DivisionMode for Strict {
583 const PADDING_KIND: PaddingKind = PaddingKind::Bottom;
584}
585
586#[repr(C)]
588#[derive(StableAbi, Debug, Clone, Copy)]
589pub struct Relaxed;
590impl DivisionMode for Relaxed {
591 const PADDING_KIND: PaddingKind = PaddingKind::Top;
592}
593
594#[repr(C)]
596#[derive(StableAbi, Debug, Clone, Copy)]
597pub struct Span;
598impl DivisionMode for Span {
599 const PADDING_KIND: PaddingKind = PaddingKind::Top;
600}
601
602#[repr(C)]
604#[derive(StableAbi, Debug, Clone)]
605pub struct Division<M: DivisionMode> {
606 pub division_terms: RVec<DivisionTerm>,
608 pub dividend: FMapping,
610 pub dividend_residue: FMapping,
613 pub divisor: FMapping,
615 pub divisor_residue: FMapping,
618 _mode: PhantomData<M>,
619}
620
621impl<M: DivisionMode> Division<M> {
622 pub fn new(
624 dividend: FMapping,
625 dividend_residue: FMapping,
626 mut division_terms: Vec<DivisionTerm>,
627 divisor: FMapping,
628 divisor_residue: FMapping,
629 ) -> Self {
630 division_terms.sort_by(|a, b| b.dividend_stride.cmp(&a.dividend_stride));
631
632 Self {
633 division_terms: division_terms.into(),
634 dividend,
635 dividend_residue,
636 divisor,
637 divisor_residue,
638 _mode: PhantomData,
639 }
640 }
641
642 pub fn dividend(&self) -> &FMapping {
644 &self.dividend
645 }
646
647 pub fn divisor(&self) -> &FMapping {
649 &self.divisor
650 }
651
652 pub fn residue(&self, side: DivisionSide) -> &FMapping {
654 match side {
655 DivisionSide::Dividend => &self.dividend_residue,
656 DivisionSide::Divisor => &self.divisor_residue,
657 }
658 }
659
660 pub fn mapping(&self, side: DivisionSide) -> &FMapping {
662 match side {
663 DivisionSide::Dividend => &self.dividend,
664 DivisionSide::Divisor => &self.divisor,
665 }
666 }
667
668 pub fn division_terms(&self) -> &[DivisionTerm] {
670 &self.division_terms
671 }
672}
673
674#[repr(C)]
678#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
679pub struct TermPosition {
680 pub term: Term,
682 pub resize: usize,
684 pub stride: usize,
687}
688
689#[repr(C)]
691#[derive(StableAbi, Debug, Clone, PartialEq, Eq)]
692pub struct Index(pub RResult<RSortedMap<Term, usize>, ()>);
693
694#[repr(C)]
696#[derive(StableAbi, Debug, Clone, Copy, PartialEq, Eq)]
697pub enum IndexValueError {
698 Invalid,
700 NonFlattened,
702}
703
704impl Default for Index {
705 fn default() -> Self {
706 Self(RResult::ROk(RSortedMap::new()))
707 }
708}
709
710impl Display for Index {
711 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
712 match &self.0 {
713 RResult::ROk(map) => {
714 let terms = map.iter().map(|(k, v)| format!("{k} = {v}")).join(", ");
715 write!(f, "Index[{}]", terms)
716 }
717 RResult::RErr(_) => {
718 write!(f, "Invalid Index")
719 }
720 }
721 }
722}
723
724#[primitive(mapping::M)]
726pub trait M: Debug + Clone {
728 const SIZE: usize;
730
731 fn to_value() -> Mapping;
733
734 fn map(i: usize) -> Option<Index>;
736}
737