furiosa_visa_std/
tensor.rs

1use std::marker::PhantomData;
2
3use abi_stable::std_types::RSlice;
4use furiosa_mapping::*;
5use ndarray::ArrayD;
6use ndarray::IxDyn;
7use rand::Rng;
8use rand::distr::StandardUniform;
9
10use super::raw_tensor::*;
11use super::scalar::*;
12use super::tensor_view::*;
13
14/// Tensor with scalar type `D` with axes determined by `Mapping`.
15#[derive(Debug, Clone)]
16pub struct Tensor<D: Scalar, Mapping: M> {
17    inner: RawTensor<D>,
18    _marker: PhantomData<Mapping>,
19}
20
21impl<D: Scalar, Mapping: M> Tensor<D, Mapping> {
22    /// Creates a new tensor from a buffer.
23    #[furiosa_mapping_macro::primitive(Tensor::from_buf)]
24    pub fn from_buf(data: Vec<Opt<D>>) -> Self {
25        <crate::runtime::CurrentBackend as crate::runtime::Backend>::from_buf::<D, Mapping>(data)
26    }
27
28    /// Returns the tensor data as a flat vector in the physical layout order defined by `Mapping`.
29    #[furiosa_mapping_macro::primitive(Tensor::to_buf)]
30    pub fn to_buf(&self) -> Vec<Opt<D>> {
31        <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_buf::<D, Mapping>(self)
32    }
33
34    pub(crate) fn from_raw(inner: RawTensor<D>) -> Self {
35        Self {
36            inner,
37            _marker: PhantomData,
38        }
39    }
40
41    pub(crate) fn raw(&self) -> &RawTensor<D> {
42        &self.inner
43    }
44
45    /// Creates a new tensor from a function.
46    pub fn from_fn<F>(f: F) -> Self
47    where
48        F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
49    {
50        Self {
51            inner: RawTensor::<D>::from_fn::<Mapping, F>(f),
52            _marker: PhantomData,
53        }
54    }
55
56    /// Creates a zero tensor.
57    pub fn zero() -> Self {
58        Self::from_fn(|_, _| Opt::Init(D::zero()))
59    }
60
61    /// Creates a random tensor.
62    pub fn rand(rng: &mut impl Rng) -> Self
63    where
64        StandardUniform: rand::distr::Distribution<D>,
65    {
66        Self::from_fn(|_, _| Opt::Init(rng.random::<D>()))
67    }
68
69    /// Creates an uninitialized tensor.
70    pub fn uninit() -> Self {
71        Self::from_fn(|_, _| Opt::Uninit)
72    }
73
74    /// Creates a mutable view of the tensor.
75    pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Mapping> {
76        TensorViewMut::new(&mut self.inner)
77    }
78
79    /// Creates an immutable view of the tensor.
80    pub fn view<'l>(&'l self) -> TensorView<'l, D, Mapping> {
81        TensorView::new(&self.inner)
82    }
83
84    /// Returns a reference to the underlying data array.
85    pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
86        &self.inner.data
87    }
88
89    /// Returns a mutable reference to the underlying data array.
90    pub fn data_mut(&mut self) -> &mut ArrayD<Opt<D>> {
91        &mut self.inner.data
92    }
93
94    /// Transmutes the tensor to a different mapping.
95    ///
96    /// # Safety
97    ///
98    /// The caller must ensure that the underlying data layout is compatible with the new mapping.
99    pub unsafe fn transmute<Mapping2: M>(self) -> Tensor<D, Mapping2> {
100        Tensor {
101            inner: self.inner,
102            _marker: PhantomData,
103        }
104    }
105
106    /// Reshapes the tensor to a different mapping.
107    /// change Src into Dst, final mapping is Mapping2.
108    ///
109    /// # Safety
110    ///
111    /// TODO(jeongmin.park); document safety.
112    pub unsafe fn reshape<Src: M, Dst: M, Mapping2: M>(self) -> Tensor<D, Mapping2> {
113        if Src::to_value() == Dst::to_value() {
114            assert_eq!(Mapping::to_value(), Mapping2::to_value());
115
116            return Tensor {
117                inner: self.inner,
118                _marker: PhantomData,
119            };
120        }
121
122        assert_eq!(Src::SIZE, Dst::SIZE);
123        assert_eq!(
124            Src::to_value().factorize().into_inner().len(),
125            Dst::to_value().factorize().into_inner().len(),
126            "TODO: when Src/Dst have different length, src/dst quotient have different paddings"
127        );
128
129        let src_quotient = Mapping::to_value()
130            .divide_relaxed(&Src::to_value())
131            .exact()
132            .expect("[reshape] failed to divide by the mapping by reshape source expression")
133            .dividend_residue;
134
135        let dst_quotient = Mapping2::to_value()
136            .divide_relaxed(&Dst::to_value())
137            .exact()
138            .expect("[reshape] failed to divide by the output mapping by reshape destination expression")
139            .dividend_residue;
140
141        assert_eq!(
142            src_quotient, dst_quotient,
143            "[reshape] inconsistent reshape: quotient parts do not match"
144        );
145
146        let mut output = Tensor::<D, Mapping2>::uninit();
147
148        for idx in 0..Src::SIZE {
149            let src_offset = {
150                let mut offset = Index::new();
151                offset.add_mapping::<Src>(idx);
152                offset
153            };
154            let dst_offset = {
155                let mut offset = Index::new();
156                offset.add_mapping::<Dst>(idx);
157                offset
158            };
159
160            output.inner.write_broadcast(
161                &self.inner,
162                src_quotient.clone(),
163                FMapping::new(),
164                &src_offset,
165                &dst_offset,
166            );
167        }
168
169        output
170    }
171
172    /// Maps the tensor with a function.
173    pub fn map<D2: Scalar, F>(&self, f: F) -> Tensor<D2, Mapping>
174    where
175        F: Fn(&Opt<D>) -> Opt<D2>,
176    {
177        Tensor {
178            inner: self.inner.map(f),
179            _marker: PhantomData,
180        }
181    }
182
183    /// Zips two tensors with a function.
184    pub fn zip_with<D2: Scalar, D3: Scalar, F>(&self, other: &Tensor<D2, Mapping>, f: F) -> Tensor<D3, Mapping>
185    where
186        F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
187    {
188        Tensor {
189            inner: self.inner.zip_with(&other.inner, f),
190            _marker: PhantomData,
191        }
192    }
193
194    /// Performs reduction (sum) over axes not present in `Dst`.
195    pub fn reduce_add<Dst: M>(&self) -> Tensor<D, Dst> {
196        Tensor {
197            inner: self.inner.reduce_add(&gen_axes::<Dst>()),
198            _marker: PhantomData,
199        }
200    }
201
202    /// Reduces axes present in self but absent in Dst, using a custom reduce function.
203    pub fn reduce<Dst: M>(&self, reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>, identity: Opt<D>) -> Tensor<D, Dst> {
204        Tensor {
205            inner: self.inner.reduce(&gen_axes::<Dst>(), reduce_fn, identity),
206            _marker: PhantomData,
207        }
208    }
209
210    /// Performs transpose.
211    pub fn transpose<Dst: M>(&self, allow_broadcast: bool) -> Tensor<D, Dst> {
212        let mut dst = Tensor::<D, Dst>::uninit();
213        dst.view_mut().write_transpose(self.view(), allow_broadcast);
214        dst
215    }
216
217    /// Performs reduction followed by broadcasting to match destination axes.
218    /// This is useful when the destination has broadcast axes that should be preserved.
219    ///
220    /// # Examples
221    /// ```ignore
222    ///
223    /// let src: Tensor<f32, m![ABCD]> = Tensor::from_vec(...);  // Shape: ABCD
224    /// let dst: Tensor<f32, m![ABE]> = src.reduce_then_broadcast();  // Reduces CD, broadcasts to ABE
225    /// ```
226    pub fn reduce_then_broadcast<Dst: M>(&self) -> Tensor<D, Dst> {
227        let mut dst = Tensor::<D, Dst>::uninit();
228
229        // Perform reduction (sum)
230        let reduced = self.inner.reduce_add(&dst.inner.axes);
231
232        // Convert axes to FMapping.
233        let src_fmapping = FMapping::from_axes(RSlice::from(reduced.axes.as_slice()));
234        let dst_fmapping = FMapping::from_axes(RSlice::from(dst.inner.axes.as_slice()));
235
236        // Calculate broadcast factor: divide dst by src
237        let broadcast = dst_fmapping
238            .divide_relaxed(src_fmapping.clone())
239            .exact()
240            .expect("Failed to calculate broadcast factor")
241            .dividend_residue;
242
243        // Fill the result.
244        dst.inner
245            .write_broadcast(&reduced, src_fmapping, broadcast, &Index::new(), &Index::new());
246
247        dst
248    }
249
250    /// Performs reduction with a custom function, followed by broadcasting to match destination axes.
251    pub fn reduce_then_broadcast_with<Dst: M>(
252        &self,
253        reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
254        identity: Opt<D>,
255    ) -> Tensor<D, Dst> {
256        let mut dst = Tensor::<D, Dst>::uninit();
257
258        // Perform reduction
259        let reduced = self.inner.reduce(&dst.inner.axes, reduce_fn, identity);
260
261        // Convert axes to FMapping.
262        let src_fmapping = FMapping::from_axes(RSlice::from(reduced.axes.as_slice()));
263        let dst_fmapping = FMapping::from_axes(RSlice::from(dst.inner.axes.as_slice()));
264
265        // Calculate broadcast factor: divide dst by src
266        let broadcast = dst_fmapping
267            .divide_relaxed(src_fmapping.clone())
268            .exact()
269            .expect("Failed to calculate broadcast factor")
270            .dividend_residue;
271
272        // Fill the result.
273        dst.inner
274            .write_broadcast(&reduced, src_fmapping, broadcast, &Index::new(), &Index::new());
275
276        dst
277    }
278
279    /// Scatter elements from self into dst at positions given by index tensor.
280    ///
281    /// When `scaled=true`, indices are byte offsets; when `false`, logical positions.
282    pub fn write_scatter<Key: M, Dst: M, Idx: M>(
283        &self,
284        dst: &mut Tensor<D, Dst>,
285        index: &Tensor<i32, Idx>,
286        scaled: bool,
287    ) {
288        let src_fmapping = Mapping::to_value().factorize();
289        let dst_fmapping = Dst::to_value().factorize();
290        let key = Key::to_value().factorize();
291
292        let index_stride = if scaled {
293            let payload = src_fmapping
294                .clone()
295                .divide_relaxed(key.clone())
296                .exact()
297                .expect("Src must contain scatter key")
298                .dividend_residue;
299            payload.remove_padding().size() * std::mem::size_of::<D>()
300        } else {
301            1
302        };
303
304        let indices: Vec<usize> = index
305            .to_buf()
306            .into_iter()
307            .map(|opt| {
308                let Opt::Init(v) = opt else {
309                    panic!("Scatter index must be initialized")
310                };
311                v as usize / index_stride
312            })
313            .collect();
314
315        dst.inner
316            .write_scatter(&self.inner, src_fmapping, dst_fmapping, key, &indices);
317    }
318
319    /// Performs contraction between two tensors.
320    pub fn contraction<Union: M, Lhs: M, Rhs: M>(lhs: &Tensor<D, Lhs>, rhs: &Tensor<D, Rhs>) -> Self {
321        {
322            lhs.transpose::<Union>(true)
323                .zip_with(&rhs.transpose(true), |a, b| a * b)
324                .reduce_add()
325        }
326    }
327}