furiosa_visa_std/
tensor_state.rs

1//! Type-level encoding of tensor presence.
2//!
3//! The [`TensorState`] trait encodes at the type level whether a value currently
4//! holds a tensor and, if so, what scalar type and memory mapping that tensor has.
5//!
6//! Two implementations are provided:
7//! - [`NoTensor`] — no tensor is present (empty data).
8//! - [`HasTensor<D, Mapping>`] — a [`Tensor<D, Mapping>`] is present.
9
10use std::fmt::Debug;
11
12use crate::tensor::Tensor;
13use crate::vector_engine::scalar::VeScalar;
14use furiosa_mapping::M;
15
16/// Marker trait that tracks tensor presence at compile time.
17///
18/// The type parameter `D` ties the stored tensor's scalar type to the pipeline's current
19/// scalar type, ensuring at compile time that tensor reads match the pipeline's `D`.
20///
21/// Implementors either hold no data ([`NoTensor`]) or store a [`Tensor<D, Mapping>`] ([`HasTensor`]).
22pub trait TensorState<D: VeScalar>: Debug {
23    /// Clones the tensor data, transposing to target mapping if needed.
24    fn clone_tensor_as<TargetMapping: M>(&self) -> Option<Tensor<D, TargetMapping>>;
25}
26
27/// No tensor is present.
28#[derive(Debug)]
29pub struct NoTensor;
30impl<D: VeScalar> TensorState<D> for NoTensor {
31    fn clone_tensor_as<TargetMapping: M>(&self) -> Option<Tensor<D, TargetMapping>> {
32        None
33    }
34}
35
36/// A [`Tensor`] with scalar type `D` and memory layout `Mapping` is present.
37#[derive(Debug)]
38pub struct HasTensor<D: VeScalar, Mapping: M> {
39    data: Tensor<D, Mapping>,
40}
41
42impl<D: VeScalar, Mapping: M> HasTensor<D, Mapping> {
43    /// Wraps a tensor into a `HasTensor`.
44    pub fn new(tensor: Tensor<D, Mapping>) -> Self {
45        Self { data: tensor }
46    }
47}
48
49impl<D: VeScalar, Mapping: M> From<Tensor<D, Mapping>> for HasTensor<D, Mapping> {
50    fn from(tensor: Tensor<D, Mapping>) -> Self {
51        Self::new(tensor)
52    }
53}
54
55impl<D: VeScalar, Mapping: M> TensorState<D> for HasTensor<D, Mapping> {
56    fn clone_tensor_as<TargetMapping: M>(&self) -> Option<Tensor<D, TargetMapping>> {
57        let cloned = self.data.clone();
58        Some(cloned.transpose::<TargetMapping>(true))
59    }
60}