1use crate::scalar::{bf16, f8e4m3, i4};
2use furiosa_mapping::M;
3
4use super::scalar::Scalar;
5use furiosa_mapping_macro::primitive;
6
7pub trait FetchCast<D: Scalar>: Into<D> + Cast<D> {}
9
10impl<D> FetchCast<D> for D where D: Scalar {}
29
30impl FetchCast<i32> for i8 {}
31impl FetchCast<f32> for bf16 {}
32impl FetchCast<f32> for f8e4m3 {}
33impl FetchCast<i32> for i4 {}
34
35pub trait Cast<D: Scalar> {
37 fn cast(self) -> D;
39}
40
41impl<D: Scalar> Cast<D> for D {
42 fn cast(self) -> D {
43 self
44 }
45}
46
47impl Cast<i32> for i8 {
48 fn cast(self) -> i32 {
49 self as i32
50 }
51}
52
53impl Cast<i8> for i32 {
54 fn cast(self) -> i8 {
55 self as i8
56 }
57}
58
59impl Cast<f32> for bf16 {
60 fn cast(self) -> f32 {
61 self.to_f32()
62 }
63}
64
65impl Cast<bf16> for f32 {
66 fn cast(self) -> bf16 {
67 bf16::from_f32(self)
68 }
69}
70
71impl Cast<f32> for f8e4m3 {
72 fn cast(self) -> f32 {
73 self.to_f32()
74 }
75}
76
77impl Cast<f8e4m3> for f32 {
78 fn cast(self) -> f8e4m3 {
79 f8e4m3::from_f32(self)
80 }
81}
82
83impl Cast<i32> for i4 {
84 fn cast(self) -> i32 {
85 self.to_i32()
86 }
87}
88
89impl Cast<i4> for i32 {
90 fn cast(self) -> i4 {
91 i4::from_i32(self)
92 }
93}
94
95pub trait ContractionCast: Scalar {
97 type Output: Scalar;
99}
100
101pub trait StreamCast<D: Scalar> {
106 type CastOutput<D2: Scalar, OutPacket: M>
108 where
109 D: Cast<D2>;
110
111 #[primitive(StreamCast::cast)]
118 fn cast<D2: Scalar, OutPacket: M>(self) -> Self::CastOutput<D2, OutPacket>
119 where
120 D: Cast<D2>;
121}
122
123impl ContractionCast for i8 {
124 type Output = i32;
125}
126
127impl ContractionCast for bf16 {
128 type Output = f32;
129}
130
131impl ContractionCast for f8e4m3 {
132 type Output = f32;
133}
134
135impl ContractionCast for i4 {
136 type Output = i32;
137}