furiosa_visa_std/
raw_tensor.rs

1use ndarray::{ArrayD, Axis, IxDyn};
2use num_traits::Zero;
3use std::{fmt::Debug, marker::PhantomData};
4
5use abi_stable::std_types::RResult;
6use furiosa_mapping::*;
7
8use super::scalar::*;
9
10/// Generates axes from a mapping.
11pub(crate) fn gen_axes<Mapping: M>() -> Vec<Term> {
12    let mut index = Index::new();
13    index.add_mapping::<Mapping>(0);
14    index
15        .finalize()
16        .expect("Invalid mapping")
17        .into_iter()
18        .map(|(term, _)| term)
19        .collect::<Vec<_>>()
20}
21
22/// Tensor with scalar type `D`.
23#[derive(Debug, Clone, PartialEq)]
24pub struct RawTensor<D: Scalar> {
25    /// The axes of the tensor represented as a sorted vector of `Term`.
26    pub(crate) axes: Vec<Term>,
27    /// The multi-dimensional array holding the tensor data, where each element maybe be uninitialized.
28    pub(crate) data: ArrayD<Opt<D>>,
29    pub(crate) _marker: PhantomData<D>,
30}
31
32impl<D: Scalar> Eq for RawTensor<D> where D: Eq {}
33
34impl<D: Scalar> RawTensor<D> {
35    /// Similar to `from_vec`, but creates a new tensor with uninit elements.
36    pub fn from_elem<Mapping: M>(elem: Opt<D>) -> Self {
37        // Construct axes from mapping.
38        let axes = gen_axes::<Mapping>();
39
40        // Construct data array from vector.
41        let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
42        let data = ArrayD::from_elem(shape, elem);
43
44        Self {
45            axes,
46            data,
47            _marker: PhantomData,
48        }
49    }
50
51    /// Creates a new tensor from a vector.
52    ///
53    /// `Mapping` determines the axes of the tensor, and `data` is a flat vector containing the tensor elements in the order sorted by the axes.
54    pub fn from_vec<Mapping: M>(data: Vec<Opt<D>>) -> Self {
55        // Construct axes from mapping.
56        let axes = gen_axes::<Mapping>();
57
58        // Construct data array from vector.
59        let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
60        let data = ArrayD::from_shape_vec(shape, data).expect("Data length does not match tensor shape.");
61
62        Self {
63            axes,
64            data,
65            _marker: PhantomData,
66        }
67    }
68
69    /// Creates a new tensor from a mapping and a function.
70    pub fn from_fn<Mapping: M, F>(mut f: F) -> Self
71    where
72        F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
73    {
74        // Construct axes from mapping.
75        let axes = gen_axes::<Mapping>();
76
77        // Construct data array from function.
78        let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
79        let data = ArrayD::from_shape_fn(shape, |idx| f(&axes, &idx));
80
81        Self {
82            axes,
83            data,
84            _marker: PhantomData,
85        }
86    }
87
88    /// Applies a unary function to each element of the tensor.
89    ///
90    /// # Examples
91    /// ```ignore
92    /// let tensor = TensorValue::<i32>::from_vec::<m![A, B]>(...);
93    /// let doubled = tensor.map(|&x| x * Opt::Init(2));
94    /// ```
95    pub fn map<D2: Scalar, F>(&self, f: F) -> RawTensor<D2>
96    where
97        F: FnMut(&Opt<D>) -> Opt<D2>,
98    {
99        let data = self.data.map(f);
100        RawTensor {
101            axes: self.axes.clone(),
102            data,
103            _marker: PhantomData,
104        }
105    }
106
107    /// Applies a binary function element-wise to two tensors with the same shape.
108    ///
109    /// # Examples
110    /// ```ignore
111    /// let a = RawTensor::<i32>::from_vec::<m![A, B]>(...);
112    /// let b = RawTensor::<i32>::from_vec::<m![A, B]>(...);
113    /// let sum = a.zip_with(&b, |x, y| x + y);
114    /// ```
115    pub fn zip_with<D2: Scalar, D3: Scalar, F>(&self, other: &RawTensor<D2>, f: F) -> RawTensor<D3>
116    where
117        F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
118    {
119        assert_eq!(
120            self.axes, other.axes,
121            "Tensors must have the same axes for element-wise binary operations"
122        );
123
124        let data = ndarray::Zip::from(&self.data)
125            .and(&other.data)
126            .map_collect(|&a, &b| f(a, b));
127
128        RawTensor {
129            axes: self.axes.clone(),
130            data,
131            _marker: PhantomData,
132        }
133    }
134
135    /// Reduces axes not in `retain_axes` using a custom binary function and identity.
136    pub fn reduce(
137        &self,
138        retain_axes: &[Term],
139        reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
140        identity: Opt<D>,
141    ) -> RawTensor<D> {
142        let reduce_axes: Vec<Term> = self
143            .axes
144            .iter()
145            .filter(|src_term| !retain_axes.contains(src_term))
146            .cloned()
147            .collect();
148        self.reduce_for(&reduce_axes, reduce_fn, identity)
149    }
150
151    /// Performs reduction (sum) over axes.
152    /// The axes to retain are specified by their terms.
153    pub fn reduce_add(&self, retain_axes: &[Term]) -> RawTensor<D> {
154        self.reduce(retain_axes, |a, b| a + b, Opt::zero())
155    }
156
157    /// Performs reduction over specified axes using a custom binary function and identity.
158    /// The axes to reduce are specified by their terms.
159    fn reduce_for(
160        &self,
161        reduce_axes: &[Term],
162        reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
163        identity: Opt<D>,
164    ) -> RawTensor<D> {
165        // Find indices of axes to reduce.
166        let reduce_indices: Vec<usize> = reduce_axes
167            .iter()
168            .filter_map(|reduce_term| self.axes.iter().position(|axis_term| axis_term == reduce_term))
169            .collect();
170
171        // Compute new axes (excluding reduced axes).
172        let axes: Vec<Term> = self
173            .axes
174            .iter()
175            .enumerate()
176            .filter_map(|(idx, term)| {
177                if reduce_indices.contains(&idx) {
178                    None
179                } else {
180                    Some(term.clone())
181                }
182            })
183            .collect();
184
185        // Perform reduction by summing over the specified axes.
186        let data = if reduce_indices.is_empty() {
187            // No reduction needed
188            self.data.clone()
189        } else if reduce_indices.len() == self.axes.len() {
190            // Reducing all axes to a scalar
191            let sum_value = self.data.sum();
192            ArrayD::from_shape_vec(IxDyn(&[]), vec![sum_value]).expect("Failed to create scalar array")
193        } else {
194            // Partial reduction
195            // Sort and deduplicate indices, then reduce in reverse order
196            let mut sorted_indices = reduce_indices.clone();
197            sorted_indices.sort_unstable();
198            sorted_indices.dedup();
199
200            // Validate indices
201            for &idx in &sorted_indices {
202                if idx >= self.axes.len() {
203                    panic!(
204                        "[TensorValue::reduce] Invalid axis index: {} (tensor has {} axes)\n\
205                         axes: {:?}\n\
206                         reduce_axes: {:?}\n\
207                         reduce_indices: {:?}",
208                        idx,
209                        self.axes.len(),
210                        self.axes,
211                        reduce_axes,
212                        sorted_indices
213                    );
214                }
215            }
216
217            let mut data = self.data.clone();
218            for &axis_idx in sorted_indices.iter().rev() {
219                if axis_idx >= data.ndim() {
220                    panic!(
221                        "[TensorValue::reduce] axis_idx {} >= data.ndim() {}\n\
222                         current data shape: {:?}\n\
223                         axes: {:?}\n\
224                         reduce_indices: {:?}",
225                        axis_idx,
226                        data.ndim(),
227                        data.shape(),
228                        self.axes,
229                        sorted_indices
230                    );
231                }
232                data = data.fold_axis(Axis(axis_idx), identity, |&acc, &val| reduce_fn(acc, val));
233            }
234            data
235        };
236
237        Self {
238            axes,
239            data,
240            _marker: PhantomData,
241        }
242    }
243
244    /// Broadcasts and writes data from another tensor based on the given mappings and offsets.
245    ///
246    /// # Arguments
247    /// * `src` - Source tensor to read data from.
248    /// * `unicast` - Mapping for 1-to-1 mapping dimensions.
249    /// * `broadcast` - Mapping for broadcasted dimensions.
250    /// * `src_offset` - Offset to apply to the source tensor indices.
251    /// * `dst_offset` - Offset to apply to the destination (self) tensor indices.
252    pub fn write_broadcast(
253        &mut self,
254        src: &Self,
255        unicast: FMapping,
256        broadcast: FMapping,
257        src_offset: &Index,
258        dst_offset: &Index,
259    ) {
260        // Iterate over all unicast indexes.
261        for index in Index::new().gen_indexes(unicast) {
262            // Read value from source tensor at the index with offset.
263            let mut src_index = index.clone();
264            src_index.add(src_offset.clone());
265            let value = src.read_index(src_index);
266
267            // Calculate the base destination index with offset.
268            let mut dst_index = index;
269            dst_index.add(dst_offset.clone());
270
271            // Write value to all broadcasted indexes in the destination tensor.
272            for dst_index in dst_index.gen_indexes(broadcast.clone()) {
273                self.write_index(dst_index, value);
274            }
275        }
276    }
277
278    /// Scatters elements from `src` into `self` at positions given by `indices`.
279    ///
280    /// ```text
281    /// src:    [N, K, V]
282    /// dst:    [N, X, V]
283    /// key:    K
284    /// idices: [N, K]
285    ///
286    /// for n, k:
287    ///     dst[n][indices[n,k]][v] = src[n][k][v]
288    /// ```
289    ///
290    /// Decomposition via divide:
291    /// 1. `src / key`     = payload [N, V] — axes preserved across scatter
292    /// 2. `dst / payload` = target  [X]    — scatter target axis
293    pub(crate) fn write_scatter(
294        &mut self,
295        src: &Self,
296        src_mapping: FMapping,
297        dst_mapping: FMapping,
298        key: FMapping,
299        indices: &[usize],
300    ) {
301        let payload = src_mapping
302            .clone()
303            .divide_relaxed(key.clone())
304            .exact()
305            .unwrap_or_else(|e| panic!("Scatter key `{key:?}` not found in source `{src_mapping:?}`: {e:?}"))
306            .dividend_residue;
307        let dst_residue = dst_mapping
308            .clone()
309            .divide_relaxed(payload.clone())
310            .exact()
311            .unwrap_or_else(|e| panic!("Destination `{dst_mapping:?}` missing payload axes `{payload:?}`: {e:?}"))
312            .dividend_residue;
313        let (dst_term, _) = dst_residue
314            .into_inner()
315            .into_iter()
316            .find_map(|f| match f {
317                Factor::Term { inner, resize } => Some((inner, resize)),
318                Factor::Padding { .. } => None,
319            })
320            .expect("Destination has no scatter target axis after removing payload");
321
322        for payload_index in Index::new().gen_indexes(payload) {
323            for (key_pos, key_index) in Index::new().gen_indexes(key.clone()).into_iter().enumerate() {
324                let mut src_index = payload_index.clone();
325                src_index.add(key_index);
326                let value = src.read_index(src_index);
327
328                let mut dst_index = payload_index.clone();
329                dst_index.add_term(dst_term.clone(), indices[key_pos]);
330                self.write_index(dst_index, value);
331            }
332        }
333    }
334
335    /// Reads the tensor value at the given index.
336    pub(crate) fn read_index(&self, index: Index) -> Opt<D> {
337        // Finalize the index before reading.
338        let RResult::ROk(index) = index.finalize() else {
339            return Opt::Uninit;
340        };
341        assert!(
342            self.axes.iter().zip(index.iter()).all(|(a, (b, _))| a == b),
343            "Index terms ({:?}) do not match tensor axes ({:?}).",
344            index,
345            self.axes
346        );
347
348        // Read the value from the data array.
349        *self
350            .data
351            .get(index.into_iter().map(|(_, v)| v).collect::<Vec<usize>>().as_slice())
352            .expect("Index out of bounds.")
353    }
354
355    /// Writes the tensor value at the given index.
356    pub(crate) fn write_index(&mut self, index: Index, value: Opt<D>) {
357        // Finalize the index before reading.
358        let RResult::ROk(index) = index.finalize() else {
359            return;
360        };
361        assert!(
362            self.axes.iter().zip(index.iter()).all(|(a, (b, _))| a == b),
363            "Index terms do not match tensor axes."
364        );
365
366        // Write the value to the data array.
367        *self
368            .data
369            .get_mut(index.into_iter().map(|(_, v)| v).collect::<Vec<usize>>().as_slice())
370            .expect("Index out of bounds.") = value;
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::prelude::*;
378
379    #[test]
380    fn test_tensor_zip_with() {
381        axes![A = 2, B = 3];
382
383        let t1 = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
384        let t2 = RawTensor::<i32>::from_vec::<m![A, B]>((2..8).map(Opt::Init).collect::<Vec<_>>());
385        let result = t1.zip_with(&t2, |a, b| a * b);
386        let expected = RawTensor::<i32>::from_vec::<m![A, B]>(
387            [2, 6, 12, 20, 30, 42].into_iter().map(Opt::Init).collect::<Vec<_>>(),
388        );
389        assert_eq!(result, expected);
390    }
391
392    #[test]
393    fn test_tensor_reduce() {
394        axes![A = 2, B = 3];
395
396        let t = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
397
398        // Reduce over B axis (retain A only)
399        let retain_axes = gen_axes::<m![A]>();
400        let result = t.reduce_add(&retain_axes);
401
402        let expected = RawTensor::<i32>::from_vec::<m![A]>([6, 15].into_iter().map(Opt::Init).collect::<Vec<_>>());
403        assert_eq!(result, expected);
404    }
405
406    #[test]
407    fn test_tensor_reduce_multiple_axes() {
408        axes![A = 2, B = 2, C = 2];
409
410        let t = RawTensor::<i32>::from_vec::<m![A, B, C]>((1..9).map(Opt::Init).collect::<Vec<_>>());
411
412        // Reduce over B and C axes (retain A only)
413        let retain_axes = gen_axes::<m![A]>();
414        let result = t.reduce_add(&retain_axes);
415
416        let expected = RawTensor::<i32>::from_vec::<m![A]>([10, 26].into_iter().map(Opt::Init).collect::<Vec<_>>());
417        assert_eq!(result, expected);
418    }
419
420    #[test]
421    fn test_tensor_map_elements() {
422        axes![A = 2, B = 3];
423
424        let t = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
425        let result = t.map(|&x| x * Opt::Init(2));
426
427        let expected =
428            RawTensor::<i32>::from_vec::<m![A, B]>([2, 4, 6, 8, 10, 12].into_iter().map(Opt::Init).collect::<Vec<_>>());
429        assert_eq!(result, expected);
430    }
431
432    #[test]
433    fn test_tensor_zip_with_custom_operation() {
434        axes![A = 2, B = 3];
435
436        let t1 = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
437        let t2 = RawTensor::<i32>::from_vec::<m![A, B]>((2..8).map(Opt::Init).collect::<Vec<_>>());
438
439        // Custom operation: (a * b) + 1
440        let result = t1.zip_with(&t2, |a, b| match (a, b) {
441            (Opt::Init(x), Opt::Init(y)) => Opt::Init(x * y + 1),
442            _ => Opt::Uninit,
443        });
444
445        let expected = RawTensor::<i32>::from_vec::<m![A, B]>(
446            [3, 7, 13, 21, 31, 43].into_iter().map(Opt::Init).collect::<Vec<_>>(),
447        );
448        assert_eq!(result, expected);
449    }
450}