furiosa_visa_std/
context.rs

1use std::fmt::Debug;
2use std::marker::ConstParamTy;
3use std::marker::PhantomData;
4use std::ops::DerefMut;
5use std::sync::LazyLock;
6use std::sync::Mutex;
7
8use furiosa_mapping::*;
9use furiosa_mapping_macro::primitive;
10use furiosa_opt_macro::m;
11
12use crate::prelude::DmTensor;
13use crate::prelude::HbmTensor;
14
15use super::memory_tensor::DmTensorView;
16use super::scalar::Scalar;
17use super::stream_tensor::BeginTensor;
18use super::tensor::Tensor;
19
20/// Tensor units.
21#[derive(Debug, PartialEq, Eq, ConstParamTy)]
22pub enum Tu {
23    /// Main context.
24    Main,
25    /// Sub context.
26    Sub,
27}
28
29/// DMA units.
30#[derive(Debug, PartialEq, Eq, ConstParamTy)]
31pub enum Dma {
32    /// Tensor DMA.
33    Tensor,
34    /// PCIe DMA.
35    Pcie,
36}
37
38/// Context for a specific tensor unit.
39#[primitive(TuContext)]
40#[derive(Debug)]
41pub struct TuContext<const T: Tu> {
42    _marker: PhantomData<()>,
43}
44
45impl<const T: Tu> crate::runtime::DeviceSend for &mut TuContext<T> {}
46
47/// Context for a specific DMA engine.
48#[primitive(DmaContext)]
49#[derive(Debug)]
50pub struct DmaContext<const DMA: Dma> {
51    _marker: PhantomData<()>,
52}
53
54impl<const DMA: Dma> crate::runtime::DeviceSend for &mut DmaContext<DMA> {}
55
56/// Device context.
57#[primitive(Context)]
58#[derive(Debug)]
59pub struct Context {
60    /// Tensor unit for the main context.
61    pub main: TuContext<{ Tu::Main }>,
62    /// Tensor unit for the sub context.
63    pub sub: TuContext<{ Tu::Sub }>,
64    /// Tensor DMA context.
65    pub tdma: DmaContext<{ Dma::Tensor }>,
66    /// PCIe DMA context.
67    pub pdma: DmaContext<{ Dma::Pcie }>,
68}
69
70impl crate::runtime::DeviceSend for &mut Context {}
71
72impl Context {
73    /// Acquire the tensor units.
74    pub fn acquire() -> impl DerefMut<Target = Context> {
75        static SINGLETON: LazyLock<Mutex<Context>> = LazyLock::new(|| {
76            Mutex::new(Context {
77                main: TuContext::<{ Tu::Main }> { _marker: PhantomData },
78                sub: TuContext::<{ Tu::Sub }> { _marker: PhantomData },
79                tdma: DmaContext::<{ Dma::Tensor }> { _marker: PhantomData },
80                pdma: DmaContext::<{ Dma::Pcie }> { _marker: PhantomData },
81            })
82        });
83        SINGLETON.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
84    }
85}
86
87impl<const DMA: Dma> DmaContext<DMA> {
88    /// Perform cluster shuffle operation using DMA commands (HBM <-> HBM transfer).
89    /// This operation redistributes data across clusters according to the shuffle pattern.
90    ///
91    /// # Arguments
92    /// * `tensor` - Input tensor with cluster dimension
93    /// * `shuffle_pattern` - Array mapping source cluster to destination cluster
94    ///
95    /// # Example
96    /// For Cluster=2 with shuffle_pattern=\[1,0\]:
97    /// - Data from Cluster 1 goes to Cluster 0
98    /// - Data from Cluster 0 goes to Cluster 1
99    pub fn hbm_cluster_shuffle<D: Scalar, Chip: M, Element: M>(
100        &mut self,
101        _tensor: &HbmTensor<D, Chip, Element>,
102        _shuffle_pattern: &[usize],
103    ) -> HbmTensor<D, Chip, Element> {
104        // Low-level implementation using DMA commands for inter-cluster data transfer
105        // This will be lowered to actual hardware DMA commands by the compiler
106        todo!("dm_cluster_shuffle: Low-level DMA command implementation for cluster <-> cluster HBM transfer")
107    }
108
109    /// Perform chip shuffle using PCIe/Tensor DMA commands (HBM <-> HBM transfer across chips).
110    /// This operation redistributes data across chips according to the shuffle pattern.
111    ///
112    /// # Arguments
113    /// * `tensor` - Input tensor with chip dimension
114    /// * `shuffle_pattern` - Array mapping source chip to destination chip
115    ///
116    /// # Example
117    /// For Chip=4 with shuffle_pattern=\[1,2,3,0\]:
118    /// - Data from Chip 1 goes to Chip 0
119    /// - Data from Chip 2 goes to Chip 1
120    /// - Data from Chip 3 goes to Chip 2
121    /// - Data from Chip 0 goes to Chip 3
122    pub fn hbm_chip_shuffle<D: Scalar, Chip: M, Element: M>(
123        &mut self,
124        _tensor: &HbmTensor<D, Chip, Element>,
125        _shuffle_pattern: &[usize],
126    ) -> HbmTensor<D, Chip, Element> {
127        todo!("dm_chip_shuffle: Low-level DMA command implementation for chip<->chip HBM transfer")
128    }
129}
130
131impl DmaContext<{ Dma::Tensor }> {
132    /// Perform cluster shuffle operation using DMA commands (DM <-> DM transfer).
133    /// This operation redistributes data across clusters according to the shuffle pattern.
134    ///
135    /// # Arguments
136    /// * `tensor` - Input tensor with cluster dimension
137    /// * `shuffle_pattern` - Array mapping source cluster to destination cluster
138    ///
139    /// # Example
140    /// For Cluster=2 with shuffle_pattern=\[1,0\]:
141    /// - Data from Cluster 1 goes to Cluster 0
142    /// - Data from Cluster 0 goes to Cluster 1
143    #[primitive(DmaContext::dm_cluster_shuffle)]
144    pub fn dm_cluster_shuffle<D: Scalar, const CLUSTER_DIM: usize, Chip: M, Cluster: M, Slice: M, Element: M>(
145        &mut self,
146        tensor: DmTensorView<D, Chip, Cluster, Slice, Element>,
147        shuffle_pattern: &[usize],
148    ) -> DmTensor<D, Chip, Cluster, Slice, Element> {
149        let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element> = unsafe { DmTensor::from_addr(0) };
150
151        for (target_cluster_idx, source_cluster_idx) in shuffle_pattern.iter().enumerate() {
152            tensor
153                .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(*source_cluster_idx)
154                .to_dm_view(
155                    self,
156                    shuffled
157                        .view_mut()
158                        .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(target_cluster_idx),
159                );
160        }
161
162        shuffled
163    }
164
165    /// Perform chip shuffle using Tensor DMA commands (DM <-> DM transfer across chips).
166    /// This operation redistributes data across chips according to the shuffle pattern.
167    ///
168    /// # Arguments
169    /// * `tensor` - Input tensor with chip dimension
170    /// * `shuffle_pattern` - Array mapping source chip to destination chip
171    ///
172    /// # Example
173    /// For Chip=4 with shuffle_pattern=\[1,2,3,0\]:
174    /// - Data from Chip 1 goes to Chip 0
175    /// - Data from Chip 2 goes to Chip 1
176    /// - Data from Chip 3 goes to Chip 2
177    /// - Data from Chip 0 goes to Chip 3
178    #[primitive(DmaContext::dm_chip_shuffle)]
179    pub fn dm_chip_shuffle<D: Scalar, const CHIP_DIM: usize, Chip: M, Cluster: M, Slice: M, Element: M>(
180        &mut self,
181        tensor: DmTensorView<D, Chip, Cluster, Slice, Element>,
182        shuffle_pattern: &[usize; CHIP_DIM],
183    ) -> DmTensor<D, Chip, Cluster, Slice, Element> {
184        let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element> = unsafe { DmTensor::from_addr(0) };
185
186        for (target_chip_idx, source_chip_idx) in shuffle_pattern.iter().enumerate() {
187            tensor
188                .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(*source_chip_idx)
189                .to_dm_view(
190                    self,
191                    shuffled
192                        .view_mut()
193                        .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(target_chip_idx),
194                );
195        }
196
197        shuffled
198    }
199}
200
201impl<const T: Tu> TuContext<{ T }> {
202    /// Begin a tensor unit operation in this context.
203    #[primitive(TuContext::begin)]
204    pub fn begin<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>(
205        &'l mut self,
206        tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
207    ) -> BeginTensor<'l, { T }, D, Chip, Cluster, Slice, Identity, Element> {
208        // SAFETY: the mappings differ only by `Identity`.
209        BeginTensor::new(self, unsafe { tensor.inner.read().transmute() })
210    }
211
212    /// Begin a tensor unit operation in this context with interleaved tensors.
213    #[primitive(TuContext::begin_interleaved)]
214    pub fn begin_interleaved<'l, I: AxisName, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>(
215        &'l mut self,
216        lhs: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
217        rhs: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
218    ) -> BeginTensor<'l, { T }, D, Chip, Cluster, Slice, Symbol<I>, Element> {
219        let mut output = Tensor::<D, m![{ Chip }, { Cluster }, { Slice }, { Symbol<I> }, { Element }]>::uninit();
220
221        for (i, input) in [lhs, rhs].into_iter().enumerate() {
222            output
223                .view_mut()
224                .tile::<Symbol<I>, m![{ Chip }, { Cluster }, { Slice }, 1 # 2, { Element }], 1>(i)
225                .write_transpose(input.inner, false);
226        }
227
228        BeginTensor::new(self, output)
229    }
230}
231
232impl TuContext<{ Tu::Sub }> {
233    /// Perform asymmetric cluster slice operation using ParallelCopy (stos command).
234    /// This operation allows different clusters to select different slice positions.
235    ///
236    /// # Arguments
237    /// * `tensor` - Input tensor with cluster dimension
238    /// * `slice_indices` - Array of slice indices, one per cluster
239    ///
240    /// # Example
241    /// For Cluster=2 with slice_indices=\[1,0\]:
242    /// - Cluster 0 selects slice position 1
243    /// - Cluster 1 selects slice position 0
244    ///
245    /// # Restrictions
246    /// The `AxisToSlice` should be the outermost axis in `Element`.
247    #[primitive(TuContext::parallel_copy_cluster_slice)]
248    pub fn parallel_copy_cluster_slice<
249        'l,
250        const CLUSTER_DIM: usize,
251        AxisToSlice: M,
252        AxisSlicedElement: M,
253        Element2: M,
254        D: Scalar,
255        Chip: M,
256        Cluster: M,
257        Slice: M,
258        Element: M,
259    >(
260        &mut self,
261        tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
262        slice_indices: &[usize; CLUSTER_DIM],
263    ) -> super::memory_tensor::DmTensor<D, Chip, Cluster, Slice, Element2> {
264        let mut sliced = unsafe { DmTensor::from_addr(0) };
265
266        for (cluster_idx, slice_idx) in slice_indices.iter().enumerate() {
267            let cluster_slice = tensor.cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(cluster_idx);
268            cluster_slice
269                .tile::<AxisToSlice, 1, AxisSlicedElement>(*slice_idx)
270                .to_dm_view_pcopy(
271                    self,
272                    sliced
273                        .view_mut()
274                        .cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(cluster_idx),
275                );
276        }
277
278        sliced
279    }
280
281    /// Perform asymmetric chip slice operation using ParallelCopy (stos command).
282    /// This operation allows different chips to select different slice positions.
283    ///
284    /// # Arguments
285    /// * `tensor` - Input tensor with chip dimension
286    /// * `slice_indices` - Array of slice indices, one per chip
287    ///
288    /// # Example
289    /// For Chip=4 with slice_indices=\[3,0,1,2\]:
290    /// - Chip 0 selects slice position 3
291    /// - Chip 1 selects slice position 0
292    /// - Chip 2 selects slice position 1
293    /// - Chip 3 selects slice position 2
294    #[primitive(TuContext::parallel_copy_chip_slice)]
295    pub fn parallel_copy_chip_slice<
296        'l,
297        const CHIP_DIM: usize,
298        AxisToSlice: M,
299        AxisSlicedElement: M,
300        Element2: M,
301        D: Scalar,
302        Chip: M,
303        Cluster: M,
304        Slice: M,
305        Element: M,
306    >(
307        &mut self,
308        tensor: DmTensorView<'l, D, Chip, Cluster, Slice, Element>,
309        slice_indices: &[usize; CHIP_DIM],
310    ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
311        let mut sliced = unsafe { DmTensor::from_addr(0) };
312
313        for (chip_idx, slice_idx) in slice_indices.iter().enumerate() {
314            let chip_slice = tensor.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(chip_idx);
315            chip_slice
316                .tile::<AxisToSlice, 1, AxisSlicedElement>(*slice_idx)
317                .to_dm_view_pcopy(
318                    self,
319                    sliced
320                        .view_mut()
321                        .chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(chip_idx),
322                );
323        }
324
325        sliced
326    }
327}