pallet_evm_polkavm_proc_macro/
lib.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Procedural macros used in the contracts module.
19//!
20//! Most likely you should use the [`#[define_env]`][`macro@define_env`] attribute macro which hides
21//! boilerplate of defining external environment for a polkavm module.
22
23use proc_macro::TokenStream;
24use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
25use quote::{quote, ToTokens};
26use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg, Ident};
27
28#[proc_macro_attribute]
29pub fn unstable_hostfn(_attr: TokenStream, item: TokenStream) -> TokenStream {
30	let input = syn::parse_macro_input!(item as syn::Item);
31	let expanded = quote! {
32		#[cfg(feature = "unstable-hostfn")]
33		#[cfg_attr(docsrs, doc(cfg(feature = "unstable-hostfn")))]
34		#input
35	};
36	expanded.into()
37}
38
39/// Defines a host functions set that can be imported by contract polkavm code.
40///
41/// **CAUTION**: Be advised that all functions defined by this macro
42/// cause undefined behaviour inside the contract if the signature does not match.
43///
44/// WARNING: It is CRITICAL for contracts to make sure that the signatures match exactly.
45/// Failure to do so may result in undefined behavior, traps or security vulnerabilities inside the
46/// contract. The runtime itself is unharmed due to sandboxing.
47/// For example, if a function is called with an incorrect signature, it could lead to memory
48/// corruption or unexpected results within the contract.
49#[proc_macro_attribute]
50pub fn define_env(attr: TokenStream, item: TokenStream) -> TokenStream {
51	if !attr.is_empty() {
52		let msg = r#"Invalid `define_env` attribute macro: expected no attributes:
53					- `#[define_env]`"#;
54		let span = TokenStream2::from(attr).span();
55		return syn::Error::new(span, msg).to_compile_error().into();
56	}
57
58	let item = syn::parse_macro_input!(item as syn::ItemMod);
59
60	match EnvDef::try_from(item) {
61		Ok(def) => expand_env(&def).into(),
62		Err(e) => e.to_compile_error().into(),
63	}
64}
65
66/// Parsed environment definition.
67struct EnvDef {
68	host_funcs: Vec<HostFn>,
69}
70
71/// Parsed host function definition.
72struct HostFn {
73	item: syn::ItemFn,
74	is_stable: bool,
75	name: String,
76	returns: HostFnReturn,
77	cfg: Option<syn::Attribute>,
78}
79
80enum HostFnReturn {
81	Unit,
82	U32,
83	U64,
84	ReturnCode,
85}
86
87impl HostFnReturn {
88	fn map_output(&self) -> TokenStream2 {
89		match self {
90			Self::Unit => quote! { |_| None },
91			_ => quote! { |ret_val| Some(ret_val.into()) },
92		}
93	}
94
95	fn success_type(&self) -> syn::ReturnType {
96		match self {
97			Self::Unit => syn::ReturnType::Default,
98			Self::U32 => parse_quote! { -> u32 },
99			Self::U64 => parse_quote! { -> u64 },
100			Self::ReturnCode => parse_quote! { -> ReturnErrorCode },
101		}
102	}
103}
104
105impl EnvDef {
106	pub fn try_from(item: syn::ItemMod) -> syn::Result<Self> {
107		let span = item.span();
108		let err = |msg| syn::Error::new(span, msg);
109		let items = &item
110			.content
111			.as_ref()
112			.ok_or(err(
113				"Invalid environment definition, expected `mod` to be inlined.",
114			))?
115			.1;
116
117		let extract_fn = |i: &syn::Item| match i {
118			syn::Item::Fn(i_fn) => Some(i_fn.clone()),
119			_ => None,
120		};
121
122		let host_funcs = items
123			.iter()
124			.filter_map(extract_fn)
125			.map(HostFn::try_from)
126			.collect::<Result<Vec<_>, _>>()?;
127
128		Ok(Self { host_funcs })
129	}
130}
131
132impl HostFn {
133	pub fn try_from(mut item: syn::ItemFn) -> syn::Result<Self> {
134		let err = |span, msg| {
135			let msg = format!("Invalid host function definition.\n{msg}");
136			syn::Error::new(span, msg)
137		};
138
139		// process attributes
140		let msg = "Only #[stable], #[cfg] and #[mutating] attributes are allowed.";
141		let span = item.span();
142		let mut attrs = item.attrs.clone();
143		attrs.retain(|a| !a.path().is_ident("doc"));
144		let mut is_stable = false;
145		let mut mutating = false;
146		let mut cfg = None;
147		while let Some(attr) = attrs.pop() {
148			let ident = attr.path().get_ident().ok_or(err(span, msg))?.to_string();
149			match ident.as_str() {
150				"stable" => {
151					if is_stable {
152						return Err(err(span, "#[stable] can only be specified once"));
153					}
154					is_stable = true;
155				}
156				"mutating" => {
157					if mutating {
158						return Err(err(span, "#[mutating] can only be specified once"));
159					}
160					mutating = true;
161				}
162				"cfg" => {
163					if cfg.is_some() {
164						return Err(err(span, "#[cfg] can only be specified once"));
165					}
166					cfg = Some(attr);
167				}
168				id => return Err(err(span, &format!("Unsupported attribute \"{id}\". {msg}"))),
169			}
170		}
171
172		if mutating {
173			let stmt = syn::parse_quote! {
174				return Err(SupervisorError::StateChangeDenied.into());
175			};
176			item.block.stmts.insert(0, stmt);
177		}
178
179		let name = item.sig.ident.to_string();
180
181		let msg = "Every function must start with these two parameters: &mut self, memory: &mut M";
182		let special_args = item
183			.sig
184			.inputs
185			.iter()
186			.take(2)
187			.enumerate()
188			.map(|(i, arg)| is_valid_special_arg(i, arg))
189			.fold(0u32, |acc, valid| if valid { acc + 1 } else { acc });
190
191		if special_args != 2 {
192			return Err(err(span, msg));
193		}
194
195		// process return type
196		let msg = r#"Should return one of the following:
197				- Result<(), TrapReason>,
198				- Result<ReturnErrorCode, TrapReason>,
199				- Result<u32, TrapReason>,
200				- Result<u64, TrapReason>"#;
201		let ret_ty = match item.clone().sig.output {
202			syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
203			_ => Err(err(span, msg)),
204		}?;
205		match *ret_ty {
206			syn::Type::Path(tp) => {
207				let result = &tp.path.segments.last().ok_or(err(span, msg))?;
208				let (id, span) = (result.ident.to_string(), result.ident.span());
209				id.eq(&"Result".to_string())
210					.then_some(())
211					.ok_or(err(span, msg))?;
212
213				match &result.arguments {
214					syn::PathArguments::AngleBracketed(group) => {
215						if group.args.len() != 2 {
216							return Err(err(span, msg));
217						};
218
219						let arg2 = group.args.last().ok_or(err(span, msg))?;
220
221						let err_ty = match arg2 {
222							syn::GenericArgument::Type(ty) => Ok(ty.clone()),
223							_ => Err(err(arg2.span(), msg)),
224						}?;
225
226						match err_ty {
227							syn::Type::Path(tp) => Ok(tp
228								.path
229								.segments
230								.first()
231								.ok_or(err(arg2.span(), msg))?
232								.ident
233								.to_string()),
234							_ => Err(err(tp.span(), msg)),
235						}?
236						.eq("TrapReason")
237						.then_some(())
238						.ok_or(err(span, msg))?;
239
240						let arg1 = group.args.first().ok_or(err(span, msg))?;
241						let ok_ty = match arg1 {
242							syn::GenericArgument::Type(ty) => Ok(ty.clone()),
243							_ => Err(err(arg1.span(), msg)),
244						}?;
245						let ok_ty_str = match ok_ty {
246							syn::Type::Path(tp) => Ok(tp
247								.path
248								.segments
249								.first()
250								.ok_or(err(arg1.span(), msg))?
251								.ident
252								.to_string()),
253							syn::Type::Tuple(tt) => {
254								if !tt.elems.is_empty() {
255									return Err(err(arg1.span(), msg));
256								};
257								Ok("()".to_string())
258							}
259							_ => Err(err(ok_ty.span(), msg)),
260						}?;
261						let returns = match ok_ty_str.as_str() {
262							"()" => Ok(HostFnReturn::Unit),
263							"u32" => Ok(HostFnReturn::U32),
264							"u64" => Ok(HostFnReturn::U64),
265							"ReturnErrorCode" => Ok(HostFnReturn::ReturnCode),
266							_ => Err(err(arg1.span(), msg)),
267						}?;
268
269						Ok(Self {
270							item,
271							is_stable,
272							name,
273							returns,
274							cfg,
275						})
276					}
277					_ => Err(err(span, msg)),
278				}
279			}
280			_ => Err(err(span, msg)),
281		}
282	}
283}
284
285fn is_valid_special_arg(idx: usize, arg: &FnArg) -> bool {
286	match (idx, arg) {
287		(0, FnArg::Receiver(rec)) => rec.reference.is_some() && rec.mutability.is_some(),
288		(1, FnArg::Typed(pat)) => {
289			let ident = if let syn::Pat::Ident(ref ident) = *pat.pat {
290				&ident.ident
291			} else {
292				return false;
293			};
294			if !(ident == "memory" || ident == "_memory") {
295				return false;
296			}
297			matches!(*pat.ty, syn::Type::Reference(_))
298		}
299		_ => false,
300	}
301}
302
303fn arg_decoder<'a, P, I>(param_names: P, param_types: I) -> TokenStream2
304where
305	P: Iterator<Item = &'a std::boxed::Box<syn::Pat>> + Clone,
306	I: Iterator<Item = &'a std::boxed::Box<syn::Type>> + Clone,
307{
308	const ALLOWED_REGISTERS: usize = 6;
309
310	// too many arguments
311	if param_names.clone().count() > ALLOWED_REGISTERS {
312		panic!("Syscalls take a maximum of {ALLOWED_REGISTERS} arguments");
313	}
314
315	// all of them take one register but we truncate them before passing into the function
316	// it is important to not allow any type which has illegal bit patterns like 'bool'
317	if !param_types.clone().all(|ty| {
318		let syn::Type::Path(path) = &**ty else {
319			panic!("Type needs to be path");
320		};
321		let Some(ident) = path.path.get_ident() else {
322			panic!("Type needs to be ident");
323		};
324		matches!(ident.to_string().as_ref(), "u8" | "u16" | "u32" | "u64")
325	}) {
326		panic!("Only primitive unsigned integers are allowed as arguments to syscalls");
327	}
328
329	// one argument per register
330	let bindings = param_names
331		.zip(param_types)
332		.enumerate()
333		.map(|(idx, (name, ty))| {
334			let reg = quote::format_ident!("__a{}__", idx);
335			quote! {
336				let #name = #reg as #ty;
337			}
338		});
339	quote! {
340		#( #bindings )*
341	}
342}
343
344/// Expands environment definition.
345/// Should generate source code for:
346///  - implementations of the host functions to be added to the polkavm runtime environment (see
347///    `expand_impls()`).
348fn expand_env(def: &EnvDef) -> TokenStream2 {
349	let impls = expand_functions(def);
350	let bench_impls = expand_bench_functions(def);
351	let docs = expand_func_doc(def);
352	let stable_syscalls = expand_func_list(def, false);
353	let all_syscalls = expand_func_list(def, true);
354
355	quote! {
356		pub fn list_syscalls(include_unstable: bool) -> &'static [&'static [u8]] {
357			if include_unstable {
358				#all_syscalls
359			} else {
360				#stable_syscalls
361			}
362		}
363
364		impl<'a, T: Config, H: PrecompileHandle, M: PolkaVmInstance> Runtime<'a, T, H, M> {
365			fn handle_ecall(
366				&mut self,
367				memory: &mut M,
368				__syscall_symbol__: &[u8],
369			) -> Result<Option<u64>, TrapReason>
370			{
371				#impls
372			}
373		}
374
375		#[cfg(feature = "runtime-benchmarks")]
376		impl<'a, T: Config, H: PrecompileHandle, M: PolkaVmInstance> Runtime<'a, T, H, M> {
377			#bench_impls
378		}
379
380		/// Documentation of the syscalls (host functions) available to contracts.
381		///
382		/// Each of the functions in this trait represent a function that is callable
383		/// by the contract. Guests use the function name as the import symbol.
384		///
385		/// # Note
386		///
387		/// This module is not meant to be used by any code. Rather, it is meant to be
388		/// consumed by humans through rustdoc.
389		#[cfg(doc)]
390		pub trait SyscallDoc {
391			#docs
392		}
393	}
394}
395
396fn expand_functions(def: &EnvDef) -> TokenStream2 {
397	let impls = def.host_funcs.iter().map(|f| {
398		// skip the self and memory argument
399		let params = f.item.sig.inputs.iter().skip(2);
400		let param_names = params.clone().filter_map(|arg| {
401			let FnArg::Typed(arg) = arg else {
402				return None;
403			};
404			Some(&arg.pat)
405		});
406		let param_types = params.clone().filter_map(|arg| {
407			let FnArg::Typed(arg) = arg else {
408				return None;
409			};
410			Some(&arg.ty)
411		});
412		let arg_decoder = arg_decoder(param_names, param_types);
413		let cfg = &f.cfg;
414		let name = &f.name;
415		let syscall_symbol = Literal::byte_string(name.as_bytes());
416		let body = &f.item.block;
417		let map_output = f.returns.map_output();
418		let output = &f.item.sig.output;
419
420		// wrapped host function body call with host function traces
421		// see https://github.com/paritytech/polkadot-sdk/tree/master/substrate/frame/contracts#host-function-tracing
422		let wrapped_body_with_trace = {
423			let trace_fmt_args = params.clone().filter_map(|arg| match arg {
424				syn::FnArg::Receiver(_) => None,
425				syn::FnArg::Typed(p) => match *p.pat.clone() {
426					syn::Pat::Ident(ref pat_ident) => Some(pat_ident.ident.clone()),
427					_ => None,
428				},
429			});
430
431			let params_fmt_str = trace_fmt_args
432				.clone()
433				.map(|s| format!("{s}: {{:?}}"))
434				.collect::<Vec<_>>()
435				.join(", ");
436			let trace_fmt_str = format!("{name}({params_fmt_str}) = {{:?}}");
437
438			quote! {
439				// wrap body in closure to make sure the tracing is always executed
440				let result = (|| #body)();
441				::log::trace!(target: LOG_TARGET, #trace_fmt_str, #( #trace_fmt_args, )* result);
442				result
443			}
444		};
445
446		quote! {
447			#cfg
448			#syscall_symbol => {
449				// closure is needed so that "?" can infere the correct type
450				(|| #output {
451					#arg_decoder
452					#wrapped_body_with_trace
453				})().map(#map_output)
454			},
455		}
456	});
457
458	quote! {
459		self.charge_polkavm_gas(memory)?;
460
461		// This is the overhead to call an empty syscall that always needs to be charged.
462		self.charge_gas(crate::vm::RuntimeCosts::HostFn).map_err(TrapReason::from)?;
463
464		// They will be mapped to variable names by the syscall specific code.
465		let (__a0__, __a1__, __a2__, __a3__, __a4__, __a5__) = memory.read_input_regs();
466
467		// Execute the syscall specific logic in a closure so that the gas metering code is always executed.
468		let result = (|| match __syscall_symbol__ {
469			#( #impls )*
470			_ => Err(TrapReason::SupervisorError(SupervisorError::InvalidSyscall.into()))
471		})();
472
473		result
474	}
475}
476
477fn expand_bench_functions(def: &EnvDef) -> TokenStream2 {
478	let impls = def.host_funcs.iter().map(|f| {
479		// skip the context and memory argument
480		let params = f.item.sig.inputs.iter().skip(2);
481		let cfg = &f.cfg;
482		let name = &f.name;
483		let body = &f.item.block;
484		let output = &f.item.sig.output;
485
486		let name = Ident::new(&format!("bench_{name}"), Span::call_site());
487		quote! {
488			#cfg
489			pub fn #name(&mut self, memory: &mut M, #(#params),*) #output {
490				#body
491			}
492		}
493	});
494
495	quote! {
496		#( #impls )*
497	}
498}
499
500fn expand_func_doc(def: &EnvDef) -> TokenStream2 {
501	let docs = def.host_funcs.iter().map(|func| {
502		// Remove auxiliary args: `ctx: _` and `memory: _`
503		let func_decl = {
504			let mut sig = func.item.sig.clone();
505			sig.inputs = sig
506				.inputs
507				.iter()
508				.skip(2)
509				.cloned()
510				.collect::<Punctuated<FnArg, Comma>>();
511			sig.output = func.returns.success_type();
512			sig.to_token_stream()
513		};
514		let func_doc = {
515			let func_docs = {
516				let docs = func
517					.item
518					.attrs
519					.iter()
520					.filter(|a| a.path().is_ident("doc"))
521					.map(|d| {
522						let docs = d.to_token_stream();
523						quote! { #docs }
524					});
525				quote! { #( #docs )* }
526			};
527			let availability = if func.is_stable {
528				let info = "\n# Stable API\nThis API is stable and will never change.";
529				quote! { #[doc = #info] }
530			} else {
531				let info =
532				"\n# Unstable API\nThis API is not standardized and only available for testing.";
533				quote! { #[doc = #info] }
534			};
535			quote! {
536				#func_docs
537				#availability
538			}
539		};
540		quote! {
541			#func_doc
542			#func_decl;
543		}
544	});
545
546	quote! {
547		#( #docs )*
548	}
549}
550
551fn expand_func_list(def: &EnvDef, include_unstable: bool) -> TokenStream2 {
552	let docs = def
553		.host_funcs
554		.iter()
555		.filter(|f| include_unstable || f.is_stable)
556		.map(|f| {
557			let name = Literal::byte_string(f.name.as_bytes());
558			quote! {
559				#name.as_slice()
560			}
561		});
562	let len = docs.clone().count();
563
564	quote! {
565		{
566			static FUNCS: [&[u8]; #len] = [#(#docs),*];
567			FUNCS.as_slice()
568		}
569	}
570}