Skip to content

Commit

Permalink
Merge pull request #26 from Bond-Protocol/rsa-tests
Browse files Browse the repository at this point in the history
Encryption Updates and Tests
  • Loading branch information
0xJem authored Jan 26, 2024
2 parents 6e58aa0 + a7c0fc6 commit 35d515c
Show file tree
Hide file tree
Showing 27 changed files with 330 additions and 258 deletions.
18 changes: 9 additions & 9 deletions src/AuctionHouse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ abstract contract Router is FeeManager {
///
/// @param params_ Bid parameters
/// @return bidId Bid ID
function bid(BidParams memory params_) external virtual returns (uint256 bidId);
function bid(BidParams memory params_) external virtual returns (uint96 bidId);

/// @notice Cancel a bid on a lot in a batch auction
/// @dev The implementing function must perform the following:
Expand All @@ -139,7 +139,7 @@ abstract contract Router is FeeManager {
///
/// @param lotId_ Lot ID
/// @param bidId_ Bid ID
function cancelBid(uint96 lotId_, uint256 bidId_) external virtual;
function cancelBid(uint96 lotId_, uint96 bidId_) external virtual;

/// @notice Settle a batch auction
/// @notice This function is used for versions with on-chain storage and bids and local settlement
Expand Down Expand Up @@ -261,7 +261,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @return bool True if caller is allowed to purchase/bid on the lot
function _isAllowed(
IAllowlist allowlist_,
uint256 lotId_,
uint96 lotId_,
address caller_,
bytes memory allowlistProof_
) internal view returns (bool) {
Expand Down Expand Up @@ -356,7 +356,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
external
override
isLotValid(params_.lotId)
returns (uint256)
returns (uint96)
{
// Load routing data for the lot
Routing memory routing = lotRouting[params_.lotId];
Expand All @@ -368,7 +368,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {

// Record the bid on the auction module
// The module will determine if the bid is valid - minimum bid size, minimum price, auction status, etc
uint256 bidId;
uint96 bidId;
{
AuctionModule module = _getModuleForId(params_.lotId);
bidId = module.bid(
Expand Down Expand Up @@ -398,7 +398,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @dev This function reverts if:
/// - the lot ID is invalid
/// - the auction module reverts when cancelling the bid
function cancelBid(uint96 lotId_, uint256 bidId_) external override isLotValid(lotId_) {
function cancelBid(uint96 lotId_, uint96 bidId_) external override isLotValid(lotId_) {
// Cancel the bid on the auction module
// The auction module is responsible for validating the bid and authorizing the caller
AuctionModule module = _getModuleForId(lotId_);
Expand Down Expand Up @@ -505,7 +505,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param hooks_ Hooks contract to call (optional)
/// @param permit2Approval_ Permit2 approval data (optional)
function _collectPayment(
uint256 lotId_,
uint96 lotId_,
uint256 amount_,
ERC20 quoteToken_,
IHooks hooks_,
Expand Down Expand Up @@ -574,7 +574,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param payoutAmount_ Amount of payoutToken to collect (in native decimals)
/// @param routingParams_ Routing parameters for the lot
function _collectPayout(
uint256 lotId_,
uint96 lotId_,
uint256 paymentAmount_,
uint256 payoutAmount_,
Routing memory routingParams_
Expand Down Expand Up @@ -634,7 +634,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param routingParams_ Routing parameters for the lot
/// @param auctionOutput_ Custom data returned by the auction module
function _sendPayout(
uint256 lotId_,
uint96 lotId_,
address recipient_,
uint256 payoutAmount_,
Routing memory routingParams_,
Expand Down
6 changes: 3 additions & 3 deletions src/interfaces/IHooks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ interface IHooks {

/// @notice Called before payment and payout
/// TODO define expected state, invariants
function pre(uint256 lotId_, uint256 amount_) external;
function pre(uint96 lotId_, uint256 amount_) external;

/// @notice Called after payment and before payout
/// TODO define expected state, invariants
function mid(uint256 lotId_, uint256 amount_, uint256 payout_) external;
function mid(uint96 lotId_, uint256 amount_, uint256 payout_) external;

/// @notice Called after payment and after payout
/// TODO define expected state, invariants
function post(uint256 lotId_, uint256 payout_) external;
function post(uint96 lotId_, uint256 payout_) external;
}
135 changes: 58 additions & 77 deletions src/lib/RSA.sol
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ library RSAOAEP {
bytes memory label
) internal view returns (bytes memory message, bytes32 seed) {
// Implements 7.1.2 RSAES-OAEP-DECRYPT as defined in RFC8017: https://www.rfc-editor.org/rfc/rfc8017
// Error messages are intentionally vague to prevent oracle attacks

// 1. Input length validation
// 1. a. If the length of L is greater than the input limitation
Expand Down Expand Up @@ -57,36 +56,38 @@ library RSAOAEP {

// 3. b. Separate encoded message into Y (1 byte) | maskedSeed (32 bytes) | maskedDB (cLen - 32 - 1)
bytes1 y = bytes1(encoded);
bytes32 maskedSeed;
uint256 words = (cLen - 33) / 32 + ((cLen - 33) % 32 == 0 ? 0 : 1);
bytes memory maskedDb = new bytes(cLen - 33);

assembly {
bytes memory db;
{
// Scope these local variables to avoid stack too deep later
bytes32 maskedSeed;
// uint256 words = ((cLen - 33) / 32) + (((cLen - 33) % 32) == 0 ? 0 : 1);
uint256 maskLen = cLen - 33;
bytes memory maskedDb = new bytes(maskLen);

// Load a word from the encoded string starting at the 2nd byte (also have to account for length stored in first slot)
maskedSeed := mload(add(encoded, 0x21))
assembly {
maskedSeed := mload(add(encoded, 0x21))
}

// Store the remaining bytes into the maskedDb
for { let i := 0 } lt(i, words) { i := add(i, 1) } {
mstore(
add(add(maskedDb, 0x20), mul(i, 0x20)),
mload(add(add(encoded, 0x41), mul(i, 0x20)))
)
for (uint256 i; i < maskLen; i++) {
maskedDb[i] = encoded[i + 33];
}
}

// 3. c. Calculate seed mask
// 3. d. Calculate seed
{
bytes32 seedMask = bytes32(_mgf(maskedDb, 32));
seed = maskedSeed ^ seedMask;
}
// 3. c. Calculate seed mask
// 3. d. Calculate seed
{
bytes32 seedMask = bytes32(_mgf(maskedDb, 32));
seed = maskedSeed ^ seedMask;
}

// 3. e. Calculate DB mask
bytes memory dbMask = _mgf(abi.encodePacked(seed), cLen - 33);
// 3. e. Calculate DB mask
bytes memory dbMask = _mgf(abi.encodePacked(seed), cLen - 33);

// 3. f. Calculate DB
bytes memory db = _xor(maskedDb, dbMask);
uint256 dbWords = db.length / 32 + db.length % 32 == 0 ? 0 : 1;
// 3. f. Calculate DB
db = _xor(maskedDb, dbMask);
}

// 3. g. Separate DB into an octet string lHash' of length hLen, a
// (possibly empty) padding string PS consisting of octets
Expand All @@ -99,57 +100,37 @@ library RSAOAEP {
// Y is nonzero, output "decryption error" and stop.
bytes32 recoveredHash = bytes32(db);
bytes1 one;
assembly {
// Iterate over bytes after the label hash until hitting a non-zero byte
// Skip the first word since it is the recovered hash
// Identify the start index of the message within the db byte string
let m := 0
for { let w := 1 } lt(w, dbWords) { w := add(w, 1) } {
let word := mload(add(db, add(0x20, mul(w, 0x20))))
// Iterate over bytes in the word
for { let i := 0 } lt(i, 0x20) { i := 0x20 } {
switch byte(i, word)
case 0x00 { continue }
case 0x01 {
one := 0x01
m := add(add(i, 1), mul(sub(w, 1), 0x20))
break
}
default {
// Non-zero entry found before 0x01, revert
let p := mload(0x40)
mstore(p, "decryption error")
revert(p, 0x10)
}
}

// If the 0x01 byte has been found, exit the outer loop
switch one
case 0x01 { break }
uint256 m;

// Iterate over bytes after the label hash until hitting a non-zero byte
// Skip the first word since it is the recovered hash
// Identify the start index of the message within the db byte string
for (uint256 i = 32; i < db.length; i++) {
if (db[i] == 0x00) {
// Padding, continue
continue;
} else if (db[i] == 0x01) {
// We found the 0x01 byte, set the one flag and store the index of the next byte
one = 0x01;
m = i + 1;
break;
} else {
// Non-zero entry found before 0x01, revert
revert("decryption error");
}
}

// Check that m is not zero, otherwise revert
switch m
case 0x00 {
let p := mload(0x40)
mstore(p, "decryption error")
revert(p, 0x10)
}
// Check that m was found, otherwise revert
if (m == 0) revert("decryption error");

// Copy the message from the db bytes string
let len := sub(mload(db), m)
let wrds := div(len, 0x20)
switch mod(len, 0x20)
case 0x00 {}
default { wrds := add(wrds, 1) }
for { let w := 0 } lt(w, wrds) { w := add(w, 1) } {
let c := mload(add(db, add(m, mul(w, 0x20))))
let i := add(message, mul(w, 0x20))
mstore(i, c)
}
// Copy the message from the db bytes string
uint256 len = db.length - m;
message = new bytes(len);
for (uint256 i; i < len; i++) {
message[i] = db[m + i];
}

if (one != 0x01 || lhash != recoveredHash || y != 0x00) revert("decryption error");
if (one != bytes1(0x01) || lhash != recoveredHash || y != bytes1(0x00)) revert("final");

// 4. Return the message and seed used for encryption
}
Expand Down Expand Up @@ -185,7 +166,7 @@ library RSAOAEP {
// 2. d. Generate random byte string the same length as the hash function
bytes32 rand = sha256(abi.encodePacked(seed));

// 2. e. Let dbMask = MGF(seed, k - hLen - 1).
// 2. e. Let dbMask = MGF(seed, nLen - hLen - 1).
bytes memory dbMask = _mgf(abi.encodePacked(rand), nLen - 33);

// 2. f. Let maskedDB = DB \xor dbMask.
Expand Down Expand Up @@ -221,7 +202,7 @@ library RSAOAEP {
return modexp(encoded, e, n);
}

function _mgf(bytes memory seed, uint256 maskLen) internal pure returns (bytes memory) {
function _mgf(bytes memory seed, uint256 maskLen) public pure returns (bytes memory) {

Check warning on line 205 in src/lib/RSA.sol

View workflow job for this annotation

GitHub Actions / Foundry project

'_mgf' should not start with _
// Implements 8.2.1 MGF1 as defined in RFC8017: https://www.rfc-editor.org/rfc/rfc8017

// 1. Check that the mask length is not greater than 2^32 * hash length (32 bytes in this case)
Expand All @@ -240,8 +221,8 @@ library RSAOAEP {
// string T:
// T = T || Hash(mgfSeed || C) .

uint256 count = maskLen / 32 + (maskLen % 32 == 0 ? 0 : 1);
for (uint256 c; c < count; c++) {
uint32 count = uint32((maskLen / 32) + ((maskLen % 32) == 0 ? 0 : 1));
for (uint32 c; c < count; c++) {
bytes32 h = sha256(abi.encodePacked(seed, c));
assembly {
let p := add(add(t, 0x20), mul(c, 0x20))
Expand All @@ -253,19 +234,19 @@ library RSAOAEP {
return t;
}

function _xor(bytes memory first, bytes memory second) internal pure returns (bytes memory) {
function _xor(bytes memory first, bytes memory second) public pure returns (bytes memory) {

Check warning on line 237 in src/lib/RSA.sol

View workflow job for this annotation

GitHub Actions / Foundry project

'_xor' should not start with _
uint256 fLen = first.length;
uint256 sLen = second.length;
if (fLen != sLen) revert("xor: different lengths");

uint256 words = (fLen / 32) + (fLen % 32 == 0 ? 0 : 1);
uint256 words = (fLen / 32) + ((fLen % 32) == 0 ? 0 : 1);
bytes memory result = new bytes(fLen);

// Iterate through words in the byte strings and xor them one at a time, storing the result
assembly {
for { let i := 0 } lt(i, words) { i := add(i, 1) } {
let f := mload(add(first, mul(i, 0x20)))
let s := mload(add(second, mul(i, 0x20)))
let f := mload(add(add(first, 0x20), mul(i, 0x20)))
let s := mload(add(add(second, 0x20), mul(i, 0x20)))
mstore(add(add(result, 0x20), mul(i, 0x20)), xor(f, s))
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/modules/Auction.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ abstract contract Auction {

error Auction_InvalidLotId(uint96 lotId);

error Auction_InvalidBidId(uint96 lotId, uint256 bidId);
error Auction_InvalidBidId(uint96 lotId, uint96 bidId);

error Auction_OnlyMarketOwner();
error Auction_AmountLessThanMinimum();
Expand Down Expand Up @@ -113,7 +113,7 @@ abstract contract Auction {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) external virtual returns (uint256 bidId);
) external virtual returns (uint96 bidId);

/// @notice Cancel a bid
/// @dev The implementing function should handle the following:
Expand All @@ -127,7 +127,7 @@ abstract contract Auction {
/// @return bidAmount The amount of quote tokens to refund
function cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address bidder_
) external virtual returns (uint256 bidAmount);

Expand Down Expand Up @@ -343,7 +343,7 @@ abstract contract AuctionModule is Auction, Module {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) external override onlyInternal returns (uint256 bidId) {
) external override onlyInternal returns (uint96 bidId) {
// Standard validation
_revertIfLotInvalid(lotId_);
_revertIfBeforeLotStart(lotId_);
Expand Down Expand Up @@ -373,7 +373,7 @@ abstract contract AuctionModule is Auction, Module {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) internal virtual returns (uint256 bidId);
) internal virtual returns (uint96 bidId);

/// @inheritdoc Auction
/// @dev Implements a basic cancelBid function that:
Expand All @@ -394,7 +394,7 @@ abstract contract AuctionModule is Auction, Module {
/// - Updating the bid data
function cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) external override onlyInternal returns (uint256 bidAmount) {
// Standard validation
Expand All @@ -419,7 +419,7 @@ abstract contract AuctionModule is Auction, Module {
/// @return bidAmount The amount of quote tokens to refund
function _cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) internal virtual returns (uint256 bidAmount);

Expand Down Expand Up @@ -551,7 +551,7 @@ abstract contract AuctionModule is Auction, Module {
///
/// @param lotId_ The lot ID
/// @param bidId_ The bid ID
function _revertIfBidInvalid(uint96 lotId_, uint256 bidId_) internal view virtual;
function _revertIfBidInvalid(uint96 lotId_, uint96 bidId_) internal view virtual;

/// @notice Checks that `caller_` is the bid owner
/// @dev Should revert if `caller_` is not the bid owner
Expand All @@ -562,7 +562,7 @@ abstract contract AuctionModule is Auction, Module {
/// @param caller_ The caller
function _revertIfNotBidOwner(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) internal view virtual;

Expand All @@ -572,5 +572,5 @@ abstract contract AuctionModule is Auction, Module {
///
/// @param lotId_ The lot ID
/// @param bidId_ The bid ID
function _revertIfBidCancelled(uint96 lotId_, uint256 bidId_) internal view virtual;
function _revertIfBidCancelled(uint96 lotId_, uint96 bidId_) internal view virtual;
}
Loading

0 comments on commit 35d515c

Please sign in to comment.