Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encryption Updates and Tests #26

Merged
merged 13 commits into from
Jan 26, 2024
18 changes: 9 additions & 9 deletions src/AuctionHouse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ abstract contract Router is FeeManager {
///
/// @param params_ Bid parameters
/// @return bidId Bid ID
function bid(BidParams memory params_) external virtual returns (uint256 bidId);
function bid(BidParams memory params_) external virtual returns (uint96 bidId);

/// @notice Cancel a bid on a lot in a batch auction
/// @dev The implementing function must perform the following:
Expand All @@ -139,7 +139,7 @@ abstract contract Router is FeeManager {
///
/// @param lotId_ Lot ID
/// @param bidId_ Bid ID
function cancelBid(uint96 lotId_, uint256 bidId_) external virtual;
function cancelBid(uint96 lotId_, uint96 bidId_) external virtual;

/// @notice Settle a batch auction
/// @notice This function is used for versions with on-chain storage and bids and local settlement
Expand Down Expand Up @@ -261,7 +261,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @return bool True if caller is allowed to purchase/bid on the lot
function _isAllowed(
IAllowlist allowlist_,
uint256 lotId_,
uint96 lotId_,
address caller_,
bytes memory allowlistProof_
) internal view returns (bool) {
Expand Down Expand Up @@ -356,7 +356,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
external
override
isLotValid(params_.lotId)
returns (uint256)
returns (uint96)
{
// Load routing data for the lot
Routing memory routing = lotRouting[params_.lotId];
Expand All @@ -368,7 +368,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {

// Record the bid on the auction module
// The module will determine if the bid is valid - minimum bid size, minimum price, auction status, etc
uint256 bidId;
uint96 bidId;
{
AuctionModule module = _getModuleForId(params_.lotId);
bidId = module.bid(
Expand Down Expand Up @@ -398,7 +398,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @dev This function reverts if:
/// - the lot ID is invalid
/// - the auction module reverts when cancelling the bid
function cancelBid(uint96 lotId_, uint256 bidId_) external override isLotValid(lotId_) {
function cancelBid(uint96 lotId_, uint96 bidId_) external override isLotValid(lotId_) {
// Cancel the bid on the auction module
// The auction module is responsible for validating the bid and authorizing the caller
AuctionModule module = _getModuleForId(lotId_);
Expand Down Expand Up @@ -505,7 +505,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param hooks_ Hooks contract to call (optional)
/// @param permit2Approval_ Permit2 approval data (optional)
function _collectPayment(
uint256 lotId_,
uint96 lotId_,
uint256 amount_,
ERC20 quoteToken_,
IHooks hooks_,
Expand Down Expand Up @@ -574,7 +574,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param payoutAmount_ Amount of payoutToken to collect (in native decimals)
/// @param routingParams_ Routing parameters for the lot
function _collectPayout(
uint256 lotId_,
uint96 lotId_,
uint256 paymentAmount_,
uint256 payoutAmount_,
Routing memory routingParams_
Expand Down Expand Up @@ -634,7 +634,7 @@ contract AuctionHouse is Derivatizer, Auctioneer, Router {
/// @param routingParams_ Routing parameters for the lot
/// @param auctionOutput_ Custom data returned by the auction module
function _sendPayout(
uint256 lotId_,
uint96 lotId_,
address recipient_,
uint256 payoutAmount_,
Routing memory routingParams_,
Expand Down
6 changes: 3 additions & 3 deletions src/interfaces/IHooks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ interface IHooks {

/// @notice Called before payment and payout
/// TODO define expected state, invariants
function pre(uint256 lotId_, uint256 amount_) external;
function pre(uint96 lotId_, uint256 amount_) external;

/// @notice Called after payment and before payout
/// TODO define expected state, invariants
function mid(uint256 lotId_, uint256 amount_, uint256 payout_) external;
function mid(uint96 lotId_, uint256 amount_, uint256 payout_) external;

/// @notice Called after payment and after payout
/// TODO define expected state, invariants
function post(uint256 lotId_, uint256 payout_) external;
function post(uint96 lotId_, uint256 payout_) external;
}
135 changes: 58 additions & 77 deletions src/lib/RSA.sol
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
bytes memory label
) internal view returns (bytes memory message, bytes32 seed) {
// Implements 7.1.2 RSAES-OAEP-DECRYPT as defined in RFC8017: https://www.rfc-editor.org/rfc/rfc8017
// Error messages are intentionally vague to prevent oracle attacks

// 1. Input length validation
// 1. a. If the length of L is greater than the input limitation
Expand Down Expand Up @@ -57,36 +56,38 @@

// 3. b. Separate encoded message into Y (1 byte) | maskedSeed (32 bytes) | maskedDB (cLen - 32 - 1)
bytes1 y = bytes1(encoded);
bytes32 maskedSeed;
uint256 words = (cLen - 33) / 32 + ((cLen - 33) % 32 == 0 ? 0 : 1);
bytes memory maskedDb = new bytes(cLen - 33);

assembly {
bytes memory db;
{
// Scope these local variables to avoid stack too deep later
bytes32 maskedSeed;
// uint256 words = ((cLen - 33) / 32) + (((cLen - 33) % 32) == 0 ? 0 : 1);
uint256 maskLen = cLen - 33;
bytes memory maskedDb = new bytes(maskLen);

// Load a word from the encoded string starting at the 2nd byte (also have to account for length stored in first slot)
maskedSeed := mload(add(encoded, 0x21))
assembly {
maskedSeed := mload(add(encoded, 0x21))
}

// Store the remaining bytes into the maskedDb
for { let i := 0 } lt(i, words) { i := add(i, 1) } {
mstore(
add(add(maskedDb, 0x20), mul(i, 0x20)),
mload(add(add(encoded, 0x41), mul(i, 0x20)))
)
for (uint256 i; i < maskLen; i++) {
maskedDb[i] = encoded[i + 33];
}
}

// 3. c. Calculate seed mask
// 3. d. Calculate seed
{
bytes32 seedMask = bytes32(_mgf(maskedDb, 32));
seed = maskedSeed ^ seedMask;
}
// 3. c. Calculate seed mask
// 3. d. Calculate seed
{
bytes32 seedMask = bytes32(_mgf(maskedDb, 32));
seed = maskedSeed ^ seedMask;
}

// 3. e. Calculate DB mask
bytes memory dbMask = _mgf(abi.encodePacked(seed), cLen - 33);
// 3. e. Calculate DB mask
bytes memory dbMask = _mgf(abi.encodePacked(seed), cLen - 33);

// 3. f. Calculate DB
bytes memory db = _xor(maskedDb, dbMask);
uint256 dbWords = db.length / 32 + db.length % 32 == 0 ? 0 : 1;
// 3. f. Calculate DB
db = _xor(maskedDb, dbMask);
}

// 3. g. Separate DB into an octet string lHash' of length hLen, a
// (possibly empty) padding string PS consisting of octets
Expand All @@ -99,57 +100,37 @@
// Y is nonzero, output "decryption error" and stop.
bytes32 recoveredHash = bytes32(db);
bytes1 one;
assembly {
// Iterate over bytes after the label hash until hitting a non-zero byte
// Skip the first word since it is the recovered hash
// Identify the start index of the message within the db byte string
let m := 0
for { let w := 1 } lt(w, dbWords) { w := add(w, 1) } {
let word := mload(add(db, add(0x20, mul(w, 0x20))))
// Iterate over bytes in the word
for { let i := 0 } lt(i, 0x20) { i := 0x20 } {
switch byte(i, word)
case 0x00 { continue }
case 0x01 {
one := 0x01
m := add(add(i, 1), mul(sub(w, 1), 0x20))
break
}
default {
// Non-zero entry found before 0x01, revert
let p := mload(0x40)
mstore(p, "decryption error")
revert(p, 0x10)
}
}

// If the 0x01 byte has been found, exit the outer loop
switch one
case 0x01 { break }
uint256 m;

// Iterate over bytes after the label hash until hitting a non-zero byte
// Skip the first word since it is the recovered hash
// Identify the start index of the message within the db byte string
for (uint256 i = 32; i < db.length; i++) {
if (db[i] == 0x00) {
// Padding, continue
continue;
} else if (db[i] == 0x01) {
// We found the 0x01 byte, set the one flag and store the index of the next byte
one = 0x01;
m = i + 1;
break;
} else {
// Non-zero entry found before 0x01, revert
revert("decryption error");
}
}

// Check that m is not zero, otherwise revert
switch m
case 0x00 {
let p := mload(0x40)
mstore(p, "decryption error")
revert(p, 0x10)
}
// Check that m was found, otherwise revert
if (m == 0) revert("decryption error");

// Copy the message from the db bytes string
let len := sub(mload(db), m)
let wrds := div(len, 0x20)
switch mod(len, 0x20)
case 0x00 {}
default { wrds := add(wrds, 1) }
for { let w := 0 } lt(w, wrds) { w := add(w, 1) } {
let c := mload(add(db, add(m, mul(w, 0x20))))
let i := add(message, mul(w, 0x20))
mstore(i, c)
}
// Copy the message from the db bytes string
uint256 len = db.length - m;
message = new bytes(len);
for (uint256 i; i < len; i++) {
message[i] = db[m + i];
}

if (one != 0x01 || lhash != recoveredHash || y != 0x00) revert("decryption error");
if (one != bytes1(0x01) || lhash != recoveredHash || y != bytes1(0x00)) revert("final");

// 4. Return the message and seed used for encryption
}
Expand Down Expand Up @@ -185,7 +166,7 @@
// 2. d. Generate random byte string the same length as the hash function
bytes32 rand = sha256(abi.encodePacked(seed));

// 2. e. Let dbMask = MGF(seed, k - hLen - 1).
// 2. e. Let dbMask = MGF(seed, nLen - hLen - 1).
bytes memory dbMask = _mgf(abi.encodePacked(rand), nLen - 33);

// 2. f. Let maskedDB = DB \xor dbMask.
Expand Down Expand Up @@ -221,7 +202,7 @@
return modexp(encoded, e, n);
}

function _mgf(bytes memory seed, uint256 maskLen) internal pure returns (bytes memory) {
function _mgf(bytes memory seed, uint256 maskLen) public pure returns (bytes memory) {

Check warning on line 205 in src/lib/RSA.sol

View workflow job for this annotation

GitHub Actions / Foundry project

'_mgf' should not start with _
// Implements 8.2.1 MGF1 as defined in RFC8017: https://www.rfc-editor.org/rfc/rfc8017

// 1. Check that the mask length is not greater than 2^32 * hash length (32 bytes in this case)
Expand All @@ -240,8 +221,8 @@
// string T:
// T = T || Hash(mgfSeed || C) .

uint256 count = maskLen / 32 + (maskLen % 32 == 0 ? 0 : 1);
for (uint256 c; c < count; c++) {
uint32 count = uint32((maskLen / 32) + ((maskLen % 32) == 0 ? 0 : 1));
for (uint32 c; c < count; c++) {
bytes32 h = sha256(abi.encodePacked(seed, c));
assembly {
let p := add(add(t, 0x20), mul(c, 0x20))
Expand All @@ -253,19 +234,19 @@
return t;
}

function _xor(bytes memory first, bytes memory second) internal pure returns (bytes memory) {
function _xor(bytes memory first, bytes memory second) public pure returns (bytes memory) {

Check warning on line 237 in src/lib/RSA.sol

View workflow job for this annotation

GitHub Actions / Foundry project

'_xor' should not start with _
uint256 fLen = first.length;
uint256 sLen = second.length;
if (fLen != sLen) revert("xor: different lengths");

uint256 words = (fLen / 32) + (fLen % 32 == 0 ? 0 : 1);
uint256 words = (fLen / 32) + ((fLen % 32) == 0 ? 0 : 1);
bytes memory result = new bytes(fLen);

// Iterate through words in the byte strings and xor them one at a time, storing the result
assembly {
for { let i := 0 } lt(i, words) { i := add(i, 1) } {
let f := mload(add(first, mul(i, 0x20)))
let s := mload(add(second, mul(i, 0x20)))
let f := mload(add(add(first, 0x20), mul(i, 0x20)))
let s := mload(add(add(second, 0x20), mul(i, 0x20)))
mstore(add(add(result, 0x20), mul(i, 0x20)), xor(f, s))
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/modules/Auction.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ abstract contract Auction {

error Auction_InvalidLotId(uint96 lotId);

error Auction_InvalidBidId(uint96 lotId, uint256 bidId);
error Auction_InvalidBidId(uint96 lotId, uint96 bidId);

error Auction_OnlyMarketOwner();
error Auction_AmountLessThanMinimum();
Expand Down Expand Up @@ -113,7 +113,7 @@ abstract contract Auction {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) external virtual returns (uint256 bidId);
) external virtual returns (uint96 bidId);

/// @notice Cancel a bid
/// @dev The implementing function should handle the following:
Expand All @@ -127,7 +127,7 @@ abstract contract Auction {
/// @return bidAmount The amount of quote tokens to refund
function cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address bidder_
) external virtual returns (uint256 bidAmount);

Expand Down Expand Up @@ -343,7 +343,7 @@ abstract contract AuctionModule is Auction, Module {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) external override onlyInternal returns (uint256 bidId) {
) external override onlyInternal returns (uint96 bidId) {
// Standard validation
_revertIfLotInvalid(lotId_);
_revertIfBeforeLotStart(lotId_);
Expand Down Expand Up @@ -373,7 +373,7 @@ abstract contract AuctionModule is Auction, Module {
address referrer_,
uint256 amount_,
bytes calldata auctionData_
) internal virtual returns (uint256 bidId);
) internal virtual returns (uint96 bidId);

/// @inheritdoc Auction
/// @dev Implements a basic cancelBid function that:
Expand All @@ -394,7 +394,7 @@ abstract contract AuctionModule is Auction, Module {
/// - Updating the bid data
function cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) external override onlyInternal returns (uint256 bidAmount) {
// Standard validation
Expand All @@ -419,7 +419,7 @@ abstract contract AuctionModule is Auction, Module {
/// @return bidAmount The amount of quote tokens to refund
function _cancelBid(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) internal virtual returns (uint256 bidAmount);

Expand Down Expand Up @@ -551,7 +551,7 @@ abstract contract AuctionModule is Auction, Module {
///
/// @param lotId_ The lot ID
/// @param bidId_ The bid ID
function _revertIfBidInvalid(uint96 lotId_, uint256 bidId_) internal view virtual;
function _revertIfBidInvalid(uint96 lotId_, uint96 bidId_) internal view virtual;

/// @notice Checks that `caller_` is the bid owner
/// @dev Should revert if `caller_` is not the bid owner
Expand All @@ -562,7 +562,7 @@ abstract contract AuctionModule is Auction, Module {
/// @param caller_ The caller
function _revertIfNotBidOwner(
uint96 lotId_,
uint256 bidId_,
uint96 bidId_,
address caller_
) internal view virtual;

Expand All @@ -572,5 +572,5 @@ abstract contract AuctionModule is Auction, Module {
///
/// @param lotId_ The lot ID
/// @param bidId_ The bid ID
function _revertIfBidCancelled(uint96 lotId_, uint256 bidId_) internal view virtual;
function _revertIfBidCancelled(uint96 lotId_, uint96 bidId_) internal view virtual;
}
Loading
Loading