1use rand::Rng;
4use rand::distr::StandardUniform;
5use std::fmt::{self, Display, Formatter};
6use std::marker::PhantomData;
7
8use furiosa_mapping::*;
9use furiosa_mapping_macro::primitive;
10use furiosa_opt_macro::m;
11
12use crate::context::*;
13use crate::scalar::*;
14use crate::tensor::*;
15use crate::tensor_view::*;
16use crate::vector_engine::scalar::VeScalar;
17
18pub type Address = u64;
22
23#[primitive(TrfAddress)]
25#[derive(Clone, Debug)]
26pub enum TrfAddress {
27 FirstHalf,
29 SecondHalf,
31 Full,
33}
34
35impl TrfAddress {
36 pub fn capacity(&self) -> usize {
40 match self {
41 Self::Full => 65_536,
42 Self::FirstHalf | Self::SecondHalf => 32_768,
43 }
44 }
45}
46
47impl Display for TrfAddress {
48 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::FirstHalf => write!(f, "TrfAddress::FirstHalf"),
51 Self::SecondHalf => write!(f, "TrfAddress::SecondHalf"),
52 Self::Full => write!(f, "TrfAddress::Full"),
53 }
54 }
55}
56
57#[primitive(HostTensor)]
59#[derive(Debug, Clone)]
60pub struct HostTensor<D: Scalar, Element: M> {
61 inner: Tensor<D, Element>,
62}
63
64impl<D: Scalar, Element: M> From<Tensor<D, Element>> for HostTensor<D, Element> {
65 fn from(inner: Tensor<D, Element>) -> Self {
66 Self { inner }
67 }
68}
69
70impl<D: Scalar, Element: M> HostTensor<D, Element> {
71 pub type Mapping = Element;
73
74 pub(crate) fn inner_tensor(&self) -> &Tensor<D, Element> {
75 &self.inner
76 }
77
78 pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
79 self.inner.data()
80 }
81
82 pub fn from_buf(data: Vec<Opt<D>>) -> Self {
84 Tensor::from_buf(data).into()
85 }
86
87 pub fn from_safetensors(view: &safetensors::tensor::TensorView<'_>) -> Result<Self, safetensors::SafeTensorError>
94 where
95 D: ScalarBytes,
96 {
97 fn flat_shape(mapping: &Mapping, out: &mut Vec<usize>) {
98 match mapping {
99 Mapping::Pair { left, right } => {
100 flat_shape(left, out);
101 flat_shape(right, out);
102 }
103 _ => out.push(mapping.size()),
104 }
105 }
106 let mut expected_shape = Vec::new();
107 flat_shape(&Element::to_value(), &mut expected_shape);
108 if view.shape() != expected_shape.as_slice() {
109 return Err(safetensors::SafeTensorError::TensorInvalidInfo);
110 }
111 let stride = D::BITS / 8;
112 if view.data().len() != Element::SIZE * stride {
113 return Err(safetensors::SafeTensorError::TensorInvalidInfo);
114 }
115 let data: Vec<Opt<D>> = view
116 .data()
117 .chunks_exact(stride)
118 .map(|b| Opt::Init(D::from_le_bytes(b)))
119 .collect();
120 Ok(Tensor::from_buf(data).into())
121 }
122
123 pub fn zero() -> Self {
125 Tensor::zero().into()
126 }
127
128 #[primitive(HostTensor::rand)]
130 pub fn rand(rng: &mut impl Rng) -> Self
131 where
132 StandardUniform: rand::distr::Distribution<D>,
133 {
134 Tensor::rand(rng).into()
135 }
136
137 pub fn uninit() -> Self {
139 Tensor::uninit().into()
140 }
141
142 #[primitive(HostTensor::to_hbm)]
146 pub async fn to_hbm<Chip: M, Element2: M>(
147 &self,
148 _dma: &mut DmaContext<{ Dma::Pcie }>,
149 address: Address,
150 ) -> HbmTensor<D, Chip, Element2> {
151 <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_hbm(self, address).await
152 }
153
154 pub fn into_inner(self) -> Tensor<D, Self::Mapping> {
156 self.inner
157 }
158
159 pub fn to_buf(&self) -> Vec<Opt<D>> {
161 self.inner.to_buf()
162 }
163}
164
165#[primitive(HbmTensor)]
167#[derive(Debug)]
168pub struct HbmTensor<D: Scalar, Chip: M, Element: M> {
169 inner: Tensor<D, Pair<Chip, Element>>,
170 address: Address,
171}
172
173impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensor<D, Chip, Element> {}
175impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for &HbmTensor<D, Chip, Element> {}
176impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for &mut HbmTensor<D, Chip, Element> {}
177impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensorView<'_, D, Chip, Element> {}
178impl<D: Scalar, Chip: M, Element: M> crate::runtime::DeviceSend for HbmTensorViewMut<'_, D, Chip, Element> {}
179
180impl<D: Scalar, Chip: M, Element: M> HbmTensor<D, Chip, Element> {
181 pub type Mapping = m![{ Chip }, { Element }];
183
184 pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
185 Self { inner, address }
186 }
187
188 pub(crate) fn inner_tensor(&self) -> &Tensor<D, Self::Mapping> {
189 &self.inner
190 }
191
192 pub fn address(&self) -> Address {
194 self.address
195 }
196
197 pub fn size() -> usize {
199 <Pair<Chip, Element> as M>::SIZE * std::mem::size_of::<D>()
200 }
201
202 pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
203 self.inner.data()
204 }
205
206 #[primitive(HbmTensor::from_addr)]
213 pub unsafe fn from_addr(address: Address) -> Self {
214 Self::new(Tensor::uninit(), address)
215 }
216
217 #[primitive(HbmTensor::to_host)]
221 pub async fn to_host<Element2: M>(&self, _dma: &mut DmaContext<{ Dma::Pcie }>) -> HostTensor<D, Element2> {
222 <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_host(self).await
223 }
224
225 #[primitive(HbmTensor::to_hbm)]
227 pub fn to_hbm<const DMA: Dma, Element2: M>(
228 &self,
229 _dma: &mut DmaContext<{ DMA }>,
230 address: Address,
231 ) -> HbmTensor<D, Chip, Element2> {
232 HbmTensor::new(self.inner.transpose(true), address)
233 }
234
235 #[primitive(HbmTensor::dma_gather)]
242 pub fn dma_gather<Cluster2: M, Slice2: M, Element2: M, Element3: M>(
243 &self,
244 _index_tensor: &HbmTensor<i32, Chip, Element3>,
245 _address: Address,
246 ) -> DmTensor<D, Chip, Cluster2, Slice2, Element2> {
247 todo!()
248 }
249
250 #[primitive(HbmTensor::view)]
252 pub fn view<'l>(&'l self) -> HbmTensorView<'l, D, Chip, Element> {
253 HbmTensorView {
254 inner: self.inner.view(),
255 address: self.address,
256 }
257 }
258
259 #[primitive(HbmTensor::view_mut)]
261 pub fn view_mut<'l>(&'l mut self) -> HbmTensorViewMut<'l, D, Chip, Element> {
262 HbmTensorViewMut {
263 inner: self.inner.view_mut(),
264 address: self.address,
265 }
266 }
267}
268
269impl<D: Scalar, Chip: M, Element: M> HbmTensor<D, Chip, Element> {
271 #[primitive(HbmTensor::to_dm)]
273 pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
274 &self,
275 _dma: &mut DmaContext<{ Dma::Tensor }>,
276 address: Address,
277 ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
278 DmTensor::new(self.inner.transpose(true), address)
279 }
280}
281#[primitive(HbmTensorView)]
285#[derive(Debug, Clone)]
286pub struct HbmTensorView<'l, D: Scalar, Chip: M, Element: M> {
287 inner: TensorView<'l, D, Pair<Chip, Element>>,
288 address: Address,
289}
290
291impl<'l, D: Scalar, Chip: M, Element: M> HbmTensorView<'l, D, Chip, Element> {
292 pub type Mapping = m![{ Chip }, { Element }];
294
295 pub fn address(&self) -> Address {
297 self.address
298 }
299
300 #[primitive(HbmTensorView::to_hbm_view)]
302 pub fn to_hbm_view<const DMA: Dma, Element2: M>(
303 self,
304 _dma: &mut DmaContext<{ DMA }>,
305 mut dst: HbmTensorViewMut<'l, D, Chip, Element2>,
306 ) {
307 dst.inner.write_transpose(self.inner, true);
308 }
309
310 #[primitive(HbmTensorView::to_dm)]
312 pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
313 self,
314 _dma: &mut DmaContext<{ Dma::Tensor }>,
315 address: Address,
316 ) -> DmTensor<D, Chip, Cluster, Slice, Element2> {
317 DmTensor::new(self.inner.read().transpose(true), address)
318 }
319
320 pub fn to_dm_view<Cluster: M, Slice: M, Element2: M>(
322 self,
323 _dma: &mut DmaContext<{ Dma::Tensor }>,
324 mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2>,
325 ) {
326 dst.inner.write_transpose(self.inner, true);
327 }
328
329 #[primitive(HbmTensorView::tile)]
331 pub fn tile<Index: M, const LEN: usize, Element2: M>(&self, start: usize) -> HbmTensorView<'l, D, Chip, Element2> {
332 let inner = self.inner.tile::<Index, _, LEN>(start);
333 HbmTensorView {
334 inner,
335 address: self.address,
336 }
337 }
338}
339
340#[primitive(HbmTensorViewMut)]
342#[derive(Debug)]
343pub struct HbmTensorViewMut<'l, D: Scalar, Chip: M, Element: M> {
344 inner: TensorViewMut<'l, D, Pair<Chip, Element>>,
345 address: Address,
346}
347
348impl<'l, D: Scalar, Chip: M, Element: M> HbmTensorViewMut<'l, D, Chip, Element> {
349 pub fn address(&self) -> Address {
351 self.address
352 }
353
354 #[primitive(HbmTensorViewMut::tile)]
356 pub fn tile<Index: M, const LEN: usize, Element2: M>(
357 &mut self,
358 start: usize,
359 ) -> HbmTensorViewMut<'l, D, Chip, Element2> {
360 let inner = self.inner.tile::<Index, _, LEN>(start);
361 HbmTensorViewMut {
362 inner,
363 address: self.address,
364 }
365 }
366}
367
368#[primitive(DmTensor)]
370#[derive(Debug)]
371pub struct DmTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
372 inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
373 address: Address,
374 _marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
375}
376
377impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensor<D, Chip, Cluster, Slice, Element> {
378 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
380
381 pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
382 Self {
383 inner,
384 address,
385 _marker: PhantomData,
386 }
387 }
388
389 #[primitive(DmTensor::from_addr)]
396 pub unsafe fn from_addr(address: Address) -> Self {
397 Self::new(Tensor::uninit(), address)
398 }
399
400 #[primitive(DmTensor::to_hbm)]
402 pub fn to_hbm<Element2: M>(
403 &self,
404 _dma: &mut DmaContext<{ Dma::Tensor }>,
405 address: Address,
406 ) -> HbmTensor<D, Chip, Element2> {
407 HbmTensor::new(self.inner.transpose(true), address)
408 }
409
410 #[primitive(DmTensor::dma_scatter)]
420 pub fn dma_scatter<Key: M, Element2: M, Element3: M>(
421 &self,
422 index: &HbmTensor<i32, Chip, Element3>,
423 output: &mut HbmTensor<D, Chip, Element2>,
424 scaled: bool,
425 ) {
426 let src = Pair::<Slice, Element>::to_value().factorize();
427 let key = Key::to_value().factorize();
428 assert!(
429 src.clone().divide_relaxed(key).exact().is_ok(),
430 "scatter key `{:?}` must be fully contained in source `{src:?}`. \
431 If the key axis is split across Chip and Element, indirect DMA cannot address it.",
432 Key::to_value().factorize()
433 );
434
435 self.inner
436 .write_scatter::<Key, _, _>(&mut output.inner, &index.inner, scaled);
437 }
438
439 pub fn to_dm<Slice2: M, Element2: M>(
441 &self,
442 _dma: &mut DmaContext<{ Dma::Tensor }>,
443 address: Address,
444 ) -> DmTensor<D, Chip, Cluster, Slice2, Element2> {
445 DmTensor::new(self.inner.transpose(true), address)
446 }
447
448 pub fn to_dm_pcopy<Slice2: M, Element2: M>(
452 &self,
453 sub: &mut TuContext<{ Tu::Sub }>,
454 dst: &mut DmTensor<D, Chip, Cluster, Slice2, Element2>,
455 ) {
456 self.view().to_dm_view_pcopy(sub, dst.view_mut());
457 }
458
459 #[primitive(DmTensor::view)]
461 pub fn view<'l>(&'l self) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element> {
462 DmTensorView {
463 inner: self.inner.view(),
464 }
465 }
466
467 #[primitive(DmTensor::view_mut)]
469 pub fn view_mut<'l>(&'l mut self) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element> {
470 DmTensorViewMut {
471 inner: self.inner.view_mut(),
472 }
473 }
474
475 #[primitive(DmTensor::reshape)]
482 pub unsafe fn reshape<Chip2: M, Cluster2: M, Slice2: M, Element2: M>(
483 self,
484 ) -> DmTensor<D, Chip2, Cluster2, Slice2, Element2> {
485 DmTensor::new(
486 unsafe {
487 self.inner
488 .clone()
489 .reshape::<Chip, Chip2, m![{ Chip2 }, { Cluster }, { Slice }, { Element }]>()
490 .reshape::<Cluster, Cluster2, m![{ Chip2 }, { Cluster2 }, { Slice }, { Element }]>()
491 .reshape::<Slice, Slice2, m![{ Chip2 }, { Cluster2 }, { Slice2 }, { Element }]>()
492 .reshape::<Element, Element2, m![{ Chip2 }, { Cluster2 }, { Slice2 }, { Element2 }]>()
493 },
494 self.address,
495 )
496 }
497}
498
499#[primitive(DmTensorViewMut)]
501#[derive(Debug)]
502pub struct DmTensorViewMut<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
503 pub(crate) inner: TensorViewMut<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
504}
505
506#[primitive(DmTensorView)]
508#[derive(Debug, Clone)]
509pub struct DmTensorView<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> {
510 pub(crate) inner: TensorView<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
511}
512
513impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M>
514 From<DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>> for DmTensorView<'l, D, Chip, Cluster, Slice, Element>
515{
516 fn from(view: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element>) -> Self {
517 Self {
518 inner: view.inner.into(),
519 }
520 }
521}
522
523impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensorView<'l, D, Chip, Cluster, Slice, Element> {
524 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
526
527 #[primitive(DmTensorView::to_hbm_view)]
529 pub fn to_hbm_view<Element2: M>(
530 self,
531 _dma: &mut DmaContext<{ Dma::Tensor }>,
532 mut dst: HbmTensorViewMut<'l, D, Chip, Element2>,
533 ) {
534 dst.inner.write_transpose(self.inner, true);
535 }
536
537 #[primitive(DmTensorView::to_dm_view)]
539 pub fn to_dm_view<Slice2: M, Element2: M>(
540 self,
541 _dma: &mut DmaContext<{ Dma::Tensor }>,
542 mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2>,
543 ) {
544 dst.inner.write_transpose(self.inner, true);
545 }
546
547 pub fn to_dm_view_pcopy<Slice2: M, Element2: M>(
549 self,
550 _sub: &mut TuContext<{ Tu::Sub }>,
551 mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2>,
552 ) {
553 dst.inner.write_transpose(self.inner, true);
554 }
555
556 pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
558 &self,
559 start: usize,
560 ) -> DmTensorView<'l, D, Chip2, Cluster, Slice, Element> {
561 let inner = self.inner.tile::<Index, _, LEN>(start);
562 DmTensorView { inner }
563 }
564
565 pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
567 &self,
568 start: usize,
569 ) -> DmTensorView<'l, D, Chip, Cluster2, Slice, Element> {
570 let inner = self.inner.tile::<Index, _, LEN>(start);
571 DmTensorView { inner }
572 }
573
574 #[primitive(DmTensorView::slice_tile)]
576 pub fn slice_tile<Index: M, const LEN: usize, Slice2: M>(
577 &self,
578 start: usize,
579 ) -> DmTensorView<'l, D, Chip, Cluster, Slice2, Element> {
580 let inner = self.inner.tile::<Index, _, LEN>(start);
581 DmTensorView { inner }
582 }
583
584 #[primitive(DmTensorView::tile)]
586 pub fn tile<Index: M, const LEN: usize, Element2: M>(
587 &self,
588 start: usize,
589 ) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element2> {
590 let inner = self.inner.tile::<Index, _, LEN>(start);
591 DmTensorView { inner }
592 }
593}
594
595impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element> {
596 pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
598 &mut self,
599 start: usize,
600 ) -> DmTensorViewMut<'l, D, Chip2, Cluster, Slice, Element> {
601 let inner = self.inner.tile::<Index, _, LEN>(start);
602 DmTensorViewMut { inner }
603 }
604
605 pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
607 &mut self,
608 start: usize,
609 ) -> DmTensorViewMut<'l, D, Chip, Cluster2, Slice, Element> {
610 let inner = self.inner.tile::<Index, _, LEN>(start);
611 DmTensorViewMut { inner }
612 }
613
614 #[primitive(DmTensorViewMut::tile)]
616 pub fn tile<Index: M, const LEN: usize, Element2: M>(
617 &mut self,
618 start: usize,
619 ) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2> {
620 let inner = self.inner.tile::<Index, _, LEN>(start);
621 DmTensorViewMut { inner }
622 }
623}
624
625#[primitive(TrfTensor)]
627#[derive(Debug)]
628pub struct TrfTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Element: M> {
629 pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Row, Element>>>>>,
630 #[expect(dead_code)]
631 address: TrfAddress,
632 _marker: PhantomData<(D, Chip, Cluster, Slice, Row, Element)>,
633}
634
635impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Row: M, Element: M> TrfTensor<D, Chip, Cluster, Slice, Row, Element> {
636 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Row }, { Element }];
638
639 pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: TrfAddress) -> Self {
640 Self {
641 inner,
642 address,
643 _marker: PhantomData,
644 }
645 }
646
647 pub unsafe fn from_addr(address: TrfAddress) -> Self {
654 Self::new(Tensor::uninit(), address)
655 }
656
657 pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
659 self.inner.view_mut()
660 }
661
662 pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
664 self.inner.view()
665 }
666}
667
668#[primitive(VrfTensor)]
670#[derive(Debug, Clone)]
671pub struct VrfTensor<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M> {
672 pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>,
673 #[expect(dead_code)]
674 address: Address,
675 _marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
676}
677
678impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M> VrfTensor<D, Chip, Cluster, Slice, Element> {
679 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
681
682 pub(crate) fn new(inner: Tensor<D, Self::Mapping>, address: Address) -> Self {
683 Self {
684 inner,
685 address,
686 _marker: PhantomData,
687 }
688 }
689
690 pub unsafe fn from_addr(address: Address) -> Self {
697 Self::new(Tensor::uninit(), address)
698 }
699
700 pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
702 self.inner.view_mut()
703 }
704
705 pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
707 self.inner.view()
708 }
709}
710
711#[derive(Debug)]
713pub struct DpeTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Row: M, Packet: M> {
714 inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Time, Pair<Row, Packet>>>>>>,
715}
716
717impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Row: M, Packet: M>
718 DpeTensor<D, Chip, Cluster, Slice, Time, Row, Packet>
719{
720 pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Time }, { Row }, { Packet }];
722
723 pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping> {
725 self.inner.view_mut()
726 }
727
728 pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping> {
730 self.inner.view()
731 }
732}