1use furiosa_mapping_macro::primitive;
4use num_traits::{Num, One, Zero};
5use rand::distr::StandardUniform;
6use std::fmt::Debug;
7use std::ops::{Add, Div, Mul, Rem, Sub};
8
9pub trait Scalar: ndarray::LinalgScalar + Debug + Clone + Copy + PartialEq + Num {
11 const BITS: usize;
13
14 fn size_in_bytes_from_length(length: usize) -> usize {
18 assert_eq!((length * Self::BITS) % 8, 0, "total bits must be byte-aligned");
19 (length * Self::BITS) / 8
20 }
21
22 fn length_from_bytes(bytes: usize) -> usize {
26 assert_eq!(
27 (bytes * 8) % Self::BITS,
28 0,
29 "bytes must correspond to a whole number of elements"
30 );
31 (bytes * 8) / Self::BITS
32 }
33}
34
35pub trait ScalarBytes: Scalar {
41 fn from_le_bytes(bytes: &[u8]) -> Self;
43}
44
45macro_rules! impl_scalar_std {
46 ($($t:ty),*) => {
47 $(
48 impl Scalar for $t {
49 const BITS: usize = std::mem::size_of::<Self>() * 8;
50 }
51 impl ScalarBytes for $t {
52 fn from_le_bytes(bytes: &[u8]) -> Self {
53 <$t>::from_le_bytes(bytes.try_into().expect("byte length mismatch"))
54 }
55 }
56 )*
57 };
58}
59impl_scalar_std!(i8, i16, i32, f32, u8);
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Opt<D> {
64 Init(D),
66 Uninit,
68}
69
70impl<D> Opt<D> {
71 pub fn map<D2>(self, f: impl FnOnce(D) -> D2) -> Opt<D2> {
73 match self {
74 Opt::Init(val) => Opt::Init(f(val)),
75 Opt::Uninit => Opt::Uninit,
76 }
77 }
78
79 pub fn zip_map<D2, R>(self, other: Opt<D2>, f: impl FnOnce(D, D2) -> R) -> Opt<R> {
81 match (self, other) {
82 (Opt::Init(a), Opt::Init(b)) => Opt::Init(f(a, b)),
83 _ => Opt::Uninit,
84 }
85 }
86
87 pub fn unwrap(self) -> D {
89 let Opt::Init(val) = self else {
90 panic!("Called unwrap on an uninitialized Opt value.");
91 };
92 val
93 }
94}
95
96impl<D: Add<Output = D>> Add for Opt<D> {
97 type Output = Self;
98
99 fn add(self, rhs: Self) -> Opt<D> {
100 let Opt::Init(lhs) = self else {
101 return Opt::Uninit;
102 };
103 let Opt::Init(rhs) = rhs else {
104 return Opt::Uninit;
105 };
106 Opt::Init(lhs + rhs)
107 }
108}
109
110impl<D: Sub<Output = D>> Sub for Opt<D> {
111 type Output = Self;
112
113 fn sub(self, rhs: Self) -> Opt<D> {
114 let Opt::Init(lhs) = self else {
115 return Opt::Uninit;
116 };
117 let Opt::Init(rhs) = rhs else {
118 return Opt::Uninit;
119 };
120 Opt::Init(lhs - rhs)
121 }
122}
123
124impl<D: Mul<Output = D>> Mul for Opt<D> {
125 type Output = Self;
126
127 fn mul(self, rhs: Self) -> Opt<D> {
128 let Opt::Init(lhs) = self else {
129 return Opt::Uninit;
130 };
131 let Opt::Init(rhs) = rhs else {
132 return Opt::Uninit;
133 };
134 Opt::Init(lhs * rhs)
135 }
136}
137
138impl<D: Div<Output = D>> Div for Opt<D> {
139 type Output = Self;
140
141 fn div(self, rhs: Self) -> Opt<D> {
142 let Opt::Init(lhs) = self else {
143 return Opt::Uninit;
144 };
145 let Opt::Init(rhs) = rhs else {
146 return Opt::Uninit;
147 };
148 Opt::Init(lhs / rhs)
149 }
150}
151
152impl<D: Zero> Zero for Opt<D> {
153 fn zero() -> Self {
154 Opt::Init(D::zero())
155 }
156
157 fn is_zero(&self) -> bool {
158 let Opt::Init(val) = self else {
159 return false;
160 };
161 val.is_zero()
162 }
163}
164
165impl<D: One> One for Opt<D> {
166 fn one() -> Self {
167 Opt::Init(D::one())
168 }
169}
170
171#[expect(non_camel_case_types)]
173#[derive(Clone, Copy, Debug, PartialEq)]
174pub struct f8(::f8::f8);
175
176impl Zero for f8 {
177 fn zero() -> Self {
178 f8(::f8::f8::from(0.0))
179 }
180
181 fn is_zero(&self) -> bool {
182 self.0.float().is_zero()
183 }
184}
185
186impl One for f8 {
187 fn one() -> Self {
188 f8(::f8::f8::from(1.0))
189 }
190}
191
192impl Add<Self> for f8 {
193 type Output = Self;
194
195 fn add(self, rhs: Self) -> Self {
196 f8((self.0.float() + rhs.0.float()).into())
197 }
198}
199
200impl Sub<Self> for f8 {
201 type Output = Self;
202
203 fn sub(self, rhs: Self) -> Self {
204 f8((self.0.float() - rhs.0.float()).into())
205 }
206}
207
208impl Mul<Self> for f8 {
209 type Output = Self;
210
211 fn mul(self, rhs: Self) -> Self {
212 f8((self.0.float() * rhs.0.float()).into())
213 }
214}
215
216impl Div<Self> for f8 {
217 type Output = Self;
218
219 fn div(self, rhs: Self) -> Self {
220 f8((self.0.float() / rhs.0.float()).into())
221 }
222}
223
224impl Rem<Self> for f8 {
225 type Output = Self;
226
227 fn rem(self, rhs: Self) -> Self {
228 f8((self.0.float() % rhs.0.float()).into())
229 }
230}
231
232impl Num for f8 {
233 type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
234
235 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
236 Ok(f8(::f8::f8::from(<f32 as Num>::from_str_radix(str, radix)?)))
237 }
238}
239
240impl rand::distr::Distribution<f8> for StandardUniform {
241 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f8 {
242 let val: f32 = rng.random_range(-1.0..1.0);
243 f8(::f8::f8::from(val))
244 }
245}
246
247impl Scalar for f8 {
248 const BITS: usize = 8;
249}
250
251impl ScalarBytes for f8 {
252 fn from_le_bytes(bytes: &[u8]) -> Self {
253 assert_eq!(bytes.len(), 1, "f8 expects 1 byte");
254 f8(::f8::f8::from(bytes[0]))
255 }
256}
257
258#[primitive(bf16)]
260#[expect(non_camel_case_types)]
261#[derive(Clone, Copy, Debug, PartialEq)]
262pub struct bf16(half::bf16);
263
264impl Zero for bf16 {
265 fn zero() -> Self {
266 bf16(half::bf16::from_f32(0.0))
267 }
268
269 fn is_zero(&self) -> bool {
270 self.0.is_zero()
271 }
272}
273
274impl One for bf16 {
275 fn one() -> Self {
276 bf16(half::bf16::from_f32(1.0))
277 }
278}
279
280impl Add<Self> for bf16 {
281 type Output = Self;
282
283 fn add(self, rhs: Self) -> Self {
284 bf16(self.0 + rhs.0)
285 }
286}
287
288impl Sub<Self> for bf16 {
289 type Output = Self;
290
291 fn sub(self, rhs: Self) -> Self {
292 bf16(self.0 - rhs.0)
293 }
294}
295
296impl Mul<Self> for bf16 {
297 type Output = Self;
298
299 fn mul(self, rhs: Self) -> Self {
300 bf16(self.0 * rhs.0)
301 }
302}
303
304impl Div<Self> for bf16 {
305 type Output = Self;
306
307 fn div(self, rhs: Self) -> Self {
308 bf16(self.0 / rhs.0)
309 }
310}
311
312impl Rem<Self> for bf16 {
313 type Output = Self;
314
315 fn rem(self, rhs: Self) -> Self {
316 bf16(self.0 % rhs.0)
317 }
318}
319
320impl Num for bf16 {
321 type FromStrRadixErr = <half::bf16 as Num>::FromStrRadixErr;
322
323 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
324 Ok(bf16(<half::bf16 as Num>::from_str_radix(str, radix)?))
325 }
326}
327
328impl rand::distr::Distribution<bf16> for StandardUniform {
329 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> bf16 {
330 let val: f32 = rng.random_range(-1.0..1.0);
331 bf16(half::bf16::from_f32(val))
332 }
333}
334
335impl Scalar for bf16 {
336 const BITS: usize = 16;
337}
338
339impl ScalarBytes for bf16 {
340 fn from_le_bytes(bytes: &[u8]) -> Self {
341 let raw = u16::from_le_bytes(bytes.try_into().expect("bf16 expects 2 bytes"));
342 bf16(half::bf16::from_bits(raw))
343 }
344}
345
346impl bf16 {
347 pub fn from_f32(val: f32) -> Self {
349 bf16(half::bf16::from_f32(val))
350 }
351
352 pub fn to_f32(self) -> f32 {
354 self.0.to_f32()
355 }
356}
357
358impl From<bf16> for f32 {
359 fn from(val: bf16) -> Self {
360 val.to_f32()
361 }
362}
363
364#[primitive(f8e4m3)]
366#[allow(non_camel_case_types)]
367#[derive(Clone, Copy, Debug, PartialEq)]
368pub struct f8e4m3(u8);
369
370impl Zero for f8e4m3 {
371 fn zero() -> Self {
372 f8e4m3(crate::float::F8E4_ZERO)
373 }
374
375 fn is_zero(&self) -> bool {
376 crate::float::f8_e4_is_zero(self.0)
377 }
378}
379
380impl One for f8e4m3 {
381 fn one() -> Self {
382 f8e4m3(crate::float::F8E4_ONE)
383 }
384}
385
386impl Add<Self> for f8e4m3 {
387 type Output = Self;
388
389 fn add(self, rhs: Self) -> Self {
390 Self::from_f32(self.to_f32() + rhs.to_f32())
391 }
392}
393
394impl Sub<Self> for f8e4m3 {
395 type Output = Self;
396
397 fn sub(self, rhs: Self) -> Self {
398 Self::from_f32(self.to_f32() - rhs.to_f32())
399 }
400}
401
402impl Mul<Self> for f8e4m3 {
403 type Output = Self;
404
405 fn mul(self, rhs: Self) -> Self {
406 Self::from_f32(self.to_f32() * rhs.to_f32())
407 }
408}
409
410impl Div<Self> for f8e4m3 {
411 type Output = Self;
412
413 fn div(self, rhs: Self) -> Self {
414 Self::from_f32(self.to_f32() / rhs.to_f32())
415 }
416}
417
418impl Rem<Self> for f8e4m3 {
419 type Output = Self;
420
421 fn rem(self, rhs: Self) -> Self {
422 Self::from_f32(self.to_f32() % rhs.to_f32())
423 }
424}
425
426impl Num for f8e4m3 {
427 type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
428
429 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
430 Ok(Self::from_f32(<f32 as Num>::from_str_radix(str, radix)?))
431 }
432}
433
434impl rand::distr::Distribution<f8e4m3> for StandardUniform {
435 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f8e4m3 {
436 let val: f32 = rng.random_range(-1.0..1.0);
437 f8e4m3::from_f32(val)
438 }
439}
440
441impl Scalar for f8e4m3 {
442 const BITS: usize = 8;
443}
444
445impl ScalarBytes for f8e4m3 {
446 fn from_le_bytes(bytes: &[u8]) -> Self {
447 assert_eq!(bytes.len(), 1, "f8e4m3 expects 1 byte");
448 f8e4m3(bytes[0])
449 }
450}
451
452impl f8e4m3 {
453 pub fn from_f32(val: f32) -> Self {
455 f8e4m3(crate::float::f8_e4_from_f32(val))
456 }
457
458 pub fn to_f32(self) -> f32 {
460 crate::float::f8_e4_to_f32(self.0)
461 }
462}
463
464impl From<f8e4m3> for f32 {
465 fn from(val: f8e4m3) -> Self {
466 val.to_f32()
467 }
468}
469
470#[primitive(i4)]
474#[allow(non_camel_case_types)]
475#[derive(Clone, Copy, Debug, PartialEq)]
476pub struct i4(i8);
477
478impl Zero for i4 {
479 fn zero() -> Self {
480 i4(0)
481 }
482
483 fn is_zero(&self) -> bool {
484 self.0 == 0
485 }
486}
487
488impl One for i4 {
489 fn one() -> Self {
490 i4(1)
491 }
492}
493
494impl Add<Self> for i4 {
495 type Output = Self;
496
497 fn add(self, rhs: Self) -> Self {
498 Self::from_lsb(self.0 + rhs.0)
499 }
500}
501
502impl Sub<Self> for i4 {
503 type Output = Self;
504
505 fn sub(self, rhs: Self) -> Self {
506 Self::from_lsb(self.0 - rhs.0)
507 }
508}
509
510impl Mul<Self> for i4 {
511 type Output = Self;
512
513 fn mul(self, rhs: Self) -> Self {
514 Self::from_lsb(self.0 * rhs.0)
515 }
516}
517
518impl Div<Self> for i4 {
519 type Output = Self;
520
521 fn div(self, rhs: Self) -> Self {
522 Self::from_lsb(self.0 / rhs.0)
523 }
524}
525
526impl Rem<Self> for i4 {
527 type Output = Self;
528
529 fn rem(self, rhs: Self) -> Self {
530 Self::from_lsb(self.0 % rhs.0)
531 }
532}
533
534impl Num for i4 {
535 type FromStrRadixErr = <i8 as Num>::FromStrRadixErr;
536
537 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
538 Ok(Self::from_lsb(<i8 as Num>::from_str_radix(str, radix)?))
539 }
540}
541
542impl Scalar for i4 {
543 const BITS: usize = 4;
544}
545
546impl i4 {
547 fn from_lsb(n: i8) -> Self {
548 i4((n << 4) >> 4)
549 }
550
551 pub fn from_i32(val: i32) -> Self {
553 Self::from_lsb(val as i8)
554 }
555
556 pub fn to_i32(self) -> i32 {
558 i32::from(self.0)
559 }
560}
561
562impl From<i4> for i32 {
563 fn from(val: i4) -> Self {
564 val.to_i32()
565 }
566}