pallet_evm_polkavm_proc_macro/
lib.rs1use proc_macro::TokenStream;
24use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
25use quote::{quote, ToTokens};
26use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg, Ident};
27
28#[proc_macro_attribute]
29pub fn unstable_hostfn(_attr: TokenStream, item: TokenStream) -> TokenStream {
30 let input = syn::parse_macro_input!(item as syn::Item);
31 let expanded = quote! {
32 #[cfg(feature = "unstable-hostfn")]
33 #[cfg_attr(docsrs, doc(cfg(feature = "unstable-hostfn")))]
34 #input
35 };
36 expanded.into()
37}
38
39#[proc_macro_attribute]
50pub fn define_env(attr: TokenStream, item: TokenStream) -> TokenStream {
51 if !attr.is_empty() {
52 let msg = r#"Invalid `define_env` attribute macro: expected no attributes:
53 - `#[define_env]`"#;
54 let span = TokenStream2::from(attr).span();
55 return syn::Error::new(span, msg).to_compile_error().into();
56 }
57
58 let item = syn::parse_macro_input!(item as syn::ItemMod);
59
60 match EnvDef::try_from(item) {
61 Ok(def) => expand_env(&def).into(),
62 Err(e) => e.to_compile_error().into(),
63 }
64}
65
66struct EnvDef {
68 host_funcs: Vec<HostFn>,
69}
70
71struct HostFn {
73 item: syn::ItemFn,
74 is_stable: bool,
75 name: String,
76 returns: HostFnReturn,
77 cfg: Option<syn::Attribute>,
78}
79
80enum HostFnReturn {
81 Unit,
82 U32,
83 U64,
84 ReturnCode,
85}
86
87impl HostFnReturn {
88 fn map_output(&self) -> TokenStream2 {
89 match self {
90 Self::Unit => quote! { |_| None },
91 _ => quote! { |ret_val| Some(ret_val.into()) },
92 }
93 }
94
95 fn success_type(&self) -> syn::ReturnType {
96 match self {
97 Self::Unit => syn::ReturnType::Default,
98 Self::U32 => parse_quote! { -> u32 },
99 Self::U64 => parse_quote! { -> u64 },
100 Self::ReturnCode => parse_quote! { -> ReturnErrorCode },
101 }
102 }
103}
104
105impl EnvDef {
106 pub fn try_from(item: syn::ItemMod) -> syn::Result<Self> {
107 let span = item.span();
108 let err = |msg| syn::Error::new(span, msg);
109 let items = &item
110 .content
111 .as_ref()
112 .ok_or(err(
113 "Invalid environment definition, expected `mod` to be inlined.",
114 ))?
115 .1;
116
117 let extract_fn = |i: &syn::Item| match i {
118 syn::Item::Fn(i_fn) => Some(i_fn.clone()),
119 _ => None,
120 };
121
122 let host_funcs = items
123 .iter()
124 .filter_map(extract_fn)
125 .map(HostFn::try_from)
126 .collect::<Result<Vec<_>, _>>()?;
127
128 Ok(Self { host_funcs })
129 }
130}
131
132impl HostFn {
133 pub fn try_from(mut item: syn::ItemFn) -> syn::Result<Self> {
134 let err = |span, msg| {
135 let msg = format!("Invalid host function definition.\n{msg}");
136 syn::Error::new(span, msg)
137 };
138
139 let msg = "Only #[stable], #[cfg] and #[mutating] attributes are allowed.";
141 let span = item.span();
142 let mut attrs = item.attrs.clone();
143 attrs.retain(|a| !a.path().is_ident("doc"));
144 let mut is_stable = false;
145 let mut mutating = false;
146 let mut cfg = None;
147 while let Some(attr) = attrs.pop() {
148 let ident = attr.path().get_ident().ok_or(err(span, msg))?.to_string();
149 match ident.as_str() {
150 "stable" => {
151 if is_stable {
152 return Err(err(span, "#[stable] can only be specified once"));
153 }
154 is_stable = true;
155 }
156 "mutating" => {
157 if mutating {
158 return Err(err(span, "#[mutating] can only be specified once"));
159 }
160 mutating = true;
161 }
162 "cfg" => {
163 if cfg.is_some() {
164 return Err(err(span, "#[cfg] can only be specified once"));
165 }
166 cfg = Some(attr);
167 }
168 id => return Err(err(span, &format!("Unsupported attribute \"{id}\". {msg}"))),
169 }
170 }
171
172 if mutating {
173 let stmt = syn::parse_quote! {
174 return Err(SupervisorError::StateChangeDenied.into());
175 };
176 item.block.stmts.insert(0, stmt);
177 }
178
179 let name = item.sig.ident.to_string();
180
181 let msg = "Every function must start with these two parameters: &mut self, memory: &mut M";
182 let special_args = item
183 .sig
184 .inputs
185 .iter()
186 .take(2)
187 .enumerate()
188 .map(|(i, arg)| is_valid_special_arg(i, arg))
189 .fold(0u32, |acc, valid| if valid { acc + 1 } else { acc });
190
191 if special_args != 2 {
192 return Err(err(span, msg));
193 }
194
195 let msg = r#"Should return one of the following:
197 - Result<(), TrapReason>,
198 - Result<ReturnErrorCode, TrapReason>,
199 - Result<u32, TrapReason>,
200 - Result<u64, TrapReason>"#;
201 let ret_ty = match item.clone().sig.output {
202 syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
203 _ => Err(err(span, msg)),
204 }?;
205 match *ret_ty {
206 syn::Type::Path(tp) => {
207 let result = &tp.path.segments.last().ok_or(err(span, msg))?;
208 let (id, span) = (result.ident.to_string(), result.ident.span());
209 id.eq(&"Result".to_string())
210 .then_some(())
211 .ok_or(err(span, msg))?;
212
213 match &result.arguments {
214 syn::PathArguments::AngleBracketed(group) => {
215 if group.args.len() != 2 {
216 return Err(err(span, msg));
217 };
218
219 let arg2 = group.args.last().ok_or(err(span, msg))?;
220
221 let err_ty = match arg2 {
222 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
223 _ => Err(err(arg2.span(), msg)),
224 }?;
225
226 match err_ty {
227 syn::Type::Path(tp) => Ok(tp
228 .path
229 .segments
230 .first()
231 .ok_or(err(arg2.span(), msg))?
232 .ident
233 .to_string()),
234 _ => Err(err(tp.span(), msg)),
235 }?
236 .eq("TrapReason")
237 .then_some(())
238 .ok_or(err(span, msg))?;
239
240 let arg1 = group.args.first().ok_or(err(span, msg))?;
241 let ok_ty = match arg1 {
242 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
243 _ => Err(err(arg1.span(), msg)),
244 }?;
245 let ok_ty_str = match ok_ty {
246 syn::Type::Path(tp) => Ok(tp
247 .path
248 .segments
249 .first()
250 .ok_or(err(arg1.span(), msg))?
251 .ident
252 .to_string()),
253 syn::Type::Tuple(tt) => {
254 if !tt.elems.is_empty() {
255 return Err(err(arg1.span(), msg));
256 };
257 Ok("()".to_string())
258 }
259 _ => Err(err(ok_ty.span(), msg)),
260 }?;
261 let returns = match ok_ty_str.as_str() {
262 "()" => Ok(HostFnReturn::Unit),
263 "u32" => Ok(HostFnReturn::U32),
264 "u64" => Ok(HostFnReturn::U64),
265 "ReturnErrorCode" => Ok(HostFnReturn::ReturnCode),
266 _ => Err(err(arg1.span(), msg)),
267 }?;
268
269 Ok(Self {
270 item,
271 is_stable,
272 name,
273 returns,
274 cfg,
275 })
276 }
277 _ => Err(err(span, msg)),
278 }
279 }
280 _ => Err(err(span, msg)),
281 }
282 }
283}
284
285fn is_valid_special_arg(idx: usize, arg: &FnArg) -> bool {
286 match (idx, arg) {
287 (0, FnArg::Receiver(rec)) => rec.reference.is_some() && rec.mutability.is_some(),
288 (1, FnArg::Typed(pat)) => {
289 let ident = if let syn::Pat::Ident(ref ident) = *pat.pat {
290 &ident.ident
291 } else {
292 return false;
293 };
294 if !(ident == "memory" || ident == "_memory") {
295 return false;
296 }
297 matches!(*pat.ty, syn::Type::Reference(_))
298 }
299 _ => false,
300 }
301}
302
303fn arg_decoder<'a, P, I>(param_names: P, param_types: I) -> TokenStream2
304where
305 P: Iterator<Item = &'a std::boxed::Box<syn::Pat>> + Clone,
306 I: Iterator<Item = &'a std::boxed::Box<syn::Type>> + Clone,
307{
308 const ALLOWED_REGISTERS: usize = 6;
309
310 if param_names.clone().count() > ALLOWED_REGISTERS {
312 panic!("Syscalls take a maximum of {ALLOWED_REGISTERS} arguments");
313 }
314
315 if !param_types.clone().all(|ty| {
318 let syn::Type::Path(path) = &**ty else {
319 panic!("Type needs to be path");
320 };
321 let Some(ident) = path.path.get_ident() else {
322 panic!("Type needs to be ident");
323 };
324 matches!(ident.to_string().as_ref(), "u8" | "u16" | "u32" | "u64")
325 }) {
326 panic!("Only primitive unsigned integers are allowed as arguments to syscalls");
327 }
328
329 let bindings = param_names
331 .zip(param_types)
332 .enumerate()
333 .map(|(idx, (name, ty))| {
334 let reg = quote::format_ident!("__a{}__", idx);
335 quote! {
336 let #name = #reg as #ty;
337 }
338 });
339 quote! {
340 #( #bindings )*
341 }
342}
343
344fn expand_env(def: &EnvDef) -> TokenStream2 {
349 let impls = expand_functions(def);
350 let bench_impls = expand_bench_functions(def);
351 let docs = expand_func_doc(def);
352 let stable_syscalls = expand_func_list(def, false);
353 let all_syscalls = expand_func_list(def, true);
354
355 quote! {
356 pub fn list_syscalls(include_unstable: bool) -> &'static [&'static [u8]] {
357 if include_unstable {
358 #all_syscalls
359 } else {
360 #stable_syscalls
361 }
362 }
363
364 impl<'a, T: Config, H: PrecompileHandle, M: PolkaVmInstance> Runtime<'a, T, H, M> {
365 fn handle_ecall(
366 &mut self,
367 memory: &mut M,
368 __syscall_symbol__: &[u8],
369 ) -> Result<Option<u64>, TrapReason>
370 {
371 #impls
372 }
373 }
374
375 #[cfg(feature = "runtime-benchmarks")]
376 impl<'a, T: Config, H: PrecompileHandle, M: PolkaVmInstance> Runtime<'a, T, H, M> {
377 #bench_impls
378 }
379
380 #[cfg(doc)]
390 pub trait SyscallDoc {
391 #docs
392 }
393 }
394}
395
396fn expand_functions(def: &EnvDef) -> TokenStream2 {
397 let impls = def.host_funcs.iter().map(|f| {
398 let params = f.item.sig.inputs.iter().skip(2);
400 let param_names = params.clone().filter_map(|arg| {
401 let FnArg::Typed(arg) = arg else {
402 return None;
403 };
404 Some(&arg.pat)
405 });
406 let param_types = params.clone().filter_map(|arg| {
407 let FnArg::Typed(arg) = arg else {
408 return None;
409 };
410 Some(&arg.ty)
411 });
412 let arg_decoder = arg_decoder(param_names, param_types);
413 let cfg = &f.cfg;
414 let name = &f.name;
415 let syscall_symbol = Literal::byte_string(name.as_bytes());
416 let body = &f.item.block;
417 let map_output = f.returns.map_output();
418 let output = &f.item.sig.output;
419
420 let wrapped_body_with_trace = {
423 let trace_fmt_args = params.clone().filter_map(|arg| match arg {
424 syn::FnArg::Receiver(_) => None,
425 syn::FnArg::Typed(p) => match *p.pat.clone() {
426 syn::Pat::Ident(ref pat_ident) => Some(pat_ident.ident.clone()),
427 _ => None,
428 },
429 });
430
431 let params_fmt_str = trace_fmt_args
432 .clone()
433 .map(|s| format!("{s}: {{:?}}"))
434 .collect::<Vec<_>>()
435 .join(", ");
436 let trace_fmt_str = format!("{name}({params_fmt_str}) = {{:?}}");
437
438 quote! {
439 let result = (|| #body)();
441 ::log::trace!(target: LOG_TARGET, #trace_fmt_str, #( #trace_fmt_args, )* result);
442 result
443 }
444 };
445
446 quote! {
447 #cfg
448 #syscall_symbol => {
449 (|| #output {
451 #arg_decoder
452 #wrapped_body_with_trace
453 })().map(#map_output)
454 },
455 }
456 });
457
458 quote! {
459 self.charge_polkavm_gas(memory)?;
460
461 self.charge_gas(crate::vm::RuntimeCosts::HostFn).map_err(TrapReason::from)?;
463
464 let (__a0__, __a1__, __a2__, __a3__, __a4__, __a5__) = memory.read_input_regs();
466
467 let result = (|| match __syscall_symbol__ {
469 #( #impls )*
470 _ => Err(TrapReason::SupervisorError(SupervisorError::InvalidSyscall.into()))
471 })();
472
473 result
474 }
475}
476
477fn expand_bench_functions(def: &EnvDef) -> TokenStream2 {
478 let impls = def.host_funcs.iter().map(|f| {
479 let params = f.item.sig.inputs.iter().skip(2);
481 let cfg = &f.cfg;
482 let name = &f.name;
483 let body = &f.item.block;
484 let output = &f.item.sig.output;
485
486 let name = Ident::new(&format!("bench_{name}"), Span::call_site());
487 quote! {
488 #cfg
489 pub fn #name(&mut self, memory: &mut M, #(#params),*) #output {
490 #body
491 }
492 }
493 });
494
495 quote! {
496 #( #impls )*
497 }
498}
499
500fn expand_func_doc(def: &EnvDef) -> TokenStream2 {
501 let docs = def.host_funcs.iter().map(|func| {
502 let func_decl = {
504 let mut sig = func.item.sig.clone();
505 sig.inputs = sig
506 .inputs
507 .iter()
508 .skip(2)
509 .cloned()
510 .collect::<Punctuated<FnArg, Comma>>();
511 sig.output = func.returns.success_type();
512 sig.to_token_stream()
513 };
514 let func_doc = {
515 let func_docs = {
516 let docs = func
517 .item
518 .attrs
519 .iter()
520 .filter(|a| a.path().is_ident("doc"))
521 .map(|d| {
522 let docs = d.to_token_stream();
523 quote! { #docs }
524 });
525 quote! { #( #docs )* }
526 };
527 let availability = if func.is_stable {
528 let info = "\n# Stable API\nThis API is stable and will never change.";
529 quote! { #[doc = #info] }
530 } else {
531 let info =
532 "\n# Unstable API\nThis API is not standardized and only available for testing.";
533 quote! { #[doc = #info] }
534 };
535 quote! {
536 #func_docs
537 #availability
538 }
539 };
540 quote! {
541 #func_doc
542 #func_decl;
543 }
544 });
545
546 quote! {
547 #( #docs )*
548 }
549}
550
551fn expand_func_list(def: &EnvDef, include_unstable: bool) -> TokenStream2 {
552 let docs = def
553 .host_funcs
554 .iter()
555 .filter(|f| include_unstable || f.is_stable)
556 .map(|f| {
557 let name = Literal::byte_string(f.name.as_bytes());
558 quote! {
559 #name.as_slice()
560 }
561 });
562 let len = docs.clone().count();
563
564 quote! {
565 {
566 static FUNCS: [&[u8]; #len] = [#(#docs),*];
567 FUNCS.as_slice()
568 }
569 }
570}