1use ndarray::{ArrayD, Axis, IxDyn};
2use num_traits::Zero;
3use std::{fmt::Debug, marker::PhantomData};
4
5use abi_stable::std_types::RResult;
6use furiosa_mapping::*;
7
8use super::scalar::*;
9
10pub(crate) fn gen_axes<Mapping: M>() -> Vec<Term> {
12 let mut index = Index::new();
13 index.add_mapping::<Mapping>(0);
14 index
15 .finalize()
16 .expect("Invalid mapping")
17 .into_iter()
18 .map(|(term, _)| term)
19 .collect::<Vec<_>>()
20}
21
22#[derive(Debug, Clone, PartialEq)]
24pub struct RawTensor<D: Scalar> {
25 pub(crate) axes: Vec<Term>,
27 pub(crate) data: ArrayD<Opt<D>>,
29 pub(crate) _marker: PhantomData<D>,
30}
31
32impl<D: Scalar> Eq for RawTensor<D> where D: Eq {}
33
34impl<D: Scalar> RawTensor<D> {
35 pub fn from_elem<Mapping: M>(elem: Opt<D>) -> Self {
37 let axes = gen_axes::<Mapping>();
39
40 let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
42 let data = ArrayD::from_elem(shape, elem);
43
44 Self {
45 axes,
46 data,
47 _marker: PhantomData,
48 }
49 }
50
51 pub fn from_vec<Mapping: M>(data: Vec<Opt<D>>) -> Self {
55 let axes = gen_axes::<Mapping>();
57
58 let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
60 let data = ArrayD::from_shape_vec(shape, data).expect("Data length does not match tensor shape.");
61
62 Self {
63 axes,
64 data,
65 _marker: PhantomData,
66 }
67 }
68
69 pub fn from_fn<Mapping: M, F>(mut f: F) -> Self
71 where
72 F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
73 {
74 let axes = gen_axes::<Mapping>();
76
77 let shape = axes.iter().map(|term| term.modulo).collect::<Vec<usize>>();
79 let data = ArrayD::from_shape_fn(shape, |idx| f(&axes, &idx));
80
81 Self {
82 axes,
83 data,
84 _marker: PhantomData,
85 }
86 }
87
88 pub fn map<D2: Scalar, F>(&self, f: F) -> RawTensor<D2>
96 where
97 F: FnMut(&Opt<D>) -> Opt<D2>,
98 {
99 let data = self.data.map(f);
100 RawTensor {
101 axes: self.axes.clone(),
102 data,
103 _marker: PhantomData,
104 }
105 }
106
107 pub fn zip_with<D2: Scalar, D3: Scalar, F>(&self, other: &RawTensor<D2>, f: F) -> RawTensor<D3>
116 where
117 F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
118 {
119 assert_eq!(
120 self.axes, other.axes,
121 "Tensors must have the same axes for element-wise binary operations"
122 );
123
124 let data = ndarray::Zip::from(&self.data)
125 .and(&other.data)
126 .map_collect(|&a, &b| f(a, b));
127
128 RawTensor {
129 axes: self.axes.clone(),
130 data,
131 _marker: PhantomData,
132 }
133 }
134
135 pub fn reduce(
137 &self,
138 retain_axes: &[Term],
139 reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
140 identity: Opt<D>,
141 ) -> RawTensor<D> {
142 let reduce_axes: Vec<Term> = self
143 .axes
144 .iter()
145 .filter(|src_term| !retain_axes.contains(src_term))
146 .cloned()
147 .collect();
148 self.reduce_for(&reduce_axes, reduce_fn, identity)
149 }
150
151 pub fn reduce_add(&self, retain_axes: &[Term]) -> RawTensor<D> {
154 self.reduce(retain_axes, |a, b| a + b, Opt::zero())
155 }
156
157 fn reduce_for(
160 &self,
161 reduce_axes: &[Term],
162 reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
163 identity: Opt<D>,
164 ) -> RawTensor<D> {
165 let reduce_indices: Vec<usize> = reduce_axes
167 .iter()
168 .filter_map(|reduce_term| self.axes.iter().position(|axis_term| axis_term == reduce_term))
169 .collect();
170
171 let axes: Vec<Term> = self
173 .axes
174 .iter()
175 .enumerate()
176 .filter_map(|(idx, term)| {
177 if reduce_indices.contains(&idx) {
178 None
179 } else {
180 Some(term.clone())
181 }
182 })
183 .collect();
184
185 let data = if reduce_indices.is_empty() {
187 self.data.clone()
189 } else if reduce_indices.len() == self.axes.len() {
190 let sum_value = self.data.sum();
192 ArrayD::from_shape_vec(IxDyn(&[]), vec![sum_value]).expect("Failed to create scalar array")
193 } else {
194 let mut sorted_indices = reduce_indices.clone();
197 sorted_indices.sort_unstable();
198 sorted_indices.dedup();
199
200 for &idx in &sorted_indices {
202 if idx >= self.axes.len() {
203 panic!(
204 "[TensorValue::reduce] Invalid axis index: {} (tensor has {} axes)\n\
205 axes: {:?}\n\
206 reduce_axes: {:?}\n\
207 reduce_indices: {:?}",
208 idx,
209 self.axes.len(),
210 self.axes,
211 reduce_axes,
212 sorted_indices
213 );
214 }
215 }
216
217 let mut data = self.data.clone();
218 for &axis_idx in sorted_indices.iter().rev() {
219 if axis_idx >= data.ndim() {
220 panic!(
221 "[TensorValue::reduce] axis_idx {} >= data.ndim() {}\n\
222 current data shape: {:?}\n\
223 axes: {:?}\n\
224 reduce_indices: {:?}",
225 axis_idx,
226 data.ndim(),
227 data.shape(),
228 self.axes,
229 sorted_indices
230 );
231 }
232 data = data.fold_axis(Axis(axis_idx), identity, |&acc, &val| reduce_fn(acc, val));
233 }
234 data
235 };
236
237 Self {
238 axes,
239 data,
240 _marker: PhantomData,
241 }
242 }
243
244 pub fn write_broadcast(
253 &mut self,
254 src: &Self,
255 unicast: FMapping,
256 broadcast: FMapping,
257 src_offset: &Index,
258 dst_offset: &Index,
259 ) {
260 for index in Index::new().gen_indexes(unicast) {
262 let mut src_index = index.clone();
264 src_index.add(src_offset.clone());
265 let value = src.read_index(src_index);
266
267 let mut dst_index = index;
269 dst_index.add(dst_offset.clone());
270
271 for dst_index in dst_index.gen_indexes(broadcast.clone()) {
273 self.write_index(dst_index, value);
274 }
275 }
276 }
277
278 pub(crate) fn write_scatter(
294 &mut self,
295 src: &Self,
296 src_mapping: FMapping,
297 dst_mapping: FMapping,
298 key: FMapping,
299 indices: &[usize],
300 ) {
301 let payload = src_mapping
302 .clone()
303 .divide_relaxed(key.clone())
304 .exact()
305 .unwrap_or_else(|e| panic!("Scatter key `{key:?}` not found in source `{src_mapping:?}`: {e:?}"))
306 .dividend_residue;
307 let dst_residue = dst_mapping
308 .clone()
309 .divide_relaxed(payload.clone())
310 .exact()
311 .unwrap_or_else(|e| panic!("Destination `{dst_mapping:?}` missing payload axes `{payload:?}`: {e:?}"))
312 .dividend_residue;
313 let (dst_term, _) = dst_residue
314 .into_inner()
315 .into_iter()
316 .find_map(|f| match f {
317 Factor::Term { inner, resize } => Some((inner, resize)),
318 Factor::Padding { .. } => None,
319 })
320 .expect("Destination has no scatter target axis after removing payload");
321
322 for payload_index in Index::new().gen_indexes(payload) {
323 for (key_pos, key_index) in Index::new().gen_indexes(key.clone()).into_iter().enumerate() {
324 let mut src_index = payload_index.clone();
325 src_index.add(key_index);
326 let value = src.read_index(src_index);
327
328 let mut dst_index = payload_index.clone();
329 dst_index.add_term(dst_term.clone(), indices[key_pos]);
330 self.write_index(dst_index, value);
331 }
332 }
333 }
334
335 pub(crate) fn read_index(&self, index: Index) -> Opt<D> {
337 let RResult::ROk(index) = index.finalize() else {
339 return Opt::Uninit;
340 };
341 assert!(
342 self.axes.iter().zip(index.iter()).all(|(a, (b, _))| a == b),
343 "Index terms ({:?}) do not match tensor axes ({:?}).",
344 index,
345 self.axes
346 );
347
348 *self
350 .data
351 .get(index.into_iter().map(|(_, v)| v).collect::<Vec<usize>>().as_slice())
352 .expect("Index out of bounds.")
353 }
354
355 pub(crate) fn write_index(&mut self, index: Index, value: Opt<D>) {
357 let RResult::ROk(index) = index.finalize() else {
359 return;
360 };
361 assert!(
362 self.axes.iter().zip(index.iter()).all(|(a, (b, _))| a == b),
363 "Index terms do not match tensor axes."
364 );
365
366 *self
368 .data
369 .get_mut(index.into_iter().map(|(_, v)| v).collect::<Vec<usize>>().as_slice())
370 .expect("Index out of bounds.") = value;
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::prelude::*;
378
379 #[test]
380 fn test_tensor_zip_with() {
381 axes![A = 2, B = 3];
382
383 let t1 = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
384 let t2 = RawTensor::<i32>::from_vec::<m![A, B]>((2..8).map(Opt::Init).collect::<Vec<_>>());
385 let result = t1.zip_with(&t2, |a, b| a * b);
386 let expected = RawTensor::<i32>::from_vec::<m![A, B]>(
387 [2, 6, 12, 20, 30, 42].into_iter().map(Opt::Init).collect::<Vec<_>>(),
388 );
389 assert_eq!(result, expected);
390 }
391
392 #[test]
393 fn test_tensor_reduce() {
394 axes![A = 2, B = 3];
395
396 let t = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
397
398 let retain_axes = gen_axes::<m![A]>();
400 let result = t.reduce_add(&retain_axes);
401
402 let expected = RawTensor::<i32>::from_vec::<m![A]>([6, 15].into_iter().map(Opt::Init).collect::<Vec<_>>());
403 assert_eq!(result, expected);
404 }
405
406 #[test]
407 fn test_tensor_reduce_multiple_axes() {
408 axes![A = 2, B = 2, C = 2];
409
410 let t = RawTensor::<i32>::from_vec::<m![A, B, C]>((1..9).map(Opt::Init).collect::<Vec<_>>());
411
412 let retain_axes = gen_axes::<m![A]>();
414 let result = t.reduce_add(&retain_axes);
415
416 let expected = RawTensor::<i32>::from_vec::<m![A]>([10, 26].into_iter().map(Opt::Init).collect::<Vec<_>>());
417 assert_eq!(result, expected);
418 }
419
420 #[test]
421 fn test_tensor_map_elements() {
422 axes![A = 2, B = 3];
423
424 let t = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
425 let result = t.map(|&x| x * Opt::Init(2));
426
427 let expected =
428 RawTensor::<i32>::from_vec::<m![A, B]>([2, 4, 6, 8, 10, 12].into_iter().map(Opt::Init).collect::<Vec<_>>());
429 assert_eq!(result, expected);
430 }
431
432 #[test]
433 fn test_tensor_zip_with_custom_operation() {
434 axes![A = 2, B = 3];
435
436 let t1 = RawTensor::<i32>::from_vec::<m![A, B]>((1..7).map(Opt::Init).collect::<Vec<_>>());
437 let t2 = RawTensor::<i32>::from_vec::<m![A, B]>((2..8).map(Opt::Init).collect::<Vec<_>>());
438
439 let result = t1.zip_with(&t2, |a, b| match (a, b) {
441 (Opt::Init(x), Opt::Init(y)) => Opt::Init(x * y + 1),
442 _ => Opt::Uninit,
443 });
444
445 let expected = RawTensor::<i32>::from_vec::<m![A, B]>(
446 [3, 7, 13, 21, 31, 43].into_iter().map(Opt::Init).collect::<Vec<_>>(),
447 );
448 assert_eq!(result, expected);
449 }
450}