furiosa_visa_std/
scalar.rs

1//! Scalar types.
2
3use furiosa_mapping_macro::primitive;
4use num_traits::{Num, One, Zero};
5use rand::distr::StandardUniform;
6use std::fmt::Debug;
7use std::ops::{Add, Div, Mul, Rem, Sub};
8
9/// A trait for scalar types.
10pub trait Scalar: ndarray::LinalgScalar + Debug + Clone + Copy + PartialEq + Num {
11    /// Number of bits per element.
12    const BITS: usize;
13
14    /// Returns the byte size for `length` elements of this scalar type.
15    ///
16    /// Panics if the total bit count is not a multiple of 8.
17    fn size_in_bytes_from_length(length: usize) -> usize {
18        assert_eq!((length * Self::BITS) % 8, 0, "total bits must be byte-aligned");
19        (length * Self::BITS) / 8
20    }
21
22    /// Returns the number of elements that fit in `bytes` bytes.
23    ///
24    /// Panics if the byte count does not evenly divide into whole elements.
25    fn length_from_bytes(bytes: usize) -> usize {
26        assert_eq!(
27            (bytes * 8) % Self::BITS,
28            0,
29            "bytes must correspond to a whole number of elements"
30        );
31        (bytes * 8) / Self::BITS
32    }
33}
34
35/// A byte-aligned [`Scalar`] that can be decoded from its little-endian byte representation.
36///
37/// Excludes sub-byte scalars like [`i4`] for which a single element cannot be addressed at a
38/// byte boundary. This is what [`crate::memory_tensor::HostTensor::from_safetensors`] requires,
39/// matching the set of dtypes safetensors itself can carry.
40pub trait ScalarBytes: Scalar {
41    /// Decodes one element from `bytes`; `bytes.len()` must equal `Self::BITS / 8`.
42    fn from_le_bytes(bytes: &[u8]) -> Self;
43}
44
45macro_rules! impl_scalar_std {
46    ($($t:ty),*) => {
47        $(
48            impl Scalar for $t {
49                const BITS: usize = std::mem::size_of::<Self>() * 8;
50            }
51            impl ScalarBytes for $t {
52                fn from_le_bytes(bytes: &[u8]) -> Self {
53                    <$t>::from_le_bytes(bytes.try_into().expect("byte length mismatch"))
54                }
55            }
56        )*
57    };
58}
59impl_scalar_std!(i8, i16, i32, f32, u8);
60
61/// A data type that can be either initialized or uninitialized.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Opt<D> {
64    /// Initialized value.
65    Init(D),
66    /// Uninitialized value.
67    Uninit,
68}
69
70impl<D> Opt<D> {
71    /// Maps the initialized value using the provided function, or returns uninitialized.
72    pub fn map<D2>(self, f: impl FnOnce(D) -> D2) -> Opt<D2> {
73        match self {
74            Opt::Init(val) => Opt::Init(f(val)),
75            Opt::Uninit => Opt::Uninit,
76        }
77    }
78
79    /// Combines two Opt values with a function. Returns Init only if both are Init.
80    pub fn zip_map<D2, R>(self, other: Opt<D2>, f: impl FnOnce(D, D2) -> R) -> Opt<R> {
81        match (self, other) {
82            (Opt::Init(a), Opt::Init(b)) => Opt::Init(f(a, b)),
83            _ => Opt::Uninit,
84        }
85    }
86
87    /// Returns the initialized value, or panics if uninitialized.
88    pub fn unwrap(self) -> D {
89        let Opt::Init(val) = self else {
90            panic!("Called unwrap on an uninitialized Opt value.");
91        };
92        val
93    }
94}
95
96impl<D: Add<Output = D>> Add for Opt<D> {
97    type Output = Self;
98
99    fn add(self, rhs: Self) -> Opt<D> {
100        let Opt::Init(lhs) = self else {
101            return Opt::Uninit;
102        };
103        let Opt::Init(rhs) = rhs else {
104            return Opt::Uninit;
105        };
106        Opt::Init(lhs + rhs)
107    }
108}
109
110impl<D: Sub<Output = D>> Sub for Opt<D> {
111    type Output = Self;
112
113    fn sub(self, rhs: Self) -> Opt<D> {
114        let Opt::Init(lhs) = self else {
115            return Opt::Uninit;
116        };
117        let Opt::Init(rhs) = rhs else {
118            return Opt::Uninit;
119        };
120        Opt::Init(lhs - rhs)
121    }
122}
123
124impl<D: Mul<Output = D>> Mul for Opt<D> {
125    type Output = Self;
126
127    fn mul(self, rhs: Self) -> Opt<D> {
128        let Opt::Init(lhs) = self else {
129            return Opt::Uninit;
130        };
131        let Opt::Init(rhs) = rhs else {
132            return Opt::Uninit;
133        };
134        Opt::Init(lhs * rhs)
135    }
136}
137
138impl<D: Div<Output = D>> Div for Opt<D> {
139    type Output = Self;
140
141    fn div(self, rhs: Self) -> Opt<D> {
142        let Opt::Init(lhs) = self else {
143            return Opt::Uninit;
144        };
145        let Opt::Init(rhs) = rhs else {
146            return Opt::Uninit;
147        };
148        Opt::Init(lhs / rhs)
149    }
150}
151
152impl<D: Zero> Zero for Opt<D> {
153    fn zero() -> Self {
154        Opt::Init(D::zero())
155    }
156
157    fn is_zero(&self) -> bool {
158        let Opt::Init(val) = self else {
159            return false;
160        };
161        val.is_zero()
162    }
163}
164
165impl<D: One> One for Opt<D> {
166    fn one() -> Self {
167        Opt::Init(D::one())
168    }
169}
170
171/// 8-bit floating point type.
172#[expect(non_camel_case_types)]
173#[derive(Clone, Copy, Debug, PartialEq)]
174pub struct f8(::f8::f8);
175
176impl Zero for f8 {
177    fn zero() -> Self {
178        f8(::f8::f8::from(0.0))
179    }
180
181    fn is_zero(&self) -> bool {
182        self.0.float().is_zero()
183    }
184}
185
186impl One for f8 {
187    fn one() -> Self {
188        f8(::f8::f8::from(1.0))
189    }
190}
191
192impl Add<Self> for f8 {
193    type Output = Self;
194
195    fn add(self, rhs: Self) -> Self {
196        f8((self.0.float() + rhs.0.float()).into())
197    }
198}
199
200impl Sub<Self> for f8 {
201    type Output = Self;
202
203    fn sub(self, rhs: Self) -> Self {
204        f8((self.0.float() - rhs.0.float()).into())
205    }
206}
207
208impl Mul<Self> for f8 {
209    type Output = Self;
210
211    fn mul(self, rhs: Self) -> Self {
212        f8((self.0.float() * rhs.0.float()).into())
213    }
214}
215
216impl Div<Self> for f8 {
217    type Output = Self;
218
219    fn div(self, rhs: Self) -> Self {
220        f8((self.0.float() / rhs.0.float()).into())
221    }
222}
223
224impl Rem<Self> for f8 {
225    type Output = Self;
226
227    fn rem(self, rhs: Self) -> Self {
228        f8((self.0.float() % rhs.0.float()).into())
229    }
230}
231
232impl Num for f8 {
233    type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
234
235    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
236        Ok(f8(::f8::f8::from(<f32 as Num>::from_str_radix(str, radix)?)))
237    }
238}
239
240impl rand::distr::Distribution<f8> for StandardUniform {
241    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f8 {
242        let val: f32 = rng.random_range(-1.0..1.0);
243        f8(::f8::f8::from(val))
244    }
245}
246
247impl Scalar for f8 {
248    const BITS: usize = 8;
249}
250
251impl ScalarBytes for f8 {
252    fn from_le_bytes(bytes: &[u8]) -> Self {
253        assert_eq!(bytes.len(), 1, "f8 expects 1 byte");
254        f8(::f8::f8::from(bytes[0]))
255    }
256}
257
258/// 16-bit brain floating point type.
259#[primitive(bf16)]
260#[expect(non_camel_case_types)]
261#[derive(Clone, Copy, Debug, PartialEq)]
262pub struct bf16(half::bf16);
263
264impl Zero for bf16 {
265    fn zero() -> Self {
266        bf16(half::bf16::from_f32(0.0))
267    }
268
269    fn is_zero(&self) -> bool {
270        self.0.is_zero()
271    }
272}
273
274impl One for bf16 {
275    fn one() -> Self {
276        bf16(half::bf16::from_f32(1.0))
277    }
278}
279
280impl Add<Self> for bf16 {
281    type Output = Self;
282
283    fn add(self, rhs: Self) -> Self {
284        bf16(self.0 + rhs.0)
285    }
286}
287
288impl Sub<Self> for bf16 {
289    type Output = Self;
290
291    fn sub(self, rhs: Self) -> Self {
292        bf16(self.0 - rhs.0)
293    }
294}
295
296impl Mul<Self> for bf16 {
297    type Output = Self;
298
299    fn mul(self, rhs: Self) -> Self {
300        bf16(self.0 * rhs.0)
301    }
302}
303
304impl Div<Self> for bf16 {
305    type Output = Self;
306
307    fn div(self, rhs: Self) -> Self {
308        bf16(self.0 / rhs.0)
309    }
310}
311
312impl Rem<Self> for bf16 {
313    type Output = Self;
314
315    fn rem(self, rhs: Self) -> Self {
316        bf16(self.0 % rhs.0)
317    }
318}
319
320impl Num for bf16 {
321    type FromStrRadixErr = <half::bf16 as Num>::FromStrRadixErr;
322
323    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
324        Ok(bf16(<half::bf16 as Num>::from_str_radix(str, radix)?))
325    }
326}
327
328impl rand::distr::Distribution<bf16> for StandardUniform {
329    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> bf16 {
330        let val: f32 = rng.random_range(-1.0..1.0);
331        bf16(half::bf16::from_f32(val))
332    }
333}
334
335impl Scalar for bf16 {
336    const BITS: usize = 16;
337}
338
339impl ScalarBytes for bf16 {
340    fn from_le_bytes(bytes: &[u8]) -> Self {
341        let raw = u16::from_le_bytes(bytes.try_into().expect("bf16 expects 2 bytes"));
342        bf16(half::bf16::from_bits(raw))
343    }
344}
345
346impl bf16 {
347    /// Creates `bf16` from `f32`.
348    pub fn from_f32(val: f32) -> Self {
349        bf16(half::bf16::from_f32(val))
350    }
351
352    /// Converts to `f32`.
353    pub fn to_f32(self) -> f32 {
354        self.0.to_f32()
355    }
356}
357
358impl From<bf16> for f32 {
359    fn from(val: bf16) -> Self {
360        val.to_f32()
361    }
362}
363
364/// 8-bit floating point type with 4-bit exponent (E4M3).
365#[primitive(f8e4m3)]
366#[allow(non_camel_case_types)]
367#[derive(Clone, Copy, Debug, PartialEq)]
368pub struct f8e4m3(u8);
369
370impl Zero for f8e4m3 {
371    fn zero() -> Self {
372        f8e4m3(crate::float::F8E4_ZERO)
373    }
374
375    fn is_zero(&self) -> bool {
376        crate::float::f8_e4_is_zero(self.0)
377    }
378}
379
380impl One for f8e4m3 {
381    fn one() -> Self {
382        f8e4m3(crate::float::F8E4_ONE)
383    }
384}
385
386impl Add<Self> for f8e4m3 {
387    type Output = Self;
388
389    fn add(self, rhs: Self) -> Self {
390        Self::from_f32(self.to_f32() + rhs.to_f32())
391    }
392}
393
394impl Sub<Self> for f8e4m3 {
395    type Output = Self;
396
397    fn sub(self, rhs: Self) -> Self {
398        Self::from_f32(self.to_f32() - rhs.to_f32())
399    }
400}
401
402impl Mul<Self> for f8e4m3 {
403    type Output = Self;
404
405    fn mul(self, rhs: Self) -> Self {
406        Self::from_f32(self.to_f32() * rhs.to_f32())
407    }
408}
409
410impl Div<Self> for f8e4m3 {
411    type Output = Self;
412
413    fn div(self, rhs: Self) -> Self {
414        Self::from_f32(self.to_f32() / rhs.to_f32())
415    }
416}
417
418impl Rem<Self> for f8e4m3 {
419    type Output = Self;
420
421    fn rem(self, rhs: Self) -> Self {
422        Self::from_f32(self.to_f32() % rhs.to_f32())
423    }
424}
425
426impl Num for f8e4m3 {
427    type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
428
429    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
430        Ok(Self::from_f32(<f32 as Num>::from_str_radix(str, radix)?))
431    }
432}
433
434impl rand::distr::Distribution<f8e4m3> for StandardUniform {
435    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f8e4m3 {
436        let val: f32 = rng.random_range(-1.0..1.0);
437        f8e4m3::from_f32(val)
438    }
439}
440
441impl Scalar for f8e4m3 {
442    const BITS: usize = 8;
443}
444
445impl ScalarBytes for f8e4m3 {
446    fn from_le_bytes(bytes: &[u8]) -> Self {
447        assert_eq!(bytes.len(), 1, "f8e4m3 expects 1 byte");
448        f8e4m3(bytes[0])
449    }
450}
451
452impl f8e4m3 {
453    /// Creates `f8e4m3` from `f32`.
454    pub fn from_f32(val: f32) -> Self {
455        f8e4m3(crate::float::f8_e4_from_f32(val))
456    }
457
458    /// Converts to `f32`.
459    pub fn to_f32(self) -> f32 {
460        crate::float::f8_e4_to_f32(self.0)
461    }
462}
463
464impl From<f8e4m3> for f32 {
465    fn from(val: f8e4m3) -> Self {
466        val.to_f32()
467    }
468}
469
470/// 4-bit signed integer type.
471///
472/// Stored as `i8` with sign-extension: valid range is `[-8, 7]`.
473#[primitive(i4)]
474#[allow(non_camel_case_types)]
475#[derive(Clone, Copy, Debug, PartialEq)]
476pub struct i4(i8);
477
478impl Zero for i4 {
479    fn zero() -> Self {
480        i4(0)
481    }
482
483    fn is_zero(&self) -> bool {
484        self.0 == 0
485    }
486}
487
488impl One for i4 {
489    fn one() -> Self {
490        i4(1)
491    }
492}
493
494impl Add<Self> for i4 {
495    type Output = Self;
496
497    fn add(self, rhs: Self) -> Self {
498        Self::from_lsb(self.0 + rhs.0)
499    }
500}
501
502impl Sub<Self> for i4 {
503    type Output = Self;
504
505    fn sub(self, rhs: Self) -> Self {
506        Self::from_lsb(self.0 - rhs.0)
507    }
508}
509
510impl Mul<Self> for i4 {
511    type Output = Self;
512
513    fn mul(self, rhs: Self) -> Self {
514        Self::from_lsb(self.0 * rhs.0)
515    }
516}
517
518impl Div<Self> for i4 {
519    type Output = Self;
520
521    fn div(self, rhs: Self) -> Self {
522        Self::from_lsb(self.0 / rhs.0)
523    }
524}
525
526impl Rem<Self> for i4 {
527    type Output = Self;
528
529    fn rem(self, rhs: Self) -> Self {
530        Self::from_lsb(self.0 % rhs.0)
531    }
532}
533
534impl Num for i4 {
535    type FromStrRadixErr = <i8 as Num>::FromStrRadixErr;
536
537    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
538        Ok(Self::from_lsb(<i8 as Num>::from_str_radix(str, radix)?))
539    }
540}
541
542impl Scalar for i4 {
543    const BITS: usize = 4;
544}
545
546impl i4 {
547    fn from_lsb(n: i8) -> Self {
548        i4((n << 4) >> 4)
549    }
550
551    /// Creates `i4` from `i32`.
552    pub fn from_i32(val: i32) -> Self {
553        Self::from_lsb(val as i8)
554    }
555
556    /// Converts to `i32`.
557    pub fn to_i32(self) -> i32 {
558        i32::from(self.0)
559    }
560}
561
562impl From<i4> for i32 {
563    fn from(val: i4) -> Self {
564        val.to_i32()
565    }
566}