From e6b0b0f35bc9dda1d3a02d314d9acbd1196a27ca Mon Sep 17 00:00:00 2001
From: Pavlo Khrystenko
Date: Wed, 11 Sep 2024 10:05:12 +0200
Subject: [PATCH] Add limit for codec indexes
closes #507
---
README.md | 2 +-
derive/src/decode.rs | 23 +++++---
derive/src/encode.rs | 27 ++++++----
derive/src/utils.rs | 113 +++++++++++++++++++++++++++++++---------
tests/variant_number.rs | 4 +-
5 files changed, 125 insertions(+), 44 deletions(-)
diff --git a/README.md b/README.md
index 8b079051..69d32d40 100644
--- a/README.md
+++ b/README.md
@@ -216,7 +216,7 @@ The derive implementation supports the following attributes:
- `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being
encoded by using `OtherType`.
- `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given
- index when encoded. By default the index is determined by counting from `0` beginning wth the
+ index when encoded. By default the index is determined by counting from `0` beginning with the
first variant.
- `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take
in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for
diff --git a/derive/src/decode.rs b/derive/src/decode.rs
index 7f2d08b2..59a42ce0 100644
--- a/derive/src/decode.rs
+++ b/derive/src/decode.rs
@@ -15,7 +15,7 @@
use proc_macro2::{Ident, Span, TokenStream};
use syn::{spanned::Spanned, Data, Error, Field, Fields};
-use crate::utils;
+use crate::utils::{self, UsedIndexes};
/// Generate function block for function `Decode::decode`.
///
@@ -57,9 +57,17 @@ pub fn quote(
.to_compile_error();
}
- let recurse = data_variants().enumerate().map(|(i, v)| {
+ let mut used_indexes = match UsedIndexes::from_iter(data_variants()) {
+ Ok(index) => index,
+ Err(e) => return e.into_compile_error(),
+ };
+ let mut items = vec![];
+ for v in data_variants() {
let name = &v.ident;
- let index = utils::variant_index(v, i);
+ let index = match used_indexes.variant_index(v) {
+ Ok(index) => index,
+ Err(e) => return e.into_compile_error(),
+ };
let create = create_instance(
quote! { #type_name #type_generics :: #name },
@@ -69,7 +77,7 @@ pub fn quote(
crate_path,
);
- quote_spanned! { v.span() =>
+ let item = quote_spanned! { v.span() =>
#[allow(clippy::unnecessary_cast)]
__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
// NOTE: This lambda is necessary to work around an upstream bug
@@ -80,8 +88,9 @@ pub fn quote(
#create
})();
},
- }
- });
+ };
+ items.push(item);
+ }
let read_byte_err_msg =
format!("Could not decode `{type_name}`, failed to read variant byte");
@@ -91,7 +100,7 @@ pub fn quote(
match #input.read_byte()
.map_err(|e| e.chain(#read_byte_err_msg))?
{
- #( #recurse )*
+ #( #items )*
_ => {
#[allow(clippy::redundant_closure_call)]
return (move || {
diff --git a/derive/src/encode.rs b/derive/src/encode.rs
index 142bb439..2c3600ea 100644
--- a/derive/src/encode.rs
+++ b/derive/src/encode.rs
@@ -17,7 +17,7 @@ use std::str::from_utf8;
use proc_macro2::{Ident, Span, TokenStream};
use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, Data, Error, Field, Fields};
-use crate::utils;
+use crate::{utils, utils::UsedIndexes};
type FieldsList = Punctuated;
@@ -313,12 +313,18 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
if data_variants().count() == 0 {
return quote!();
}
-
- let recurse = data_variants().enumerate().map(|(i, f)| {
+ let mut used_indexes = match UsedIndexes::from_iter(data_variants()) {
+ Ok(index) => index,
+ Err(e) => return e.into_compile_error(),
+ };
+ let mut items = vec![];
+ for f in data_variants() {
let name = &f.ident;
- let index = utils::variant_index(f, i);
-
- match f.fields {
+ let index = match used_indexes.variant_index(f) {
+ Ok(index) => index,
+ Err(e) => return e.into_compile_error(),
+ };
+ let item = match f.fields {
Fields::Named(ref fields) => {
let fields = &fields.named;
let field_name = |_, ident: &Option| quote!(#ident);
@@ -396,11 +402,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
[hinting, encoding]
},
- }
- });
+ };
+ items.push(item)
+ }
- let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
- let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
+ let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
+ let recurse_encoding = items.iter().map(|[_, encoding]| encoding);
let hinting = quote! {
// The variant index uses 1 byte.
diff --git a/derive/src/utils.rs b/derive/src/utils.rs
index 091a45ee..07f89d16 100644
--- a/derive/src/utils.rs
+++ b/derive/src/utils.rs
@@ -17,13 +17,14 @@
//! NOTE: attributes finder must be checked using check_attribute first,
//! otherwise the macro can panic.
-use std::str::FromStr;
+use std::{collections::HashSet, str::FromStr};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput,
- Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant,
+ ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path,
+ Variant,
};
fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option
@@ -37,32 +38,96 @@ where
})
}
-/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
-/// is found, fall back to the discriminant or just the variant index.
-pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
- // first look for an attribute
- let index = find_meta_item(v.attrs.iter(), |meta| {
- if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
- if nv.path.is_ident("index") {
- if let Lit::Int(ref v) = nv.lit {
- let byte = v
- .base10_parse::()
- .expect("Internal error, index attribute must have been checked");
- return Some(byte);
+pub struct UsedIndexes {
+ used_set: HashSet,
+ current: u8,
+}
+
+impl UsedIndexes {
+ /// Build a Set of used indexes for use with #[scale(index = $int)] attribute on variant
+ pub fn from_iter<'a, I: Iterator- >(values: I) -> syn::Result {
+ let mut set = HashSet::new();
+ for (i, v) in values.enumerate() {
+ if let Some((index, nv)) = find_meta_item(v.attrs.iter(), |meta| {
+ if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
+ if nv.path.is_ident("index") {
+ if let Lit::Int(ref v) = nv.lit {
+ let byte = v
+ .base10_parse::()
+ .expect("Internal error, index attribute must have been checked");
+ return Some((byte, nv.span()));
+ }
+ }
+ }
+ None
+ }) {
+ if !set.insert(index) {
+ return Err(syn::Error::new(nv.span(), "Duplicate variant index. qed"))
+ }
+ set.insert(i.try_into().expect("Will never happen. qed"));
+ } else {
+ match v.discriminant.as_ref() {
+ Some((
+ _,
+ expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }),
+ )) => {
+ let index = lit_int
+ .base10_parse::()
+ .expect("Internal error, index attribute must have been checked");
+ if !set.insert(index) {
+ return Err(syn::Error::new(expr.span(), "Duplicate variant index. qed"))
+ }
+ set.insert(i.try_into().expect("Will never happen. qed"));
+ },
+ _ => (),
}
}
}
+ Ok(Self { current: 0, used_set: set })
+ }
- None
- });
-
- // then fallback to discriminant or just index
- index.map(|i| quote! { #i }).unwrap_or_else(|| {
- v.discriminant
- .as_ref()
- .map(|(_, expr)| quote! { #expr })
- .unwrap_or_else(|| quote! { #i })
- })
+ /// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
+ /// is found, fall back to the discriminant or just the variant index.
+ pub fn variant_index(&mut self, v: &Variant) -> syn::Result {
+ // first look for an attribute
+ let index = find_meta_item(v.attrs.iter(), |meta| {
+ if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
+ if nv.path.is_ident("index") {
+ if let Lit::Int(ref v) = nv.lit {
+ let byte = v
+ .base10_parse::()
+ .expect("Internal error, index attribute must have been checked");
+ return Some(byte);
+ }
+ }
+ }
+
+ None
+ });
+
+ index.map_or_else(
+ || match v.discriminant.as_ref() {
+ Some((_, expr)) => return Ok(quote! { #expr }),
+ None => {
+ let idx = self.next_index();
+ return Ok(quote! { #idx })
+ },
+ },
+ |i| Ok(quote! { #i }),
+ )
+ }
+
+ fn next_index(&mut self) -> u8 {
+ loop {
+ if self.used_set.contains(&self.current) {
+ self.current += 1;
+ } else {
+ let index = self.current;
+ self.current += 1;
+ return index
+ }
+ }
+ }
}
/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
diff --git a/tests/variant_number.rs b/tests/variant_number.rs
index 9bdaba0a..c38fd748 100644
--- a/tests/variant_number.rs
+++ b/tests/variant_number.rs
@@ -10,7 +10,7 @@ fn discriminant_variant_counted_in_default_index() {
}
assert_eq!(T::A.encode(), vec![1]);
- assert_eq!(T::B.encode(), vec![1]);
+ assert_eq!(T::B.encode(), vec![2]);
}
#[test]
@@ -36,5 +36,5 @@ fn index_attr_variant_counted_and_reused_in_default_index() {
}
assert_eq!(T::A.encode(), vec![1]);
- assert_eq!(T::B.encode(), vec![1]);
+ assert_eq!(T::B.encode(), vec![2]);
}