furiosa_visa_std/
cast.rs

1use crate::scalar::{bf16, f8e4m3, i4};
2use furiosa_mapping::M;
3
4use super::scalar::Scalar;
5use furiosa_mapping_macro::primitive;
6
7/// Trait for types that can be cast during fetch operations.
8pub trait FetchCast<D: Scalar>: Into<D> + Cast<D> {}
9
10// TODO: complete list of fetch conversions
11// Int4ToInt5,
12// Int4ToInt32,
13// Int8ToInt9,
14// Int8ToInt32,
15// Int16ToInt32,
16// Float8e4m3ToFloat32,
17// Float8e5m2ToFloat32,
18// Bfloat16ToFloat32,
19// Float16ToFloat32,
20// Float32ToBfloat16,
21// // Renegade-S only
22// Int4ToInt9,
23// Int16ToInt9,
24// Float8e4m3ToBfloat16,
25// Float8e5m2ToBfloat16,
26
27// Identity casts
28impl<D> FetchCast<D> for D where D: Scalar {}
29
30impl FetchCast<i32> for i8 {}
31impl FetchCast<f32> for bf16 {}
32impl FetchCast<f32> for f8e4m3 {}
33impl FetchCast<i32> for i4 {}
34
35/// Trait for casting between scalar types.
36pub trait Cast<D: Scalar> {
37    /// Casts self to target type D.
38    fn cast(self) -> D;
39}
40
41impl<D: Scalar> Cast<D> for D {
42    fn cast(self) -> D {
43        self
44    }
45}
46
47impl Cast<i32> for i8 {
48    fn cast(self) -> i32 {
49        self as i32
50    }
51}
52
53impl Cast<i8> for i32 {
54    fn cast(self) -> i8 {
55        self as i8
56    }
57}
58
59impl Cast<f32> for bf16 {
60    fn cast(self) -> f32 {
61        self.to_f32()
62    }
63}
64
65impl Cast<bf16> for f32 {
66    fn cast(self) -> bf16 {
67        bf16::from_f32(self)
68    }
69}
70
71impl Cast<f32> for f8e4m3 {
72    fn cast(self) -> f32 {
73        self.to_f32()
74    }
75}
76
77impl Cast<f8e4m3> for f32 {
78    fn cast(self) -> f8e4m3 {
79        f8e4m3::from_f32(self)
80    }
81}
82
83impl Cast<i32> for i4 {
84    fn cast(self) -> i32 {
85        self.to_i32()
86    }
87}
88
89impl Cast<i4> for i32 {
90    fn cast(self) -> i4 {
91        i4::from_i32(self)
92    }
93}
94
95/// Output type for contraction (DPE accumulates in wider type).
96pub trait ContractionCast: Scalar {
97    /// The wider scalar type that accumulates contraction results.
98    type Output: Scalar;
99}
100
101/// Trait for stream tensors that can be cast to a different scalar type.
102///
103/// FIXME: This trait exists purely to disambiguate tensor `cast()` from scalar
104/// `Cast::cast()` for rust-analyzer.
105pub trait StreamCast<D: Scalar> {
106    /// The output tensor type after casting to scalar type `D2`.
107    type CastOutput<D2: Scalar, OutPacket: M>
108    where
109        D: Cast<D2>;
110
111    /// Casts this tensor's scalar type from `D` to `D2`.
112    ///
113    /// The cast engine operates on a single 32-byte flit.
114    /// The input packet must already be exactly 32 bytes (ensured by the Collect Engine).
115    /// After casting, the output packet is padded back to 32 bytes.
116    /// Time passes through unchanged.
117    #[primitive(StreamCast::cast)]
118    fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
119    where
120        D: Cast<D2>;
121}
122
123impl ContractionCast for i8 {
124    type Output = i32;
125}
126
127impl ContractionCast for bf16 {
128    type Output = f32;
129}
130
131impl ContractionCast for f8e4m3 {
132    type Output = f32;
133}
134
135impl ContractionCast for i4 {
136    type Output = i32;
137}