furiosa_visa_std/
memory_tensor.rs

1//! Tensors placed on memory.
2
3use rand::Rng;
4use rand::distr::StandardUniform;
5use std::fmt::{self, Display, Formatter};
6use std::marker::PhantomData;
7
8use furiosa_mapping::*;
9use furiosa_mapping_macro::primitive;
10use furiosa_opt_macro::m;
11
12use crate::context::*;
13use crate::scalar::*;
14use crate::tensor::*;
15use crate::tensor_view::*;
16use crate::vector_engine::scalar::VeScalar;
17
18/// Address.
19///
20/// TODO: check that every address is 64-bit.
21pub type Address = u64;
22
23/// Address in the tensor register file.
24#[primitive(TrfAddress)]
25#[derive(Clone, Debug)]
26pub enum TrfAddress {
27    /// Address in the first half of TRF.
28    FirstHalf,
29    /// Address in the second half of TRF.
30    SecondHalf,
31    /// Address in the full TRF.
32    Full,
33}
34
35impl TrfAddress {
36    /// Total TRF capacity in bytes for this address mode.
37    /// - `Full`: 65,536 bytes (8 bank rows × 128 rows × 2 columns × 32 bytes)
38    /// - `FirstHalf` / `SecondHalf`: 32,768 bytes (half of Full)
39    pub fn capacity(&self) -> usize {
40        match self {
41            Self::Full => 65_536,
42            Self::FirstHalf | Self::SecondHalf => 32_768,
43        }
44    }
45}
46
47impl Display for TrfAddress {
48    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::FirstHalf => write!(f, "TrfAddress::FirstHalf"),
51            Self::SecondHalf => write!(f, "TrfAddress::SecondHalf"),
52            Self::Full => write!(f, "TrfAddress::Full"),
53        }
54    }
55}
56
57/// Tensor stored in host memory.
58#[primitive(HostTensor)]
59#[derive(Debug, Clone)]
60pub struct HostTensor<D: Scalar, Element: M> {
61    inner: Tensor<D, Element>,
62}
63
64impl<D: Scalar, Element: M> From<Tensor<D, Element>> for HostTensor<D, Element> {
65    fn from(inner: Tensor<D, Element>) -> Self {
66        Self { inner }
67    }
68}
69
70impl<D: Scalar, Element: M> HostTensor<D, Element> {
71    /// Mapping type alias.
72    pub type Mapping = Element;
73
74    pub(crate) fn inner_tensor(&self) -> &Tensor<D, Element> {
75        &self.inner
76    }
77
78    pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
79        self.inner.data()
80    }
81
82    /// Creates a tensor from a vector.
83    pub fn from_buf(data: Vec<Opt<D>>) -> Self {
84        Tensor::from_buf(data).into()
85    }
86
87    /// Creates a tensor from a `safetensors` tensor view.
88    ///
89    /// The view's per-axis shape must match `Element`'s pair-flattened size list (e.g.
90    /// `m![H, X]` expects safetensors shape `[H.size, X.size]`) and its bytes are decoded as
91    /// little-endian `D` values — LE is mandated by the safetensors format spec, not our
92    /// choice. Returns [`safetensors::SafeTensorError::TensorInvalidInfo`] on any mismatch.
93    pub fn from_safetensors(view: &safetensors::tensor::TensorView<'_>) -> Result<Self, safetensors::SafeTensorError>
94    where
95        D: ScalarBytes,
96    {
97        fn flat_shape(mapping: &Mapping, out: &mut Vec<usize>) {
98            match mapping {
99                Mapping::Pair { left, right } => {
100                    flat_shape(left, out);
101                    flat_shape(right, out);
102                }
103                _ => out.push(mapping.size()),
104            }
105        }
106        let mut expected_shape = Vec::new();
107        flat_shape(&Element::to_value(), &mut expected_shape);
108        if view.shape() != expected_shape.as_slice() {
109            return Err(safetensors::SafeTensorError::TensorInvalidInfo);
110        }
111        let stride = D::BITS / 8;
112        if view.data().len() != Element::SIZE * stride {
113            return Err(safetensors::SafeTensorError::TensorInvalidInfo);
114        }
115        let data: Vec<Opt<D>> = view
116            .data()
117            .chunks_exact(stride)
118            .map(|b| Opt::Init(D::from_le_bytes(b)))
119            .collect();
120        Ok(Tensor::from_buf(data).into())
121    }
122
123    /// Creates a tensor filled with zeros.
124    pub fn zero() -> Self {
125        Tensor::zero().into()
126    }
127
128    /// Creates a tensor filled with random values.
129    #[primitive(HostTensor::rand)]
130    pub fn rand(rng: &mut impl Rng) -> Self
131    where
132        StandardUniform: rand::distr::Distribution<D>,
133    {
134        Tensor::rand(rng).into()
135    }
136
137    /// Creates an uninitialized tensor.
138    pub fn uninit() -> Self {
139        Tensor::uninit().into()
140    }
141
142    /// Converts to HBM tensor.
143    ///
144    /// TODO: `address` should be optional.
145    #[primitive(HostTensor::to_hbm)]
146    pub async fn to_hbm<Chip: M, Element2: M>(
147        &self,
148        _dma: &mut DmaContext<{ Dma::Pcie }>,
149        address: Address,
150    ) -> HbmTensor<D, Chip, Element2> {
151        <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_hbm(self, address).await
152    }
153
154    /// Consumes self and returns the inner tensor value.
155    pub fn into_inner(self) -> Tensor<D, Self::Mapping> {
156        self.inner
157    }
158
159    /// Returns the tensor data as a flat vector in physical layout order.
160    pub fn to_buf(&self) -> Vec<Opt<D>> {
161        self.inner.to_buf()
162    }
163}
164
165/// Tensor stored in HBM memory.
166#[primitive(HbmTensor)]
167#[derive(Debug)]
168pub struct HbmTensor<D: Scalar, Chip: M, Element: M> {
169    inner: Tensor<D, Pair<Chip, Element>>,
170    address: Address,
171}
172
173// Manual impl: inner `Tensor` is not DeviceSend
174impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensor<D, Chip, Element> {}
175impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for &HbmTensor<D, Chip, Element> {}
176impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for &mut HbmTensor<D, Chip, Element> {}
177impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensorView<'_, D, Chip, Element> {}
178impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensorViewMut<'_, D, Chip, Element> {}
179
180impl<D: Scalar, Chip: M, Element: M> HbmTensor<D, Chip, Element> {
181    /// Mapping type alias.
182    pub type Mapping = m![{ Chip }, { Element }];
183
184    pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
185        Self { inner, address }
186    }
187
188    pub(crate) fn inner_tensor(&self) -> &Tensor<D, Self::Mapping> {
189        &self.inner
190    }
191
192    /// Returns the HBM address of this tensor.
193    pub fn address(&self) -> Address {
194        self.address
195    }
196
197    /// Size in bytes.
198    pub fn size() -> usize {
199        <Pair<Chip, Element> as M>::SIZE * std::mem::size_of::<D>()
200    }
201
202    pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
203        self.inner.data()
204    }
205
206    /// Creates an HBM tensor handle at the given raw address.
207    ///
208    /// # Safety
209    ///
210    /// The caller must ensure that the underlying data layout is compatible
211    /// with the tensor mapping.
212    #[primitive(HbmTensor::from_addr)]
213    pub unsafe fn from_addr(address: Address) -> Self {
214        Self::new(Tensor::uninit(), address)
215    }
216
217    /// Converts to host tensor.
218    ///
219    /// TODO: we should optionally receive the intermediate stream's mapping expression.
220    #[primitive(HbmTensor::to_host)]
221    pub async fn to_host<Element2: M>(&self, _dma: &mut DmaContext<{ Dma::Pcie }>) -> HostTensor<D, Element2> {
222        <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_host(self).await
223    }
224
225    /// Converts to HBM tensor.
226    #[primitive(HbmTensor::to_hbm)]
227    pub fn to_hbm<const DMA: Dma, Element2: M>(
228        &self,
229        _dma: &mut DmaContext<{ DMA }>,
230        address: Address,
231    ) -> HbmTensor<D, Chip, Element2> {
232        HbmTensor::new(self.inner.transpose(true), address)
233    }
234
235    /// Gather DRAM rows into SRAM at positions given by index tensor.
236    ///
237    /// Implements `index_select` along dim 0: `output[i] = self[index[i]]`.
238    /// Inverse of [`DmTensor::dma_scatter`].
239    ///
240    /// TODO: implement CPU reference and LIR translation (same pattern as dma_scatter).
241    #[primitive(HbmTensor::dma_gather)]
242    pub fn dma_gather<Cluster2: M, Slice2: M, Element2: M, Element3: M>(
243        &self,
244        _index_tensor: &HbmTensor<i32, Chip, Element3>,
245        _address: Address,
246    ) -> DmTensor<D, Chip, Cluster2, Slice2, Element2> {
247        todo!()
248    }
249
250    /// Creates mutable views by splitting along a tile expression.
251    #[primitive(HbmTensor::view)]
252    pub fn view<'l>(&'l self) -> HbmTensorView<'l, D, Chip, Element> {
253        HbmTensorView {
254            inner: self.inner.view(),
255            address: self.address,
256        }
257    }
258
259    /// Creates mutable views by splitting along a tile expression.
260    #[primitive(HbmTensor::view_mut)]
261    pub fn view_mut<'l>(&'l mut self) -> HbmTensorViewMut<'l, D, Chip, Element> {
262        HbmTensorViewMut {
263            inner: self.inner.view_mut(),
264            address: self.address,
265        }
266    }
267}
268
269// ANCHOR: dma_impl
270impl<D: Scalar, Chip: M, Element: M> HbmTensor<D, Chip, Element> {
271    /// Converts to data memory tensor.
272    #[primitive(HbmTensor::to_dm)]
273    pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
274        &self,
275        _dma: &mut DmaContext<{ Dma::Tensor }>,
276        address: Address,
277    ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
278        DmTensor::new(self.inner.transpose(true), address)
279    }
280}
281// ANCHOR_END: dma_impl
282
283/// View of an HBM tensor.
284#[primitive(HbmTensorView)]
285#[derive(Debug, Clone)]
286pub struct HbmTensorView<'l, D: Scalar, Chip: M, Element: M> {
287    inner: TensorView<'l, D, Pair<Chip, Element>>,
288    address: Address,
289}
290
291impl<'l, D: Scalar, Chip: M, Element: M> HbmTensorView<'l, D, Chip, Element> {
292    /// Mapping type alias.
293    pub type Mapping = m![{ Chip }, { Element }];
294
295    /// Returns the base HBM address of this view.
296    pub fn address(&self) -> Address {
297        self.address
298    }
299
300    /// Writes to HBM tensor view.
301    #[primitive(HbmTensorView::to_hbm_view)]
302    pub fn to_hbm_view<const DMA: Dma, Element2: M>(
303        self,
304        _dma: &mut DmaContext<{ DMA }>,
305        mut dst: HbmTensorViewMut<'l, D, Chip, Element2>,
306    ) {
307        dst.inner.write_transpose(self.inner, true);
308    }
309
310    /// Converts to data memory tensor.
311    #[primitive(HbmTensorView::to_dm)]
312    pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
313        self,
314        _dma: &mut DmaContext<{ Dma::Tensor }>,
315        address: Address,
316    ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
317        DmTensor::new(self.inner.read().transpose(true), address)
318    }
319
320    /// Writes to data memory tensor view.
321    pub fn to_dm_view<Cluster: M, Slice: M, Element2: M>(
322        self,
323        _dma: &mut DmaContext<{ Dma::Tensor }>,
324        mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2>,
325    ) {
326        dst.inner.write_transpose(self.inner, true);
327    }
328
329    /// Creates immutable views by splitting along a tile expression.
330    #[primitive(HbmTensorView::tile)]
331    pub fn tile<Index: M, const LEN: usize, Element2: M>(&self, start: usize) -> HbmTensorView<'l, D, Chip, Element2> {
332        let inner = self.inner.tile::<Index, _, LEN>(start);
333        HbmTensorView {
334            inner,
335            address: self.address,
336        }
337    }
338}
339
340/// Mutable view of an HBM tensor.
341#[primitive(HbmTensorViewMut)]
342#[derive(Debug)]
343pub struct HbmTensorViewMut<'l, D: Scalar, Chip: M, Element: M> {
344    inner: TensorViewMut<'l, D, Pair<Chip, Element>>,
345    address: Address,
346}
347
348impl<'l, D: Scalar, Chip: M, Element: M> HbmTensorViewMut<'l, D, Chip, Element> {
349    /// Returns the base HBM address of this view.
350    pub fn address(&self) -> Address {
351        self.address
352    }
353
354    /// Creates mutable views by splitting along a tile expression.
355    #[primitive(HbmTensorViewMut::tile)]
356    pub fn tile<Index: M, const LEN: usize, Element2: M>(
357        &mut self,
358        start: usize,
359    ) -> HbmTensorViewMut<'l, D, Chip, Element2> {
360        let inner = self.inner.tile::<Index, _, LEN>(start);
361        HbmTensorViewMut {
362            inner,
363            address: self.address,
364        }
365    }
366}
367
368/// Tensor stored in data memory.
369#[primitive(DmTensor)]
370#[derive(Debug)]
371pub struct DmTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
372    inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
373    address: Address,
374    _marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
375}
376
377impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensor<D, Chip, Cluster, Slice, Element> {
378    /// Mapping type alias.
379    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
380
381    pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
382        Self {
383            inner,
384            address,
385            _marker: PhantomData,
386        }
387    }
388
389    /// Creates a DM tensor handle at the given raw address.
390    ///
391    /// # Safety
392    ///
393    /// The caller must ensure that the underlying data layout is compatible
394    /// with the tensor mapping.
395    #[primitive(DmTensor::from_addr)]
396    pub unsafe fn from_addr(address: Address) -> Self {
397        Self::new(Tensor::uninit(), address)
398    }
399
400    /// Converts to HBM tensor.
401    #[primitive(DmTensor::to_hbm)]
402    pub fn to_hbm<Element2: M>(
403        &self,
404        _dma: &mut DmaContext<{ Dma::Tensor }>,
405        address: Address,
406    ) -> HbmTensor<D, Chip, Element2> {
407        HbmTensor::new(self.inner.transpose(true), address)
408    }
409
410    /// Scatter SRAM values to DRAM at positions given by index tensor.
411    ///
412    /// ```text
413    /// data:   [N, K, V]
414    /// index:  [N, K]
415    /// output: [N, X, V]
416    ///
417    /// (data - Chip).divide(K) = [N, V]
418    /// ```
419    #[primitive(DmTensor::dma_scatter)]
420    pub fn dma_scatter<Key: M, Element2: M, Element3: M>(
421        &self,
422        index: &HbmTensor<i32, Chip, Element3>,
423        output: &mut HbmTensor<D, Chip, Element2>,
424        scaled: bool,
425    ) {
426        let src = Pair::<Slice, Element>::to_value().factorize();
427        let key = Key::to_value().factorize();
428        assert!(
429            src.clone().divide_relaxed(key).exact().is_ok(),
430            "scatter key `{:?}` must be fully contained in source `{src:?}`. \
431             If the key axis is split across Chip and Element, indirect DMA cannot address it.",
432            Key::to_value().factorize()
433        );
434
435        self.inner
436            .write_scatter::<Key, _, _>(&mut output.inner, &index.inner, scaled);
437    }
438
439    /// Converts to data memory tensor.
440    pub fn to_dm<Slice2: M, Element2: M>(
441        &self,
442        _dma: &mut DmaContext<{ Dma::Tensor }>,
443        address: Address,
444    ) -> DmTensor<D, Chip, Cluster, Slice2, Element2> {
445        DmTensor::new(self.inner.transpose(true), address)
446    }
447
448    /// Copies data to another DM tensor via parallel copy.
449    ///
450    /// Convenience wrapper: `self.view().to_dm_view_pcopy(sub, dst.view_mut())`.
451    pub fn to_dm_pcopy<Slice2: M, Element2: M>(
452        &self,
453        sub: &mut TuContext<{ Tu::Sub }>,
454        dst: &mut DmTensor<D, Chip, Cluster, Slice2, Element2>,
455    ) {
456        self.view().to_dm_view_pcopy(sub, dst.view_mut());
457    }
458
459    /// Creates immutable views by splitting along a tile expression.
460    #[primitive(DmTensor::view)]
461    pub fn view<'l>(&'l self) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element> {
462        DmTensorView {
463            inner: self.inner.view(),
464        }
465    }
466
467    /// Creates mutable views by splitting along a tile expression.
468    #[primitive(DmTensor::view_mut)]
469    pub fn view_mut<'l>(&'l mut self) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element> {
470        DmTensorViewMut {
471            inner: self.inner.view_mut(),
472        }
473    }
474
475    /// Reshapes the tensor to a different mapping at the same address.
476    ///
477    /// # Safety
478    ///
479    /// The caller must ensure that the new mapping accurately describes
480    /// the data currently at this address.
481    #[primitive(DmTensor::reshape)]
482    pub unsafe fn reshape<Chip2: M, Cluster2: M, Slice2: M, Element2: M>(
483        self,
484    ) -> DmTensor<D, Chip2, Cluster2, Slice2, Element2> {
485        DmTensor::new(
486            unsafe {
487                self.inner
488                    .clone()
489                    .reshape::<Chip, Chip2, m![{ Chip2 }, { Cluster }, { Slice }, { Element }]>()
490                    .reshape::<Cluster, Cluster2, m![{ Chip2 }, { Cluster2 }, { Slice }, { Element }]>()
491                    .reshape::<Slice, Slice2, m![{ Chip2 }, { Cluster2 }, { Slice2 }, { Element }]>()
492                    .reshape::<Element, Element2, m![{ Chip2 }, { Cluster2 }, { Slice2 }, { Element2 }]>()
493            },
494            self.address,
495        )
496    }
497}
498
499/// Mutable view of a data memory tensor.
500#[primitive(DmTensorViewMut)]
501#[derive(Debug)]
502pub struct DmTensorViewMut<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
503    pub(crate) inner: TensorViewMut<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
504}
505
506/// View of a data memory tensor.
507#[primitive(DmTensorView)]
508#[derive(Debug, Clone)]
509pub struct DmTensorView<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
510    pub(crate) inner: TensorView<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
511}
512
513impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>
514    From<DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>> for DmTensorView<'l, D, Chip, Cluster, Slice, Element>
515{
516    fn from(view: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>) -> Self {
517        Self {
518            inner: view.inner.into(),
519        }
520    }
521}
522
523impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensorView<'l, D, Chip, Cluster, Slice, Element> {
524    /// Mapping type alias.
525    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
526
527    /// Writes data to a mutable tensor view for HBM.
528    #[primitive(DmTensorView::to_hbm_view)]
529    pub fn to_hbm_view<Element2: M>(
530        self,
531        _dma: &mut DmaContext<{ Dma::Tensor }>,
532        mut dst: HbmTensorViewMut<'l, D, Chip, Element2>,
533    ) {
534        dst.inner.write_transpose(self.inner, true);
535    }
536
537    /// Writes data to a mutable tensor view for data memory.
538    #[primitive(DmTensorView::to_dm_view)]
539    pub fn to_dm_view<Slice2: M, Element2: M>(
540        self,
541        _dma: &mut DmaContext<{ Dma::Tensor }>,
542        mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2>,
543    ) {
544        dst.inner.write_transpose(self.inner, true);
545    }
546
547    /// Writes data to a mutable tensor view for data memory.
548    pub fn to_dm_view_pcopy<Slice2: M, Element2: M>(
549        self,
550        _sub: &mut TuContext<{ Tu::Sub }>,
551        mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2>,
552    ) {
553        dst.inner.write_transpose(self.inner, true);
554    }
555
556    /// Creates immutable views by splitting along a tile expression over Chip.
557    pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
558        &self,
559        start: usize,
560    ) -> DmTensorView<'l, D, Chip2, Cluster, Slice, Element> {
561        let inner = self.inner.tile::<Index, _, LEN>(start);
562        DmTensorView { inner }
563    }
564
565    /// Creates immutable views by splitting along a tile expression over Cluster.
566    pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
567        &self,
568        start: usize,
569    ) -> DmTensorView<'l, D, Chip, Cluster2, Slice, Element> {
570        let inner = self.inner.tile::<Index, _, LEN>(start);
571        DmTensorView { inner }
572    }
573
574    /// Creates immutable views by splitting along a tile expression over Slice.
575    #[primitive(DmTensorView::slice_tile)]
576    pub fn slice_tile<Index: M, const LEN: usize, Slice2: M>(
577        &self,
578        start: usize,
579    ) -> DmTensorView<'l, D, Chip, Cluster, Slice2, Element> {
580        let inner = self.inner.tile::<Index, _, LEN>(start);
581        DmTensorView { inner }
582    }
583
584    /// Creates immutable views by splitting along a tile expression over Element.
585    #[primitive(DmTensorView::tile)]
586    pub fn tile<Index: M, const LEN: usize, Element2: M>(
587        &self,
588        start: usize,
589    ) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element2> {
590        let inner = self.inner.tile::<Index, _, LEN>(start);
591        DmTensorView { inner }
592    }
593}
594
595impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element> {
596    /// Creates mutable views by splitting along a tile expression over Chip.
597    pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
598        &mut self,
599        start: usize,
600    ) -> DmTensorViewMut<'l, D, Chip2, Cluster, Slice, Element> {
601        let inner = self.inner.tile::<Index, _, LEN>(start);
602        DmTensorViewMut { inner }
603    }
604
605    /// Creates mutable views by splitting along a tile expression over Cluster.
606    pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
607        &mut self,
608        start: usize,
609    ) -> DmTensorViewMut<'l, D, Chip, Cluster2, Slice, Element> {
610        let inner = self.inner.tile::<Index, _, LEN>(start);
611        DmTensorViewMut { inner }
612    }
613
614    /// Creates mutable views by splitting along a tile expression over Element.
615    #[primitive(DmTensorViewMut::tile)]
616    pub fn tile<Index: M, const LEN: usize, Element2: M>(
617        &mut self,
618        start: usize,
619    ) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2> {
620        let inner = self.inner.tile::<Index, _, LEN>(start);
621        DmTensorViewMut { inner }
622    }
623}
624
625/// Tensor stored in the tensor register file.
626#[primitive(TrfTensor)]
627#[derive(Debug)]
628pub struct TrfTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Element: M> {
629    pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Element>>>>>,
630    #[expect(dead_code)]
631    address: TrfAddress,
632    _marker: PhantomData<(D, Chip, Cluster, Slice, Row, Element)>,
633}
634
635impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Element: M> TrfTensor<D, Chip, Cluster, Slice, Row, Element> {
636    /// Mapping type alias.
637    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Row }, { Element }];
638
639    pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: TrfAddress) -> Self {
640        Self {
641            inner,
642            address,
643            _marker: PhantomData,
644        }
645    }
646
647    /// Creates a TRF tensor handle at the given raw address.
648    ///
649    /// # Safety
650    ///
651    /// The caller must ensure that the underlying data layout is compatible
652    /// with the tensor mapping.
653    pub unsafe fn from_addr(address: TrfAddress) -> Self {
654        Self::new(Tensor::uninit(), address)
655    }
656
657    /// Creates a mutable view into the tensor.
658    pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
659        self.inner.view_mut()
660    }
661
662    /// Creates an immutable view into the tensor.
663    pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
664        self.inner.view()
665    }
666}
667
668/// Tensor stored in the vector register file (VRF).
669#[primitive(VrfTensor)]
670#[derive(Debug, Clone)]
671pub struct VrfTensor<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M> {
672    pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
673    #[expect(dead_code)]
674    address: Address,
675    _marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
676}
677
678impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M> VrfTensor<D, Chip, Cluster, Slice, Element> {
679    /// Mapping type alias.
680    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
681
682    pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
683        Self {
684            inner,
685            address,
686            _marker: PhantomData,
687        }
688    }
689
690    /// Creates a VRF tensor handle at the given raw address.
691    ///
692    /// # Safety
693    ///
694    /// The caller must ensure that the underlying data layout is compatible
695    /// with the tensor mapping.
696    pub unsafe fn from_addr(address: Address) -> Self {
697        Self::new(Tensor::uninit(), address)
698    }
699
700    /// Creates a mutable view into the tensor.
701    pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
702        self.inner.view_mut()
703    }
704
705    /// Creates an immutable view into the tensor.
706    pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
707        self.inner.view()
708    }
709}
710
711/// Tensor stored in dot product engine
712#[derive(Debug)]
713pub struct DpeTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Row: M, Packet: M> {
714    inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Time, Pair<Row, Packet>>>>>>,
715}
716
717impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Row: M, Packet: M>
718    DpeTensor<D, Chip, Cluster, Slice, Time, Row, Packet>
719{
720    /// Mapping type alias.
721    pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Time }, { Row }, { Packet }];
722
723    /// Creates a mutable view into the tensor.
724    pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
725        self.inner.view_mut()
726    }
727
728    /// Creates an immutable view into the tensor.
729    pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
730        self.inner.view()
731    }
732}