Skip to content

Commit

Permalink
Use Math.binarySearch in YieldCurveLibrary.getRate
Browse files Browse the repository at this point in the history
  • Loading branch information
aviggiano committed Jan 9, 2024
1 parent 20810f6 commit e4fa220
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 20 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ Size V2 Solidity

## TODO before testnet

- implement YieldCurve.getRate as binary search
- finish invariant tests
- origination fee & loan fee
- test for dueDate NOW
Expand Down
20 changes: 20 additions & 0 deletions src/libraries/MathLibrary.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity 0.8.20;

import {FixedPointMathLib} from "@solady/utils/FixedPointMathLib.sol";
import {console2 as console} from "forge-std/console2.sol";

uint256 constant PERCENT = 1e18;

Expand All @@ -27,4 +28,23 @@ library Math {
function mulDivDown(uint256 x, uint256 y, uint256 z) public pure returns (uint256) {
return FixedPointMathLib.mulDiv(x, y, z);
}

function binarySearch(uint256[] memory array, uint256 value) public view returns (uint256 low, uint256 high) {
low = 0;
high = array.length - 1;
if (value < array[low] || value > array[high]) {
return (type(uint256).max, type(uint256).max);
}
while (low <= high) {
uint256 mid = (low + high) / 2;
if (array[mid] == value) {
return (mid, mid);
} else if (array[mid] < value) {
low = mid + 1;
} else {
high = mid - 1;
}
}
return (high, low);
}
}
29 changes: 10 additions & 19 deletions src/libraries/YieldCurveLibrary.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,24 @@ library YieldCurveLibrary {

function getRate(YieldCurve memory curveRelativeTime, uint256 dueDate) public view returns (uint256) {
if (dueDate < block.timestamp) revert Errors.PAST_DUE_DATE(dueDate);
uint256 deltaT = dueDate - block.timestamp;
uint256 interval = dueDate - block.timestamp;
uint256 length = curveRelativeTime.timeBuckets.length;
if (deltaT < curveRelativeTime.timeBuckets[0] || deltaT > curveRelativeTime.timeBuckets[length - 1]) {
if (interval < curveRelativeTime.timeBuckets[0] || interval > curveRelativeTime.timeBuckets[length - 1]) {
revert Errors.DUE_DATE_OUT_OF_RANGE(
deltaT, curveRelativeTime.timeBuckets[0], curveRelativeTime.timeBuckets[length - 1]
interval, curveRelativeTime.timeBuckets[0], curveRelativeTime.timeBuckets[length - 1]
);
} else {
uint256 minIndex = type(uint256).max;
uint256 maxIndex = type(uint256).max;
for (uint256 i = 0; i < length; ++i) {
if (curveRelativeTime.timeBuckets[i] <= deltaT) {
minIndex = i;
}
if (curveRelativeTime.timeBuckets[i] >= deltaT && maxIndex == type(uint256).max) {
maxIndex = i;
}
}
uint256 x0 = curveRelativeTime.timeBuckets[minIndex];
uint256 y0 = curveRelativeTime.rates[minIndex];
uint256 x1 = curveRelativeTime.timeBuckets[maxIndex];
uint256 y1 = curveRelativeTime.rates[maxIndex];
(uint256 low, uint256 high) = Math.binarySearch(curveRelativeTime.timeBuckets, interval);
uint256 x0 = curveRelativeTime.timeBuckets[low];
uint256 y0 = curveRelativeTime.rates[low];
uint256 x1 = curveRelativeTime.timeBuckets[high];
uint256 y1 = curveRelativeTime.rates[high];
// @audit Check the rounding direction, as this may lead debt rounding down
if (x1 != x0) {
if (y1 >= y0) {
return y0 + Math.mulDivDown(y1 - y0, deltaT - x0, x1 - x0);
return y0 + Math.mulDivDown(y1 - y0, interval - x0, x1 - x0);
} else {
return y0 - Math.mulDivDown(y0 - y1, deltaT - x0, x1 - x0);
return y0 - Math.mulDivDown(y0 - y1, interval - x0, x1 - x0);
}
} else {
return y0;
Expand Down
57 changes: 57 additions & 0 deletions test/libraries/MathLibrary.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,61 @@ contract MathTest is Test {
assertEq(Math.mulDivDown(3, 5, 4), 3);
assertEq(Math.mulDivDown(4, 5, 4), 5);
}

function test_Math_binarySearch_found() public {
uint256[] memory array = new uint256[](5);
array[0] = 10;
array[1] = 20;
array[2] = 30;
array[3] = 40;
array[4] = 50;
uint256 low;
uint256 high;
for (uint256 i = 0; i < array.length; i++) {
(low, high) = Math.binarySearch(array, array[i]);
assertEq(low, i);
assertEq(high, i);
}
}

function test_Math_binarySearch_not_found() public {
uint256[] memory array = new uint256[](5);
array[0] = 10;
array[1] = 20;
array[2] = 30;
array[3] = 40;
array[4] = 50;
uint256 low;
uint256 high;
(low, high) = Math.binarySearch(array, 0);
assertEq(low, type(uint256).max);
assertEq(high, type(uint256).max);
(low, high) = Math.binarySearch(array, 13);
assertEq(low, 0);
assertEq(high, 1);
(low, high) = Math.binarySearch(array, 17);
assertEq(low, 0);
assertEq(high, 1);
(low, high) = Math.binarySearch(array, 21);
assertEq(low, 1);
assertEq(high, 2);
(low, high) = Math.binarySearch(array, 29);
assertEq(low, 1);
assertEq(high, 2);
(low, high) = Math.binarySearch(array, 32);
assertEq(low, 2);
assertEq(high, 3);
(low, high) = Math.binarySearch(array, 37);
assertEq(low, 2);
assertEq(high, 3);
(low, high) = Math.binarySearch(array, 42);
assertEq(low, 3);
assertEq(high, 4);
(low, high) = Math.binarySearch(array, 45);
assertEq(low, 3);
assertEq(high, 4);
(low, high) = Math.binarySearch(array, 51);
assertEq(low, type(uint256).max);
assertEq(high, type(uint256).max);
}
}

0 comments on commit e4fa220

Please sign in to comment.