Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some checks for capabilities #1551

Merged
merged 20 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions bindings/nodejs/lib/types/block/address.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class RestrictedAddress extends Address {
/**
* The allowed capabilities bitflags.
*/
private allowedCapabilities: HexEncodedString = '0x';
private allowedCapabilities?: HexEncodedString;
/**
* @param address An address.
*/
Expand All @@ -227,7 +227,7 @@ class RestrictedAddress extends Address {
allowedCapabilities.byteLength,
).toString('hex');
} else {
this.allowedCapabilities = '0x';
this.allowedCapabilities = undefined;
}
}

Expand All @@ -239,13 +239,20 @@ class RestrictedAddress extends Address {
}

getAllowedCapabilities(): Uint8Array {
return Uint8Array.from(
Buffer.from(this.allowedCapabilities.substring(2), 'hex'),
);
return this.allowedCapabilities !== undefined
? Uint8Array.from(
Buffer.from(this.allowedCapabilities.substring(2), 'hex'),
Thoralf-M marked this conversation as resolved.
Show resolved Hide resolved
)
: new Uint8Array();
}

toString(): string {
return this.address.toString() + this.allowedCapabilities.substring(2);
return (
this.address.toString() +
(this.allowedCapabilities !== undefined
? this.allowedCapabilities.substring(2)
: '')
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Transaction {

readonly allotments: ManaAllotment[];

private capabilities: HexEncodedString = '0x';
private capabilities?: HexEncodedString;

@Type(() => Payload, {
discriminator: PayloadDiscriminator,
Expand Down Expand Up @@ -91,7 +91,7 @@ class Transaction {
capabilities.byteLength,
).toString('hex');
} else {
this.capabilities = '0x';
this.capabilities = undefined;
}
}

Expand All @@ -102,9 +102,11 @@ class Transaction {

/** Get the capability bitflags of the transaction. */
getCapabilities(): Uint8Array {
return Uint8Array.from(
Buffer.from(this.capabilities.substring(2), 'hex'),
);
return this.capabilities !== undefined
? Uint8Array.from(
Buffer.from(this.capabilities.substring(2), 'hex'),
)
: new Uint8Array();
}
}

Expand Down
9 changes: 6 additions & 3 deletions bindings/python/iota_sdk/types/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from enum import IntEnum
from dataclasses import dataclass, field
from typing import Any, Dict, List, TypeAlias, Union
from typing import Any, Dict, List, Optional, TypeAlias, Union
from iota_sdk.types.common import HexStr, json


Expand Down Expand Up @@ -118,7 +118,7 @@ class RestrictedAddress:
allowed_capabilities: The allowed capabilities bitflags.
"""
address: Union[Ed25519Address, AccountAddress, NFTAddress]
allowed_capabilities: HexStr = field(default='0x', init=False)
allowed_capabilities: Optional[HexStr] = field(default=None, init=False)
type: int = field(default_factory=lambda: int(
AddressType.RESTRICTED), init=False)

Expand All @@ -127,7 +127,10 @@ def with_allowed_capabilities(self, capabilities: bytes):
Attributes:
capabilities: The allowed capabilities bitflags.
"""
self.allowed_capabilities = '0x' + capabilities.hex()
if any(c != 0 for c in capabilities):
self.allowed_capabilities = '0x' + capabilities.hex()
else:
self.allowed_capabilities = None


@json
Expand Down
7 changes: 5 additions & 2 deletions bindings/python/iota_sdk/types/essence.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RegularTransactionEssence:
outputs: List[Output]
context_inputs: Optional[List[ContextInput]] = None
allotments: Optional[List[ManaAllotment]] = None
capabilities: HexStr = field(default='0x', init=False)
capabilities: Optional[HexStr] = field(default=None, init=False)
payload: Optional[Payload] = None
type: int = field(
default_factory=lambda: int(EssenceType.RegularTransactionEssence),
Expand All @@ -62,7 +62,10 @@ def with_capabilities(self, capabilities: bytes):
Attributes:
capabilities: The transaction capabilities bitflags.
"""
self.capabilities = '0x' + capabilities.hex()
if any(c != 0 for c in capabilities):
self.capabilities = '0x' + capabilities.hex()
else:
self.capabilities = None


TransactionEssence: TypeAlias = RegularTransactionEssence
18 changes: 4 additions & 14 deletions sdk/src/types/block/address/restricted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,18 @@ pub type AddressCapabilities = Capabilities<AddressCapabilityFlag>;

#[cfg(feature = "serde")]
pub(crate) mod dto {
use alloc::boxed::Box;

use serde::{Deserialize, Serialize};

use super::*;
use crate::utils::serde::prefix_hex_bytes;

#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RestrictedAddressDto {
#[serde(rename = "type")]
kind: u8,
pub address: Address,
#[serde(with = "prefix_hex_bytes")]
pub allowed_capabilities: Box<[u8]>,
#[serde(default, skip_serializing_if = "AddressCapabilities::is_none")]
pub allowed_capabilities: AddressCapabilities,
}

impl core::ops::Deref for RestrictedAddressDto {
Expand All @@ -187,7 +184,7 @@ pub(crate) mod dto {
Self {
kind: RestrictedAddress::KIND,
address: value.address.clone(),
allowed_capabilities: value.allowed_capabilities.iter().copied().collect(),
allowed_capabilities: value.allowed_capabilities.clone(),
}
}
}
Expand All @@ -196,14 +193,7 @@ pub(crate) mod dto {
type Error = Error;

fn try_from(value: RestrictedAddressDto) -> Result<Self, Self::Error> {
Ok(
Self::new(value.address)?.with_allowed_capabilities(AddressCapabilities::from_bytes(
value
.allowed_capabilities
.try_into()
.map_err(Error::InvalidCapabilitiesCount)?,
)),
)
Ok(Self::new(value.address)?.with_allowed_capabilities(value.allowed_capabilities))
}
}

Expand Down
69 changes: 51 additions & 18 deletions sdk/src/types/block/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use core::marker::PhantomData;

use derive_more::Deref;
use packable::{
error::UnpackErrorExt,
error::{UnpackError, UnpackErrorExt},
prefix::{BoxedSlicePrefix, UnpackPrefixError},
Packable,
};

use crate::types::block::Error;

/// A list of bitflags that represent capabilities.
#[derive(Debug, Deref)]
#[repr(transparent)]
Expand All @@ -21,11 +23,17 @@ pub struct Capabilities<Flag> {
}

impl<Flag> Capabilities<Flag> {
pub(crate) fn from_bytes(bytes: BoxedSlicePrefix<u8, u8>) -> Self {
Self {
/// Try to create capabilities from serialized bytes. Bytes with trailing zeroes are invalid.
pub(crate) fn from_bytes(bytes: BoxedSlicePrefix<u8, u8>) -> Result<Self, Error> {
if let Some(idx) = bytes.iter().rposition(|c| 0.ne(c)) {
if idx + 1 < bytes.len() {
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
return Err(Error::TrailingCapabilityBytes);
}
}
Ok(Self {
bytes,
_flag: PhantomData,
}
})
}

/// Returns a [`Capabilities`] with every possible flag disabled.
Expand All @@ -37,6 +45,12 @@ impl<Flag> Capabilities<Flag> {
pub fn is_none(&self) -> bool {
self.iter().all(|b| 0.eq(b))
}

/// Disables every possible flag.
pub fn set_none(&mut self) -> &mut Self {
*self = Default::default();
self
}
}

impl<Flag: CapabilityFlag> Capabilities<Flag> {
Expand All @@ -60,12 +74,6 @@ impl<Flag: CapabilityFlag> Capabilities<Flag> {
self
}

/// Disables every possible flag.
pub fn set_none(&mut self) -> &mut Self {
*self = Default::default();
self
}

/// Enables a given flag.
pub fn add_capability(&mut self, flag: Flag) -> &mut Self {
if self.bytes.len() <= flag.index() {
Expand Down Expand Up @@ -173,25 +181,22 @@ impl<Flag: 'static> Packable for Capabilities<Flag> {
type UnpackVisitor = ();

fn pack<P: packable::packer::Packer>(&self, packer: &mut P) -> Result<(), P::Error> {
if !self.is_none() {
self.bytes.pack(packer)?;
} else {
0_u8.pack(packer)?;
}
self.bytes.pack(packer)?;
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
Ok(())
}

fn unpack<U: packable::unpacker::Unpacker, const VERIFY: bool>(
unpacker: &mut U,
visitor: &Self::UnpackVisitor,
) -> Result<Self, packable::error::UnpackError<Self::UnpackError, U::Error>> {
Ok(Self::from_bytes(
) -> Result<Self, UnpackError<Self::UnpackError, U::Error>> {
Self::from_bytes(
BoxedSlicePrefix::unpack::<_, VERIFY>(unpacker, visitor)
.map_packable_err(|e| match e {
UnpackPrefixError::Item(i) | UnpackPrefixError::Prefix(i) => i,
})
.coerce()?,
))
)
.map_err(UnpackError::Packable)
}
}

Expand All @@ -207,3 +212,31 @@ pub trait CapabilityFlag {
/// Returns an iterator over all flags.
fn all() -> Self::Iterator;
}

#[cfg(feature = "serde")]
mod serde {
use ::serde::{Deserialize, Serialize};

use super::*;

impl<Flag> Serialize for Capabilities<Flag> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
crate::utils::serde::boxed_slice_prefix_hex_bytes::serialize(&self.bytes, serializer)
}
}

impl<'de, Flag> Deserialize<'de> for Capabilities<Flag> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
Self::from_bytes(crate::utils::serde::boxed_slice_prefix_hex_bytes::deserialize(
deserializer,
)?)
.map_err(::serde::de::Error::custom)
}
}
}
2 changes: 2 additions & 0 deletions sdk/src/types/block/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pub enum Error {
created: EpochIndex,
target: EpochIndex,
},
TrailingCapabilityBytes,
}

#[cfg(feature = "std")]
Expand Down Expand Up @@ -397,6 +398,7 @@ impl fmt::Display for Error {
Self::InvalidEpochDelta { created, target } => {
write!(f, "invalid epoch delta: created {created}, target {target}")
}
Self::TrailingCapabilityBytes => write!(f, "capabilities bytes has trailing zeroes"),
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
24 changes: 8 additions & 16 deletions sdk/src/types/block/payload/signed_transaction/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,20 +532,14 @@ pub type TransactionCapabilities = Capabilities<TransactionCapabilityFlag>;

#[cfg(feature = "serde")]
pub(crate) mod dto {
use alloc::{
boxed::Box,
string::{String, ToString},
};
use alloc::string::{String, ToString};

use serde::{Deserialize, Serialize};

use super::*;
use crate::{
types::{
block::{mana::ManaAllotmentDto, output::dto::OutputDto, payload::dto::PayloadDto, Error},
TryFromDto,
},
utils::serde::prefix_hex_bytes,
use crate::types::{
block::{mana::ManaAllotmentDto, output::dto::OutputDto, payload::dto::PayloadDto, Error},
TryFromDto,
};

#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
Expand All @@ -556,8 +550,8 @@ pub(crate) mod dto {
pub context_inputs: Vec<ContextInput>,
pub inputs: Vec<Input>,
pub allotments: Vec<ManaAllotmentDto>,
#[serde(with = "prefix_hex_bytes")]
pub capabilities: Box<[u8]>,
#[serde(default, skip_serializing_if = "TransactionCapabilities::is_none")]
pub capabilities: TransactionCapabilities,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub payload: Option<PayloadDto>,
pub outputs: Vec<OutputDto>,
Expand All @@ -571,7 +565,7 @@ pub(crate) mod dto {
context_inputs: value.context_inputs().to_vec(),
inputs: value.inputs().to_vec(),
allotments: value.mana_allotments().iter().map(Into::into).collect(),
capabilities: value.capabilities().iter().copied().collect(),
capabilities: value.capabilities().clone(),
payload: match value.payload() {
Some(p @ Payload::TaggedData(_)) => Some(p.into()),
Some(_) => unimplemented!(),
Expand Down Expand Up @@ -607,9 +601,7 @@ pub(crate) mod dto {
.with_context_inputs(dto.context_inputs)
.with_inputs(dto.inputs)
.with_mana_allotments(mana_allotments)
.with_capabilities(Capabilities::from_bytes(
dto.capabilities.try_into().map_err(Error::InvalidCapabilitiesCount)?,
))
.with_capabilities(dto.capabilities)
.with_outputs(outputs);

builder = if let Some(p) = dto.payload {
Expand Down
Loading