diff --git a/README.md b/README.md index 0f39b932..c2cdde2f 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,6 @@ Size V2 Solidity ## TODO before testnet -- implement YieldCurve.getRate as binary search - finish invariant tests - origination fee & loan fee - test for dueDate NOW diff --git a/src/libraries/MathLibrary.sol b/src/libraries/MathLibrary.sol index 88da4d5c..5d868079 100644 --- a/src/libraries/MathLibrary.sol +++ b/src/libraries/MathLibrary.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.20; import {FixedPointMathLib} from "@solady/utils/FixedPointMathLib.sol"; +import {console2 as console} from "forge-std/console2.sol"; uint256 constant PERCENT = 1e18; @@ -27,4 +28,23 @@ library Math { function mulDivDown(uint256 x, uint256 y, uint256 z) public pure returns (uint256) { return FixedPointMathLib.mulDiv(x, y, z); } + + function binarySearch(uint256[] memory array, uint256 value) public view returns (uint256 low, uint256 high) { + low = 0; + high = array.length - 1; + if (value < array[low] || value > array[high]) { + return (type(uint256).max, type(uint256).max); + } + while (low <= high) { + uint256 mid = (low + high) / 2; + if (array[mid] == value) { + return (mid, mid); + } else if (array[mid] < value) { + low = mid + 1; + } else { + high = mid - 1; + } + } + return (high, low); + } } diff --git a/src/libraries/YieldCurveLibrary.sol b/src/libraries/YieldCurveLibrary.sol index 3b9963af..ac302b00 100644 --- a/src/libraries/YieldCurveLibrary.sol +++ b/src/libraries/YieldCurveLibrary.sol @@ -29,33 +29,24 @@ library YieldCurveLibrary { function getRate(YieldCurve memory curveRelativeTime, uint256 dueDate) public view returns (uint256) { if (dueDate < block.timestamp) revert Errors.PAST_DUE_DATE(dueDate); - uint256 deltaT = dueDate - block.timestamp; + uint256 interval = dueDate - block.timestamp; uint256 length = curveRelativeTime.timeBuckets.length; - if (deltaT < curveRelativeTime.timeBuckets[0] || deltaT > curveRelativeTime.timeBuckets[length - 1]) { + if (interval < curveRelativeTime.timeBuckets[0] || interval > curveRelativeTime.timeBuckets[length - 1]) { revert Errors.DUE_DATE_OUT_OF_RANGE( - deltaT, curveRelativeTime.timeBuckets[0], curveRelativeTime.timeBuckets[length - 1] + interval, curveRelativeTime.timeBuckets[0], curveRelativeTime.timeBuckets[length - 1] ); } else { - uint256 minIndex = type(uint256).max; - uint256 maxIndex = type(uint256).max; - for (uint256 i = 0; i < length; ++i) { - if (curveRelativeTime.timeBuckets[i] <= deltaT) { - minIndex = i; - } - if (curveRelativeTime.timeBuckets[i] >= deltaT && maxIndex == type(uint256).max) { - maxIndex = i; - } - } - uint256 x0 = curveRelativeTime.timeBuckets[minIndex]; - uint256 y0 = curveRelativeTime.rates[minIndex]; - uint256 x1 = curveRelativeTime.timeBuckets[maxIndex]; - uint256 y1 = curveRelativeTime.rates[maxIndex]; + (uint256 low, uint256 high) = Math.binarySearch(curveRelativeTime.timeBuckets, interval); + uint256 x0 = curveRelativeTime.timeBuckets[low]; + uint256 y0 = curveRelativeTime.rates[low]; + uint256 x1 = curveRelativeTime.timeBuckets[high]; + uint256 y1 = curveRelativeTime.rates[high]; // @audit Check the rounding direction, as this may lead debt rounding down if (x1 != x0) { if (y1 >= y0) { - return y0 + Math.mulDivDown(y1 - y0, deltaT - x0, x1 - x0); + return y0 + Math.mulDivDown(y1 - y0, interval - x0, x1 - x0); } else { - return y0 - Math.mulDivDown(y0 - y1, deltaT - x0, x1 - x0); + return y0 - Math.mulDivDown(y0 - y1, interval - x0, x1 - x0); } } else { return y0; diff --git a/test/libraries/MathLibrary.t.sol b/test/libraries/MathLibrary.t.sol index ae9b1854..88f83b12 100644 --- a/test/libraries/MathLibrary.t.sol +++ b/test/libraries/MathLibrary.t.sol @@ -69,4 +69,61 @@ contract MathTest is Test { assertEq(Math.mulDivDown(3, 5, 4), 3); assertEq(Math.mulDivDown(4, 5, 4), 5); } + + function test_Math_binarySearch_found() public { + uint256[] memory array = new uint256[](5); + array[0] = 10; + array[1] = 20; + array[2] = 30; + array[3] = 40; + array[4] = 50; + uint256 low; + uint256 high; + for (uint256 i = 0; i < array.length; i++) { + (low, high) = Math.binarySearch(array, array[i]); + assertEq(low, i); + assertEq(high, i); + } + } + + function test_Math_binarySearch_not_found() public { + uint256[] memory array = new uint256[](5); + array[0] = 10; + array[1] = 20; + array[2] = 30; + array[3] = 40; + array[4] = 50; + uint256 low; + uint256 high; + (low, high) = Math.binarySearch(array, 0); + assertEq(low, type(uint256).max); + assertEq(high, type(uint256).max); + (low, high) = Math.binarySearch(array, 13); + assertEq(low, 0); + assertEq(high, 1); + (low, high) = Math.binarySearch(array, 17); + assertEq(low, 0); + assertEq(high, 1); + (low, high) = Math.binarySearch(array, 21); + assertEq(low, 1); + assertEq(high, 2); + (low, high) = Math.binarySearch(array, 29); + assertEq(low, 1); + assertEq(high, 2); + (low, high) = Math.binarySearch(array, 32); + assertEq(low, 2); + assertEq(high, 3); + (low, high) = Math.binarySearch(array, 37); + assertEq(low, 2); + assertEq(high, 3); + (low, high) = Math.binarySearch(array, 42); + assertEq(low, 3); + assertEq(high, 4); + (low, high) = Math.binarySearch(array, 45); + assertEq(low, 3); + assertEq(high, 4); + (low, high) = Math.binarySearch(array, 51); + assertEq(low, type(uint256).max); + assertEq(high, type(uint256).max); + } }