1use std::marker::PhantomData;
2
3use abi_stable::std_types::RSlice;
4use furiosa_mapping::*;
5use ndarray::ArrayD;
6use ndarray::IxDyn;
7use rand::Rng;
8use rand::distr::StandardUniform;
9
10use super::raw_tensor::*;
11use super::scalar::*;
12use super::tensor_view::*;
13
14#[derive(Debug, Clone)]
16pub struct Tensor<D: Scalar, Mapping: M> {
17 inner: RawTensor<D>,
18 _marker: PhantomData<Mapping>,
19}
20
21impl<D: Scalar, Mapping: M> Tensor<D, Mapping> {
22 #[furiosa_mapping_macro::primitive(Tensor::from_buf)]
24 pub fn from_buf(data: Vec<Opt<D>>) -> Self {
25 <crate::runtime::CurrentBackend as crate::runtime::Backend>::from_buf::<D, Mapping>(data)
26 }
27
28 #[furiosa_mapping_macro::primitive(Tensor::to_buf)]
30 pub fn to_buf(&self) -> Vec<Opt<D>> {
31 <crate::runtime::CurrentBackend as crate::runtime::Backend>::to_buf::<D, Mapping>(self)
32 }
33
34 pub(crate) fn from_raw(inner: RawTensor<D>) -> Self {
35 Self {
36 inner,
37 _marker: PhantomData,
38 }
39 }
40
41 pub(crate) fn raw(&self) -> &RawTensor<D> {
42 &self.inner
43 }
44
45 pub fn from_fn<F>(f: F) -> Self
47 where
48 F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
49 {
50 Self {
51 inner: RawTensor::<D>::from_fn::<Mapping, F>(f),
52 _marker: PhantomData,
53 }
54 }
55
56 pub fn zero() -> Self {
58 Self::from_fn(|_, _| Opt::Init(D::zero()))
59 }
60
61 pub fn rand(rng: &mut impl Rng) -> Self
63 where
64 StandardUniform: rand::distr::Distribution<D>,
65 {
66 Self::from_fn(|_, _| Opt::Init(rng.random::<D>()))
67 }
68
69 pub fn uninit() -> Self {
71 Self::from_fn(|_, _| Opt::Uninit)
72 }
73
74 pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Mapping> {
76 TensorViewMut::new(&mut self.inner)
77 }
78
79 pub fn view<'l>(&'l self) -> TensorView<'l, D, Mapping> {
81 TensorView::new(&self.inner)
82 }
83
84 pub(crate) fn data(&self) -> &ndarray::ArrayD<Opt<D>> {
86 &self.inner.data
87 }
88
89 pub fn data_mut(&mut self) -> &mut ArrayD<Opt<D>> {
91 &mut self.inner.data
92 }
93
94 pub unsafe fn transmute<Mapping2: M>(self) -> Tensor<D, Mapping2> {
100 Tensor {
101 inner: self.inner,
102 _marker: PhantomData,
103 }
104 }
105
106 pub unsafe fn reshape<Src: M, Dst: M, Mapping2: M>(self) -> Tensor<D, Mapping2> {
113 if Src::to_value() == Dst::to_value() {
114 assert_eq!(Mapping::to_value(), Mapping2::to_value());
115
116 return Tensor {
117 inner: self.inner,
118 _marker: PhantomData,
119 };
120 }
121
122 assert_eq!(Src::SIZE, Dst::SIZE);
123 assert_eq!(
124 Src::to_value().factorize().into_inner().len(),
125 Dst::to_value().factorize().into_inner().len(),
126 "TODO: when Src/Dst have different length, src/dst quotient have different paddings"
127 );
128
129 let src_quotient = Mapping::to_value()
130 .divide_relaxed(&Src::to_value())
131 .exact()
132 .expect("[reshape] failed to divide by the mapping by reshape source expression")
133 .dividend_residue;
134
135 let dst_quotient = Mapping2::to_value()
136 .divide_relaxed(&Dst::to_value())
137 .exact()
138 .expect("[reshape] failed to divide by the output mapping by reshape destination expression")
139 .dividend_residue;
140
141 assert_eq!(
142 src_quotient, dst_quotient,
143 "[reshape] inconsistent reshape: quotient parts do not match"
144 );
145
146 let mut output = Tensor::<D, Mapping2>::uninit();
147
148 for idx in 0..Src::SIZE {
149 let src_offset = {
150 let mut offset = Index::new();
151 offset.add_mapping::<Src>(idx);
152 offset
153 };
154 let dst_offset = {
155 let mut offset = Index::new();
156 offset.add_mapping::<Dst>(idx);
157 offset
158 };
159
160 output.inner.write_broadcast(
161 &self.inner,
162 src_quotient.clone(),
163 FMapping::new(),
164 &src_offset,
165 &dst_offset,
166 );
167 }
168
169 output
170 }
171
172 pub fn map<D2: Scalar, F>(&self, f: F) -> Tensor<D2, Mapping>
174 where
175 F: Fn(&Opt<D>) -> Opt<D2>,
176 {
177 Tensor {
178 inner: self.inner.map(f),
179 _marker: PhantomData,
180 }
181 }
182
183 pub fn zip_with<D2: Scalar, D3: Scalar, F>(&self, other: &Tensor<D2, Mapping>, f: F) -> Tensor<D3, Mapping>
185 where
186 F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
187 {
188 Tensor {
189 inner: self.inner.zip_with(&other.inner, f),
190 _marker: PhantomData,
191 }
192 }
193
194 pub fn reduce_add<Dst: M>(&self) -> Tensor<D, Dst> {
196 Tensor {
197 inner: self.inner.reduce_add(&gen_axes::<Dst>()),
198 _marker: PhantomData,
199 }
200 }
201
202 pub fn reduce<Dst: M>(&self, reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>, identity: Opt<D>) -> Tensor<D, Dst> {
204 Tensor {
205 inner: self.inner.reduce(&gen_axes::<Dst>(), reduce_fn, identity),
206 _marker: PhantomData,
207 }
208 }
209
210 pub fn transpose<Dst: M>(&self, allow_broadcast: bool) -> Tensor<D, Dst> {
212 let mut dst = Tensor::<D, Dst>::uninit();
213 dst.view_mut().write_transpose(self.view(), allow_broadcast);
214 dst
215 }
216
217 pub fn reduce_then_broadcast<Dst: M>(&self) -> Tensor<D, Dst> {
227 let mut dst = Tensor::<D, Dst>::uninit();
228
229 let reduced = self.inner.reduce_add(&dst.inner.axes);
231
232 let src_fmapping = FMapping::from_axes(RSlice::from(reduced.axes.as_slice()));
234 let dst_fmapping = FMapping::from_axes(RSlice::from(dst.inner.axes.as_slice()));
235
236 let broadcast = dst_fmapping
238 .divide_relaxed(src_fmapping.clone())
239 .exact()
240 .expect("Failed to calculate broadcast factor")
241 .dividend_residue;
242
243 dst.inner
245 .write_broadcast(&reduced, src_fmapping, broadcast, &Index::new(), &Index::new());
246
247 dst
248 }
249
250 pub fn reduce_then_broadcast_with<Dst: M>(
252 &self,
253 reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
254 identity: Opt<D>,
255 ) -> Tensor<D, Dst> {
256 let mut dst = Tensor::<D, Dst>::uninit();
257
258 let reduced = self.inner.reduce(&dst.inner.axes, reduce_fn, identity);
260
261 let src_fmapping = FMapping::from_axes(RSlice::from(reduced.axes.as_slice()));
263 let dst_fmapping = FMapping::from_axes(RSlice::from(dst.inner.axes.as_slice()));
264
265 let broadcast = dst_fmapping
267 .divide_relaxed(src_fmapping.clone())
268 .exact()
269 .expect("Failed to calculate broadcast factor")
270 .dividend_residue;
271
272 dst.inner
274 .write_broadcast(&reduced, src_fmapping, broadcast, &Index::new(), &Index::new());
275
276 dst
277 }
278
279 pub fn write_scatter<Key: M, Dst: M, Idx: M>(
283 &self,
284 dst: &mut Tensor<D, Dst>,
285 index: &Tensor<i32, Idx>,
286 scaled: bool,
287 ) {
288 let src_fmapping = Mapping::to_value().factorize();
289 let dst_fmapping = Dst::to_value().factorize();
290 let key = Key::to_value().factorize();
291
292 let index_stride = if scaled {
293 let payload = src_fmapping
294 .clone()
295 .divide_relaxed(key.clone())
296 .exact()
297 .expect("Src must contain scatter key")
298 .dividend_residue;
299 payload.remove_padding().size() * std::mem::size_of::<D>()
300 } else {
301 1
302 };
303
304 let indices: Vec<usize> = index
305 .to_buf()
306 .into_iter()
307 .map(|opt| {
308 let Opt::Init(v) = opt else {
309 panic!("Scatter index must be initialized")
310 };
311 v as usize / index_stride
312 })
313 .collect();
314
315 dst.inner
316 .write_scatter(&self.inner, src_fmapping, dst_fmapping, key, &indices);
317 }
318
319 pub fn contraction<Union: M, Lhs: M, Rhs: M>(lhs: &Tensor<D, Lhs>, rhs: &Tensor<D, Rhs>) -> Self {
321 {
322 lhs.transpose::<Union>(true)
323 .zip_with(&rhs.transpose(true), |a, b| a * b)
324 .reduce_add()
325 }
326 }
327}