1use std::fmt::Debug;
2use std::marker::ConstParamTy;
3use std::marker::PhantomData;
4use std::ops::DerefMut;
5use std::sync::LazyLock;
6use std::sync::Mutex;
7
8use furiosa_mapping::*;
9use furiosa_mapping_macro::primitive;
10use furiosa_opt_macro::m;
11
12use crate::prelude::DmTensor;
13use crate::prelude::HbmTensor;
14
15use super::memory_tensor::DmTensorView;
16use super::scalar::Scalar;
17use super::stream_tensor::BeginTensor;
18use super::tensor::Tensor;
19
20#[derive(Debug, PartialEq, Eq, ConstParamTy)]
22pub enum Tu {
23 Main,
25 Sub,
27}
28
29#[derive(Debug, PartialEq, Eq, ConstParamTy)]
31pub enum Dma {
32 Tensor,
34 Pcie,
36}
37
38#[primitive(TuContext)]
40#[derive(Debug)]
41pub struct TuContext<const T: Tu> {
42 _marker: PhantomData<()>,
43}
44
45impl<const T: Tu> crate::runtime::DeviceSend for &mut TuContext<T> {}
46
47#[primitive(DmaContext)]
49#[derive(Debug)]
50pub struct DmaContext<const DMA: Dma> {
51 _marker: PhantomData<()>,
52}
53
54impl<const DMA: Dma> crate::runtime::DeviceSend for &mut DmaContext<DMA> {}
55
56#[primitive(Context)]
58#[derive(Debug)]
59pub struct Context {
60 pub main: TuContext<{ Tu::Main }>,
62 pub sub: TuContext<{ Tu::Sub }>,
64 pub tdma: DmaContext<{ Dma::Tensor }>,
66 pub pdma: DmaContext<{ Dma::Pcie }>,
68}
69
70impl crate::runtime::DeviceSend for &mut Context {}
71
72impl Context {
73 pub fn acquire() -> impl DerefMut<Target = Context> {
75 static SINGLETON: LazyLock<Mutex<Context>> = LazyLock::new(|| {
76 Mutex::new(Context {
77 main: TuContext::<{ Tu::Main }> { _marker: PhantomData },
78 sub: TuContext::<{ Tu::Sub }> { _marker: PhantomData },
79 tdma: DmaContext::<{ Dma::Tensor }> { _marker: PhantomData },
80 pdma: DmaContext::<{ Dma::Pcie }> { _marker: PhantomData },
81 })
82 });
83 SINGLETON.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
84 }
85}
86
87impl<const DMA: Dma> DmaContext<DMA> {
88 pub fn hbm_cluster_shuffle<D: Scalar, Chip: M, Element: M>(
100 &mut self,
101 _tensor: &HbmTensor<D, Chip, Element>,
102 _shuffle_pattern: &[usize],
103 ) -> HbmTensor<D, Chip, Element> {
104 todo!("dm_cluster_shuffle: Low-level DMA command implementation for cluster <-> cluster HBM transfer")
107 }
108
109 pub fn hbm_chip_shuffle<D: Scalar, Chip: M, Element: M>(
123 &mut self,
124 _tensor: &HbmTensor<D, Chip, Element>,
125 _shuffle_pattern: &[usize],
126 ) -> HbmTensor<D, Chip, Element> {
127 todo!("dm_chip_shuffle: Low-level DMA command implementation for chip<->chip HBM transfer")
128 }
129}
130
131impl DmaContext<{ Dma::Tensor }> {
132 #[primitive(DmaContext::dm_cluster_shuffle)]
144 pub fn dm_cluster_shuffle<D: Scalar, const CLUSTER_DIM: usize, Chip: M, Cluster: M, Slice: M, Element: M>(
145 &mut self,
146 tensor: DmTensorView<D, Chip, Cluster, Slice, Element>,
147 shuffle_pattern: &[usize],
148 ) -> DmTensor<D, Chip, Cluster, Slice, Element> {
149 let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element> = unsafe { DmTensor::from_addr(0) };
150
151 for (target_cluster_idx, source_cluster_idx) in shuffle_pattern.iter().enumerate() {
152 tensor
153 .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(*source_cluster_idx)
154 .to_dm_view(
155 self,
156 shuffled
157 .view_mut()
158 .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(target_cluster_idx),
159 );
160 }
161
162 shuffled
163 }
164
165 #[primitive(DmaContext::dm_chip_shuffle)]
179 pub fn dm_chip_shuffle<D: Scalar, const CHIP_DIM: usize, Chip: M, Cluster: M, Slice: M, Element: M>(
180 &mut self,
181 tensor: DmTensorView<D, Chip, Cluster, Slice, Element>,
182 shuffle_pattern: &[usize; CHIP_DIM],
183 ) -> DmTensor<D, Chip, Cluster, Slice, Element> {
184 let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element> = unsafe { DmTensor::from_addr(0) };
185
186 for (target_chip_idx, source_chip_idx) in shuffle_pattern.iter().enumerate() {
187 tensor
188 .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(*source_chip_idx)
189 .to_dm_view(
190 self,
191 shuffled
192 .view_mut()
193 .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(target_chip_idx),
194 );
195 }
196
197 shuffled
198 }
199}
200
201impl<const T: Tu> TuContext<{ T }> {
202 #[primitive(TuContext::begin)]
204 pub fn begin<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>(
205 &'l mut self,
206 tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
207 ) -> BeginTensor<'l, { T }, D, Chip, Cluster, Slice, Identity, Element> {
208 BeginTensor::new(self, unsafe { tensor.inner.read().transmute() })
210 }
211
212 #[primitive(TuContext::begin_interleaved)]
214 pub fn begin_interleaved<'l, I: AxisName, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>(
215 &'l mut self,
216 lhs: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
217 rhs: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
218 ) -> BeginTensor<'l, { T }, D, Chip, Cluster, Slice, Symbol<I>, Element> {
219 let mut output = Tensor::<D, m![{ Chip }, { Cluster }, { Slice }, { Symbol<I> }, { Element }]>::uninit();
220
221 for (i, input) in [lhs, rhs].into_iter().enumerate() {
222 output
223 .view_mut()
224 .tile::<Symbol<I>, m![{ Chip }, { Cluster }, { Slice }, 1 # 2, { Element }], 1>(i)
225 .write_transpose(input.inner, false);
226 }
227
228 BeginTensor::new(self, output)
229 }
230}
231
232impl TuContext<{ Tu::Sub }> {
233 #[primitive(TuContext::parallel_copy_cluster_slice)]
248 pub fn parallel_copy_cluster_slice<
249 'l,
250 const CLUSTER_DIM: usize,
251 AxisToSlice: M,
252 AxisSlicedElement: M,
253 Element2: M,
254 D: Scalar,
255 Chip: M,
256 Cluster: M,
257 Slice: M,
258 Element: M,
259 >(
260 &mut self,
261 tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
262 slice_indices: &[usize; CLUSTER_DIM],
263 ) -> super::memory_tensor::DmTensor<D, Chip, Cluster, Slice, Element2> {
264 let mut sliced = unsafe { DmTensor::from_addr(0) };
265
266 for (cluster_idx, slice_idx) in slice_indices.iter().enumerate() {
267 let cluster_slice = tensor.cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(cluster_idx);
268 cluster_slice
269 .tile::<AxisToSlice, 1, AxisSlicedElement>(*slice_idx)
270 .to_dm_view_pcopy(
271 self,
272 sliced
273 .view_mut()
274 .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(cluster_idx),
275 );
276 }
277
278 sliced
279 }
280
281 #[primitive(TuContext::parallel_copy_chip_slice)]
295 pub fn parallel_copy_chip_slice<
296 'l,
297 const CHIP_DIM: usize,
298 AxisToSlice: M,
299 AxisSlicedElement: M,
300 Element2: M,
301 D: Scalar,
302 Chip: M,
303 Cluster: M,
304 Slice: M,
305 Element: M,
306 >(
307 &mut self,
308 tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
309 slice_indices: &[usize; CHIP_DIM],
310 ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
311 let mut sliced = unsafe { DmTensor::from_addr(0) };
312
313 for (chip_idx, slice_idx) in slice_indices.iter().enumerate() {
314 let chip_slice = tensor.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(chip_idx);
315 chip_slice
316 .tile::<AxisToSlice, 1, AxisSlicedElement>(*slice_idx)
317 .to_dm_view_pcopy(
318 self,
319 sliced
320 .view_mut()
321 .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(chip_idx),
322 );
323 }
324
325 sliced
326 }
327}