pallet_evm_precompile_curve25519/
lib.rs

1// This file is part of Frontier.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18#![cfg_attr(not(feature = "std"), no_std)]
19#![warn(unused_crate_dependencies)]
20
21extern crate alloc;
22
23use alloc::vec::Vec;
24use core::marker::PhantomData;
25use curve25519_dalek::{
26	ristretto::{CompressedRistretto, RistrettoPoint},
27	scalar::Scalar,
28	traits::Identity,
29};
30use fp_evm::{
31	ExitError, ExitSucceed, Precompile, PrecompileFailure, PrecompileHandle, PrecompileOutput,
32	PrecompileResult,
33};
34use frame_support::weights::Weight;
35use pallet_evm::GasWeightMapping;
36
37// Weight provider trait expected by these precompiles. Implementations should return Substrate Weights.
38pub trait WeightInfo {
39	fn curve25519_add_n_points(n: u32) -> Weight;
40	fn curve25519_scaler_mul() -> Weight;
41}
42
43// Default weights from benchmarks run on a laptop, do not use them in production !
44impl WeightInfo for () {
45	/// The range of component `n` is `[1, 10]`.
46	fn curve25519_add_n_points(n: u32) -> Weight {
47		// Proof Size summary in bytes:
48		//  Measured:  `0`
49		//  Estimated: `0`
50		// Minimum execution time: 10_000_000 picoseconds.
51		Weight::from_parts(5_399_134, 0)
52			.saturating_add(Weight::from_parts(0, 0))
53			// Standard Error: 8_395
54			.saturating_add(Weight::from_parts(5_153_957, 0).saturating_mul(n.into()))
55	}
56	fn curve25519_scaler_mul() -> Weight {
57		// Proof Size summary in bytes:
58		//  Measured:  `0`
59		//  Estimated: `0`
60		// Minimum execution time: 81_000_000 picoseconds.
61		Weight::from_parts(87_000_000, 0).saturating_add(Weight::from_parts(0, 0))
62	}
63}
64
65// Adds at most 10 curve25519 points and returns the CompressedRistretto bytes representation
66pub struct Curve25519Add<R, WI>(PhantomData<(R, WI)>);
67
68impl<R, WI> Precompile for Curve25519Add<R, WI>
69where
70	R: pallet_evm::Config,
71	WI: WeightInfo,
72{
73	fn execute(handle: &mut impl PrecompileHandle) -> PrecompileResult {
74		let n_points = (handle.input().len() / 32) as u32;
75		let weight = WI::curve25519_add_n_points(n_points);
76		let gas = R::GasWeightMapping::weight_to_gas(weight);
77		handle.record_cost(gas)?;
78		let (exit_status, output) = Self::execute_inner(handle.input(), gas)?;
79		Ok(PrecompileOutput {
80			exit_status,
81			output,
82		})
83	}
84}
85
86impl<R, WI> Curve25519Add<R, WI>
87where
88	WI: WeightInfo,
89{
90	pub fn execute_inner(
91		input: &[u8],
92		_: u64,
93	) -> Result<(ExitSucceed, Vec<u8>), PrecompileFailure> {
94		if input.len() % 32 != 0 {
95			return Err(PrecompileFailure::Error {
96				exit_status: ExitError::Other("input must contain multiple of 32 bytes".into()),
97			});
98		};
99
100		if input.len() > 320 {
101			return Err(PrecompileFailure::Error {
102				exit_status: ExitError::Other(
103					"input cannot be greater than 320 bytes (10 compressed points)".into(),
104				),
105			});
106		};
107
108		let mut points = Vec::new();
109		let mut temp_buf = <&[u8]>::clone(&input);
110		while !temp_buf.is_empty() {
111			let mut buf = [0; 32];
112			buf.copy_from_slice(&temp_buf[0..32]);
113			let point = CompressedRistretto(buf);
114			points.push(point);
115			temp_buf = &temp_buf[32..];
116		}
117
118		let sum = points.iter().try_fold(
119			RistrettoPoint::identity(),
120			|acc, point| -> Result<RistrettoPoint, PrecompileFailure> {
121				let pt = point.decompress().ok_or_else(|| PrecompileFailure::Error {
122					exit_status: ExitError::Other("invalid compressed Ristretto point".into()),
123				})?;
124				Ok(acc + pt)
125			},
126		)?;
127
128		Ok((ExitSucceed::Returned, sum.compress().to_bytes().to_vec()))
129	}
130}
131
132// Multiplies a scalar field element with an elliptic curve point
133pub struct Curve25519ScalarMul<R, WI>(PhantomData<(R, WI)>);
134
135impl<R, WI> Precompile for Curve25519ScalarMul<R, WI>
136where
137	R: pallet_evm::Config,
138	WI: WeightInfo,
139{
140	fn execute(handle: &mut impl PrecompileHandle) -> PrecompileResult {
141		let weight = WI::curve25519_scaler_mul();
142		let gas = R::GasWeightMapping::weight_to_gas(weight);
143		handle.record_cost(gas)?;
144		let (exit_status, output) = Self::execute_inner(handle.input(), gas)?;
145		Ok(PrecompileOutput {
146			exit_status,
147			output,
148		})
149	}
150}
151
152impl<R, WI> Curve25519ScalarMul<R, WI>
153where
154	WI: WeightInfo,
155{
156	pub fn execute_inner(
157		input: &[u8],
158		_: u64,
159	) -> Result<(ExitSucceed, Vec<u8>), PrecompileFailure> {
160		if input.len() != 64 {
161			return Err(PrecompileFailure::Error {
162				exit_status: ExitError::Other(
163					"input must contain 64 bytes (scalar - 32 bytes, point - 32 bytes)".into(),
164				),
165			});
166		};
167
168		// first 32 bytes is for the scalar value
169		let mut scalar_buf = [0; 32];
170		scalar_buf.copy_from_slice(&input[0..32]);
171		let scalar = Scalar::from_bytes_mod_order(scalar_buf);
172
173		// second 32 bytes is for the compressed ristretto point bytes
174		let mut pt_buf = [0; 32];
175		pt_buf.copy_from_slice(&input[32..64]);
176		let point =
177			CompressedRistretto(pt_buf)
178				.decompress()
179				.ok_or_else(|| PrecompileFailure::Error {
180					exit_status: ExitError::Other("invalid compressed Ristretto point".into()),
181				})?;
182
183		let scalar_mul = scalar * point;
184		Ok((
185			ExitSucceed::Returned,
186			scalar_mul.compress().to_bytes().to_vec(),
187		))
188	}
189}
190
191#[cfg(test)]
192mod tests {
193	use super::*;
194	use curve25519_dalek::constants;
195
196	#[test]
197	fn test_sum() -> Result<(), PrecompileFailure> {
198		let s1 = Scalar::from(999u64);
199		let p1 = constants::RISTRETTO_BASEPOINT_POINT * s1;
200
201		let s2 = Scalar::from(333u64);
202		let p2 = constants::RISTRETTO_BASEPOINT_POINT * s2;
203
204		let vec = vec![p1, p2];
205		let mut input = vec![];
206		input.extend_from_slice(&p1.compress().to_bytes());
207		input.extend_from_slice(&p2.compress().to_bytes());
208
209		let sum: RistrettoPoint = vec.iter().sum();
210		let cost: u64 = 1;
211
212		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
213			Ok((_, out)) => {
214				assert_eq!(out, sum.compress().to_bytes());
215				Ok(())
216			}
217			Err(e) => {
218				panic!("Test not expected to fail: {e:?}");
219			}
220		}
221	}
222
223	#[test]
224	fn test_empty() -> Result<(), PrecompileFailure> {
225		// Test that sum works for the empty iterator
226		let input = vec![];
227
228		let cost: u64 = 1;
229
230		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
231			Ok((_, out)) => {
232				assert_eq!(out, RistrettoPoint::identity().compress().to_bytes());
233				Ok(())
234			}
235			Err(e) => {
236				panic!("Test not expected to fail: {e:?}");
237			}
238		}
239	}
240
241	#[test]
242	fn test_scalar_mul() -> Result<(), PrecompileFailure> {
243		let s1 = Scalar::from(999u64);
244		let s2 = Scalar::from(333u64);
245		let p1 = constants::RISTRETTO_BASEPOINT_POINT * s1;
246		let p2 = constants::RISTRETTO_BASEPOINT_POINT * s2;
247
248		let mut input = vec![];
249		input.extend_from_slice(&s1.to_bytes());
250		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
251
252		let cost: u64 = 1;
253
254		match Curve25519ScalarMul::<(), ()>::execute_inner(&input, cost) {
255			Ok((_, out)) => {
256				assert_eq!(out, p1.compress().to_bytes());
257				assert_ne!(out, p2.compress().to_bytes());
258				Ok(())
259			}
260			Err(e) => {
261				panic!("Test not expected to fail: {e:?}");
262			}
263		}
264	}
265
266	#[test]
267	fn test_scalar_mul_empty_error() -> Result<(), PrecompileFailure> {
268		let input = vec![];
269
270		let cost: u64 = 1;
271
272		match Curve25519ScalarMul::<(), ()>::execute_inner(&input, cost) {
273			Ok((_, _out)) => {
274				panic!("Test not expected to work");
275			}
276			Err(e) => {
277				assert_eq!(
278					e,
279					PrecompileFailure::Error {
280						exit_status: ExitError::Other(
281							"input must contain 64 bytes (scalar - 32 bytes, point - 32 bytes)"
282								.into()
283						)
284					}
285				);
286				Ok(())
287			}
288		}
289	}
290
291	#[test]
292	fn test_point_addition_bad_length() -> Result<(), PrecompileFailure> {
293		let input: Vec<u8> = [0u8; 33].to_vec();
294
295		let cost: u64 = 1;
296
297		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
298			Ok((_, _out)) => {
299				panic!("Test not expected to work");
300			}
301			Err(e) => {
302				assert_eq!(
303					e,
304					PrecompileFailure::Error {
305						exit_status: ExitError::Other(
306							"input must contain multiple of 32 bytes".into()
307						)
308					}
309				);
310				Ok(())
311			}
312		}
313	}
314
315	#[test]
316	fn test_point_addition_too_many_points() -> Result<(), PrecompileFailure> {
317		let mut input = vec![];
318		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 1
319		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 2
320		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 3
321		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 4
322		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 5
323		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 6
324		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 7
325		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 8
326		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 9
327		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 10
328		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 11
329
330		let cost: u64 = 1;
331
332		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
333			Ok((_, _out)) => {
334				panic!("Test not expected to work");
335			}
336			Err(e) => {
337				assert_eq!(
338					e,
339					PrecompileFailure::Error {
340						exit_status: ExitError::Other(
341							"input cannot be greater than 320 bytes (10 compressed points)".into()
342						)
343					}
344				);
345				Ok(())
346			}
347		}
348	}
349
350	#[test]
351	fn test_point_addition_invalid_point() -> Result<(), PrecompileFailure> {
352		// Create an invalid compressed Ristretto point
353		// Using a pattern that's definitely invalid for Ristretto compression
354		let mut invalid_point = [0u8; 32];
355		invalid_point[31] = 0xFF; // Set the last byte to 0xFF, which is invalid for Ristretto
356		let mut input = vec![];
357		input.extend_from_slice(&invalid_point);
358
359		let cost: u64 = 1;
360
361		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
362			Ok((_, _out)) => {
363				panic!("Test not expected to work with invalid point");
364			}
365			Err(e) => {
366				assert_eq!(
367					e,
368					PrecompileFailure::Error {
369						exit_status: ExitError::Other("invalid compressed Ristretto point".into())
370					}
371				);
372				Ok(())
373			}
374		}
375	}
376
377	#[test]
378	fn test_scalar_mul_invalid_point() -> Result<(), PrecompileFailure> {
379		// Create an invalid compressed Ristretto point
380		// Using a pattern that's definitely invalid for Ristretto compression
381		let mut invalid_point = [0u8; 32];
382		invalid_point[31] = 0xFF; // Set the last byte to 0xFF, which is invalid for Ristretto
383		let scalar = [1u8; 32];
384		let mut input = vec![];
385		input.extend_from_slice(&scalar);
386		input.extend_from_slice(&invalid_point);
387
388		let cost: u64 = 1;
389
390		match Curve25519ScalarMul::<(), ()>::execute_inner(&input, cost) {
391			Ok((_, _out)) => {
392				panic!("Test not expected to work with invalid point");
393			}
394			Err(e) => {
395				assert_eq!(
396					e,
397					PrecompileFailure::Error {
398						exit_status: ExitError::Other("invalid compressed Ristretto point".into())
399					}
400				);
401				Ok(())
402			}
403		}
404	}
405
406	#[test]
407	fn test_point_addition_mixed_valid_invalid() -> Result<(), PrecompileFailure> {
408		// Create a mix of valid and invalid points
409		let valid_point = constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes();
410		let mut invalid_point = [0u8; 32];
411		invalid_point[31] = 0xFF; // Set the last byte to 0xFF, which is invalid for Ristretto
412		let mut input = vec![];
413		input.extend_from_slice(&valid_point);
414		input.extend_from_slice(&invalid_point);
415
416		let cost: u64 = 1;
417
418		match Curve25519Add::<(), ()>::execute_inner(&input, cost) {
419			Ok((_, _out)) => {
420				panic!("Test not expected to work with invalid point");
421			}
422			Err(e) => {
423				assert_eq!(
424					e,
425					PrecompileFailure::Error {
426						exit_status: ExitError::Other("invalid compressed Ristretto point".into())
427					}
428				);
429				Ok(())
430			}
431		}
432	}
433}