1use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{Data, DeriveInput, Fields, Item, Type, parse_macro_input};
7
8#[proc_macro]
21pub fn m(input: TokenStream) -> TokenStream {
22 let input: proc_macro2::TokenStream = input.into();
23 let lexer = furiosa_mapping::parser::Lexer::new(input, furiosa_mapping::parser::LexerMode::Mapping);
24 let parser = furiosa_mapping::parser::MappingParser::new();
25 let mapping = match parser.parse(lexer) {
26 Ok(mapping) => mapping,
27 Err(e) => {
28 let msg = format!("Parse error: {:?}", e);
29 return syn::Error::new(proc_macro2::Span::call_site(), msg)
30 .to_compile_error()
31 .into();
32 }
33 };
34 let expanded = mapping.expand();
35 quote! { #expanded }.into()
36}
37
38#[proc_macro]
50pub fn i(input: TokenStream) -> TokenStream {
51 let input: proc_macro2::TokenStream = input.into();
52 let lexer = furiosa_mapping::parser::Lexer::new(input, furiosa_mapping::parser::LexerMode::Index);
53 let parser = furiosa_mapping::parser::IndexParser::new();
54 let assignments = match parser.parse(lexer) {
55 Ok(assignments) => assignments,
56 Err(e) => {
57 let msg = format!("Parse error: {:?}", e);
58 return syn::Error::new(proc_macro2::Span::call_site(), msg)
59 .to_compile_error()
60 .into();
61 }
62 };
63
64 let expansions = assignments.iter().map(|assignment| assignment.expand());
65
66 quote! {
67 {
68 let mut index = ::furiosa_mapping::Index::new();
69 #(#expansions)*
70 index
71 }
72 }
73 .into()
74}
75
76#[proc_macro_derive(DeviceSend)]
100pub fn device_send(input: TokenStream) -> TokenStream {
101 fn field_types(data: &Data) -> Vec<&Type> {
103 match data {
104 Data::Struct(data) => match &data.fields {
105 Fields::Named(f) => f.named.iter().map(|f| &f.ty).collect(),
106 Fields::Unnamed(f) => f.unnamed.iter().map(|f| &f.ty).collect(),
107 Fields::Unit => vec![],
108 },
109 Data::Enum(_) | Data::Union(_) => vec![],
110 }
111 }
112
113 fn device_send_predicates(field_types: &[&Type]) -> Vec<TokenStream2> {
115 field_types
116 .iter()
117 .map(|ty| quote! { #ty: crate::runtime::DeviceSend })
118 .collect()
119 }
120
121 let input = parse_macro_input!(input as DeriveInput);
122 let name = &input.ident;
123 let fields = field_types(&input.data);
124 let predicates = device_send_predicates(&fields);
125 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
126
127 let expanded = if let Some(wc) = where_clause {
128 quote! {
129 impl #impl_generics crate::runtime::DeviceSend for #name #ty_generics
130 #wc, #(#predicates),*
131 {}
132 }
133 } else {
134 quote! {
135 impl #impl_generics crate::runtime::DeviceSend for #name #ty_generics
136 where #(#predicates),*
137 {}
138 }
139 };
140
141 expanded.into()
142}
143
144#[proc_macro_attribute]
150pub fn device(attr: TokenStream, item: TokenStream) -> TokenStream {
151 fn to_camel(s: &str) -> String {
152 s.split('_')
153 .map(|w| {
154 let mut c = w.chars();
155 c.next()
156 .map_or(String::new(), |ch| ch.to_uppercase().collect::<String>() + c.as_str())
157 })
158 .collect()
159 }
160
161 let attr_str = attr.to_string();
162 let func = match parse_macro_input!(item as Item) {
163 Item::Fn(f) => f,
164 other => {
165 return syn::Error::new_spanned(other, "#[device] can only be applied to functions")
166 .to_compile_error()
167 .into();
168 }
169 };
170
171 let vis = &func.vis;
172 let name = &func.sig.ident;
173 let name_str = name.to_string();
174 let hidden = syn::Ident::new(&format!("__tcp_{name}"), name.span());
175 let struct_name = syn::Ident::new(&to_camel(&name_str), name.span());
176 let syn::Signature {
177 inputs,
178 output,
179 generics,
180 ..
181 } = &func.sig;
182
183 #[derive(Clone, Copy, PartialEq)]
184 enum Kind {
185 Context,
186 Tensor,
187 }
188
189 let params: Vec<_> = inputs
190 .iter()
191 .filter_map(|a| match a {
192 syn::FnArg::Typed(pt) => Some(pt),
193 _ => None,
194 })
195 .enumerate()
196 .map(|(i, pt)| {
197 let name = match pt.pat.as_ref() {
198 syn::Pat::Ident(id) => id.ident.clone(),
199 _ => syn::Ident::new(&format!("__arg_{i}"), proc_macro2::Span::call_site()),
200 };
201 let ty = &pt.ty;
202 let s = quote!(#ty).to_string();
203 let kind = if s.contains("Context") {
206 Kind::Context
207 } else {
208 Kind::Tensor
209 };
210 (name, quote!(#ty), kind)
211 })
212 .collect();
213
214 let types: Vec<_> = params.iter().map(|(_, t, _)| t).collect();
215
216 let (tensor_bufs, tensor_stmts): (Vec<syn::Ident>, Vec<TokenStream2>) = params
220 .iter()
221 .filter(|(_, _, k)| *k == Kind::Tensor)
222 .enumerate()
223 .map(|(i, (name, ty, _))| {
224 let buf = syn::Ident::new(&format!("__tcp_{i}"), proc_macro2::Span::call_site());
225 let is_ref = ty.to_string().starts_with('&');
226 let conv = if is_ref {
227 quote! { let #buf: furiosa_visa_std::runtime::Buffer = (&*#name).into(); }
228 } else {
229 quote! { let #buf: furiosa_visa_std::runtime::Buffer = (&(#name)).into(); }
230 };
231 (buf, conv)
232 })
233 .unzip();
234
235 let run_body = match output {
236 syn::ReturnType::Type(_, ty) => quote! {
237 let __tcp_out = __tcp_kernel.alloc(<#ty>::size());
238 __tcp_kernel.run(&[#(#tensor_bufs),*], &[__tcp_out.clone()]).await;
239 __tcp_out.into()
240 },
241 syn::ReturnType::Default => quote! {
242 __tcp_kernel.run(&[#(#tensor_bufs),*], &[]).await;
243 },
244 };
245
246 let tuple_type = if types.len() == 1 {
247 quote!(#(#types)*)
248 } else {
249 quote!((#(#types),*))
250 };
251 let return_ty = match output {
252 syn::ReturnType::Default => quote!(()),
253 syn::ReturnType::Type(_, ty) => quote!(#ty),
254 };
255 let block = &func.block;
256
257 let param_names: Vec<syn::Ident> = params
261 .iter()
262 .map(|(n, _, k)| match k {
263 Kind::Context => syn::Ident::new(&format!("_{n}"), n.span()),
264 Kind::Tensor => n.clone(),
265 })
266 .collect();
267 let body_destructure = if param_names.len() == 1 {
268 quote!(#(#param_names)*)
269 } else {
270 quote!((#(#param_names),*))
271 };
272
273 let npu_body = quote! {
274 static __TCP_KERNEL: furiosa_visa_std::OnceCell<furiosa_visa_std::runtime::Kernel> =
275 furiosa_visa_std::OnceCell::const_new();
276 let __tcp_kernel = __TCP_KERNEL.get_or_init(|| async {
277 let __tcp_path = furiosa_visa_std::runtime::kernel_path(
278 env!("FURIOSA_OPT_OUT_DIR"),
279 env!("CARGO_PKG_NAME"),
280 module_path!(),
281 #name_str,
282 );
283 furiosa_visa_std::runtime::Kernel::load(&__tcp_path).await
284 }).await;
285 #(#tensor_stmts)*
286 #run_body
287 };
288 let cpu_body = quote! { #hidden(#(#param_names),*) };
289
290 quote! {
291 #[tcp::device = #attr_str]
292 #[allow(dead_code, unused, clippy::too_many_arguments)]
296 fn #hidden #generics (#inputs) #output #block
297
298 #[derive(Debug)]
299 #vis struct #struct_name;
300
301 #[allow(non_upper_case_globals)]
305 #vis const #name: #struct_name = #struct_name;
306
307 impl #generics furiosa_visa_std::runtime::DeviceFn<#tuple_type> for #struct_name {
308 type Output = #return_ty;
309 fn execute(#body_destructure: #tuple_type) -> impl std::future::Future<Output = Self::Output> {
310 async move {
311 #[cfg(furiosa_opt)]
312 { #npu_body }
313 #[cfg(not(furiosa_opt))]
314 { #cpu_body }
315 }
316 }
317 }
318 }
319 .into()
320}