furiosa_opt_macro/
lib.rs

1//! Macros for virtual ISA.
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{Data, DeriveInput, Fields, Item, Type, parse_macro_input};
7
8/// Macro for mapping expressions.
9///
10/// See the documentation for `furiosa-visa-std` crate for details.
11///
12/// # Examples
13///
14/// ```ignore
15/// use furiosa_visa_std::prelude::*;
16/// axes![A = 512, B = 4];
17/// type AB = m![A, B];
18/// assert_eq!(AB::SIZE, 2048);
19/// ```
20#[proc_macro]
21pub fn m(input: TokenStream) -> TokenStream {
22    let input: proc_macro2::TokenStream = input.into();
23    let lexer = furiosa_mapping::parser::Lexer::new(input, furiosa_mapping::parser::LexerMode::Mapping);
24    let parser = furiosa_mapping::parser::MappingParser::new();
25    let mapping = match parser.parse(lexer) {
26        Ok(mapping) => mapping,
27        Err(e) => {
28            let msg = format!("Parse error: {:?}", e);
29            return syn::Error::new(proc_macro2::Span::call_site(), msg)
30                .to_compile_error()
31                .into();
32        }
33    };
34    let expanded = mapping.expand();
35    quote! { #expanded }.into()
36}
37
38/// Macro for index expressions.
39///
40/// See the documentation for `furiosa-visa-std` crate for details.
41///
42/// # Examples
43///
44/// ```ignore
45/// use furiosa_visa_std::prelude::*;
46/// axes![A = 512, B = 64];
47/// let idx = i![A / 32 = 8, B = 10];
48/// ```
49#[proc_macro]
50pub fn i(input: TokenStream) -> TokenStream {
51    let input: proc_macro2::TokenStream = input.into();
52    let lexer = furiosa_mapping::parser::Lexer::new(input, furiosa_mapping::parser::LexerMode::Index);
53    let parser = furiosa_mapping::parser::IndexParser::new();
54    let assignments = match parser.parse(lexer) {
55        Ok(assignments) => assignments,
56        Err(e) => {
57            let msg = format!("Parse error: {:?}", e);
58            return syn::Error::new(proc_macro2::Span::call_site(), msg)
59                .to_compile_error()
60                .into();
61        }
62    };
63
64    let expansions = assignments.iter().map(|assignment| assignment.expand());
65
66    quote! {
67        {
68            let mut index = ::furiosa_mapping::Index::new();
69            #(#expansions)*
70            index
71        }
72    }
73    .into()
74}
75
76/// Derive macro for DeviceSend trait.
77///
78/// Generates implementation with bounds requiring all fields to be `DeviceSend`.
79///
80/// # Compile-time Checks
81///
82/// All fields must implement `DeviceSend`. This ensures:
83/// - Reference fields are rejected (references don't impl DeviceSend)
84/// - Nested types must also be DeviceSend
85///
86/// # Example
87///
88/// ```ignore
89/// #[derive(DeviceSend)]
90/// struct MyTensor<D: Scalar, Chip: M, Element: M> {
91///     inner: Tensor<D, Pair<Chip, Element>>,  // Tensor must impl DeviceSend
92/// }
93/// // Generates:
94/// // impl<...> DeviceSend for MyTensor<...>
95/// // where
96/// //     Tensor<...>: DeviceSend,
97/// // {}
98/// ```
99#[proc_macro_derive(DeviceSend)]
100pub fn device_send(input: TokenStream) -> TokenStream {
101    /// Collect field types from a struct for where bounds.
102    fn field_types(data: &Data) -> Vec<&Type> {
103        match data {
104            Data::Struct(data) => match &data.fields {
105                Fields::Named(f) => f.named.iter().map(|f| &f.ty).collect(),
106                Fields::Unnamed(f) => f.unnamed.iter().map(|f| &f.ty).collect(),
107                Fields::Unit => vec![],
108            },
109            Data::Enum(_) | Data::Union(_) => vec![],
110        }
111    }
112
113    /// Build where predicates requiring fields to be DeviceSend.
114    fn device_send_predicates(field_types: &[&Type]) -> Vec<TokenStream2> {
115        field_types
116            .iter()
117            .map(|ty| quote! { #ty: crate::runtime::DeviceSend })
118            .collect()
119    }
120
121    let input = parse_macro_input!(input as DeriveInput);
122    let name = &input.ident;
123    let fields = field_types(&input.data);
124    let predicates = device_send_predicates(&fields);
125    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
126
127    let expanded = if let Some(wc) = where_clause {
128        quote! {
129            impl #impl_generics crate::runtime::DeviceSend for #name #ty_generics
130            #wc, #(#predicates),*
131            {}
132        }
133    } else {
134        quote! {
135            impl #impl_generics crate::runtime::DeviceSend for #name #ty_generics
136            where #(#predicates),*
137            {}
138        }
139    };
140
141    expanded.into()
142}
143
144/// Marks a function as a device entry point for `launch()`.
145///
146/// Generates a unit struct implementing `DeviceFn` with `execute()`.
147/// `cargo <subcommand>`: `execute()` calls the original function body (CPU).
148/// `cargo furiosa-opt <subcommand>`: `execute()` loads the compiled EDF and runs on NPU.
149#[proc_macro_attribute]
150pub fn device(attr: TokenStream, item: TokenStream) -> TokenStream {
151    fn to_camel(s: &str) -> String {
152        s.split('_')
153            .map(|w| {
154                let mut c = w.chars();
155                c.next()
156                    .map_or(String::new(), |ch| ch.to_uppercase().collect::<String>() + c.as_str())
157            })
158            .collect()
159    }
160
161    let attr_str = attr.to_string();
162    let func = match parse_macro_input!(item as Item) {
163        Item::Fn(f) => f,
164        other => {
165            return syn::Error::new_spanned(other, "#[device] can only be applied to functions")
166                .to_compile_error()
167                .into();
168        }
169    };
170
171    let vis = &func.vis;
172    let name = &func.sig.ident;
173    let name_str = name.to_string();
174    let hidden = syn::Ident::new(&format!("__tcp_{name}"), name.span());
175    let struct_name = syn::Ident::new(&to_camel(&name_str), name.span());
176    let syn::Signature {
177        inputs,
178        output,
179        generics,
180        ..
181    } = &func.sig;
182
183    #[derive(Clone, Copy, PartialEq)]
184    enum Kind {
185        Context,
186        Tensor,
187    }
188
189    let params: Vec<_> = inputs
190        .iter()
191        .filter_map(|a| match a {
192            syn::FnArg::Typed(pt) => Some(pt),
193            _ => None,
194        })
195        .enumerate()
196        .map(|(i, pt)| {
197            let name = match pt.pat.as_ref() {
198                syn::Pat::Ident(id) => id.ident.clone(),
199                _ => syn::Ident::new(&format!("__arg_{i}"), proc_macro2::Span::call_site()),
200            };
201            let ty = &pt.ty;
202            let s = quote!(#ty).to_string();
203            // Heuristic: Context params (DmaContext, TuContext, etc.) are CPU-side scheduling
204            // abstractions that don't exist on device — they'll be prefixed `_` in execute().
205            let kind = if s.contains("Context") {
206                Kind::Context
207            } else {
208                Kind::Tensor
209            };
210            (name, quote!(#ty), kind)
211        })
212        .collect();
213
214    let types: Vec<_> = params.iter().map(|(_, t, _)| t).collect();
215
216    // For each tensor param, convert to a DMA Buffer before passing to Kernel::run().
217    // Reference params (`&HbmTensor`): `(&*name).into()` to reborrow.
218    // Owned params (`HbmTensorView`): `(&name).into()` since there's nothing to deref.
219    let (tensor_bufs, tensor_stmts): (Vec<syn::Ident>, Vec<TokenStream2>) = params
220        .iter()
221        .filter(|(_, _, k)| *k == Kind::Tensor)
222        .enumerate()
223        .map(|(i, (name, ty, _))| {
224            let buf = syn::Ident::new(&format!("__tcp_{i}"), proc_macro2::Span::call_site());
225            let is_ref = ty.to_string().starts_with('&');
226            let conv = if is_ref {
227                quote! { let #buf: furiosa_visa_std::runtime::Buffer = (&*#name).into(); }
228            } else {
229                quote! { let #buf: furiosa_visa_std::runtime::Buffer = (&(#name)).into(); }
230            };
231            (buf, conv)
232        })
233        .unzip();
234
235    let run_body = match output {
236        syn::ReturnType::Type(_, ty) => quote! {
237            let __tcp_out = __tcp_kernel.alloc(<#ty>::size());
238            __tcp_kernel.run(&[#(#tensor_bufs),*], &[__tcp_out.clone()]).await;
239            __tcp_out.into()
240        },
241        syn::ReturnType::Default => quote! {
242            __tcp_kernel.run(&[#(#tensor_bufs),*], &[]).await;
243        },
244    };
245
246    let tuple_type = if types.len() == 1 {
247        quote!(#(#types)*)
248    } else {
249        quote!((#(#types),*))
250    };
251    let return_ty = match output {
252        syn::ReturnType::Default => quote!(()),
253        syn::ReturnType::Type(_, ty) => quote!(#ty),
254    };
255    let block = &func.block;
256
257    // Destructure the tuple param of `execute()`. Context params are prefixed
258    // with `_` because the NPU branch doesn't read them (kernels run on-device);
259    // the CPU branch uses the _-prefixed names when calling the hidden fn.
260    let param_names: Vec<syn::Ident> = params
261        .iter()
262        .map(|(n, _, k)| match k {
263            Kind::Context => syn::Ident::new(&format!("_{n}"), n.span()),
264            Kind::Tensor => n.clone(),
265        })
266        .collect();
267    let body_destructure = if param_names.len() == 1 {
268        quote!(#(#param_names)*)
269    } else {
270        quote!((#(#param_names),*))
271    };
272
273    let npu_body = quote! {
274        static __TCP_KERNEL: furiosa_visa_std::OnceCell<furiosa_visa_std::runtime::Kernel> =
275            furiosa_visa_std::OnceCell::const_new();
276        let __tcp_kernel = __TCP_KERNEL.get_or_init(|| async {
277            let __tcp_path = furiosa_visa_std::runtime::kernel_path(
278                env!("FURIOSA_OPT_OUT_DIR"),
279                env!("CARGO_PKG_NAME"),
280                module_path!(),
281                #name_str,
282            );
283            furiosa_visa_std::runtime::Kernel::load(&__tcp_path).await
284        }).await;
285        #(#tensor_stmts)*
286        #run_body
287    };
288    let cpu_body = quote! { #hidden(#(#param_names),*) };
289
290    quote! {
291        #[tcp::device = #attr_str]
292        // `#[allow]` (not `#[expect]`): the hidden fn may or may not trigger
293        // each of these lints depending on how the user defined the device
294        // function, and `#[expect]` fails when the lint doesn't fire.
295        #[allow(dead_code, unused, clippy::too_many_arguments)]
296        fn #hidden #generics (#inputs) #output #block
297
298        #[derive(Debug)]
299        #vis struct #struct_name;
300
301        // `#[allow]`: `#[expect]` requires the lint to fire, which it does
302        // for names like `my_fn` but NOT for capitalized device-function
303        // names like `MatMul`.
304        #[allow(non_upper_case_globals)]
305        #vis const #name: #struct_name = #struct_name;
306
307        impl #generics furiosa_visa_std::runtime::DeviceFn<#tuple_type> for #struct_name {
308            type Output = #return_ty;
309            fn execute(#body_destructure: #tuple_type) -> impl std::future::Future<Output = Self::Output> {
310                async move {
311                    #[cfg(furiosa_opt)]
312                    { #npu_body }
313                    #[cfg(not(furiosa_opt))]
314                    { #cpu_body }
315                }
316            }
317        }
318    }
319    .into()
320}