Skip to content

Commit

Permalink
wrapper: Add modular arithmetic functions to ArbitraryPrecisionInteger
Browse files Browse the repository at this point in the history
  • Loading branch information
simonjbeaumont committed Oct 16, 2024
1 parent 21de58f commit 1556601
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 0 deletions.
113 changes: 113 additions & 0 deletions Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,119 @@ extension ArbitraryPrecisionInteger: Numeric {
}
}

// MARK: - Modular arithmetic

extension ArbitraryPrecisionInteger {
@usableFromInline
package func modulo(_ mod: ArbitraryPrecisionInteger, nonNegative: Bool = false) throws -> ArbitraryPrecisionInteger {
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
mod.withUnsafeBignumPointer { modPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
if nonNegative {
CCryptoBoringSSL_BN_nnmod(resultPtr, selfPtr, modPtr, bnCtx)
} else {
CCryptoBoringSSLShims_BN_mod(resultPtr, selfPtr, modPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func inverse(modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
mod.withUnsafeBignumPointer { modPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
CCryptoBoringSSL_BN_mod_inverse(resultPtr, selfPtr, modPtr, bnCtx)
}
}
}
}
guard rc != nil else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}


@usableFromInline
package static func inverse(lhs: ArbitraryPrecisionInteger, modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
try ArbitraryPrecisionInteger(lhs).inverse(modulo: mod)
}

@usableFromInline
package func add(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self + rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
return CCryptoBoringSSL_BN_mod_add(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func sub(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self - rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
CCryptoBoringSSL_BN_mod_sub(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func mul(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self * rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
return CCryptoBoringSSL_BN_mod_mul(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}
}

// MARK: - SignedNumeric

extension ArbitraryPrecisionInteger: SignedNumeric {
Expand Down
118 changes: 118 additions & 0 deletions Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,123 @@ final class ArbitraryPrecisionIntegerTests: XCTestCase {
XCTAssertEqual(try ArbitraryPrecisionInteger(bytes: bytes), integer)
}
}

func testMoudlo() throws {
typealias I = ArbitraryPrecisionInteger
typealias Vector = (input: I, mod: I, expectedResult: (standard: I, nonNegative: I))
for vector: Vector in [
(input: 0, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 1, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 2, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 3, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 4, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 5, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 7, mod: 5, expectedResult: (standard: 2, nonNegative: 2)),
(input: 7, mod: -5, expectedResult: (standard: 2, nonNegative: 2)),
(input: -7, mod: 5, expectedResult: (standard: -2, nonNegative: 3)),
(input: -7, mod: -5, expectedResult: (standard: -2, nonNegative: 3)),
] {
XCTAssertEqual(
try vector.input.modulo(vector.mod, nonNegative: false),
vector.expectedResult.standard,
"\(vector.input) (mod \(vector.mod))"
)
XCTAssertEqual(
try vector.input.modulo(vector.mod, nonNegative: true),
vector.expectedResult.nonNegative,
"\(vector.input) (nnmod \(vector.mod))"
)
}
}

func testModularInverse() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 3, mod: 7, expectedResult: .ok(5)),
(a: 10, mod: 17, expectedResult: .ok(12)),
(a: 7, mod: 26, expectedResult: .ok(15)),
(a: 7, mod: 7, expectedResult: .throwsError),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.inverse(modulo: vector.mod), expectedValue, "inverse(\(vector.a), modulo: \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.inverse(modulo: vector.mod), "inverse(\(vector.a), modulo: \(vector.mod)")
}
}
}

func testModularAddition() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
(a: -4, b: -3, mod: 5, expectedResult: .ok(3)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.add(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) + \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.add(vector.b, modulo: vector.mod), "\(vector.a) + \(vector.b) (mod \(vector.mod))")
}
}
}

func testModularSubtraction() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(1)),
(a: 3, b: 4, mod: 5, expectedResult: .ok(4)),
(a: 3, b: 4, mod: -5, expectedResult: .ok(4)),
(a: -3, b: 4, mod: 5, expectedResult: .ok(3)),
(a: 3, b: -4, mod: 5, expectedResult: .ok(2)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.sub(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) - \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.sub(vector.b, modulo: vector.mod), "\(vector.a) - \(vector.b) (mod \(vector.mod))")
}
}
}

func testModularMultiplication() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
(a: -4, b: -3, mod: 5, expectedResult: .ok(2)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.mul(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) × \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.mul(vector.b, modulo: vector.mod), "\(vector.a) × \(vector.b) (mod \(vector.mod))")
}
}
}
}
#endif // CRYPTO_IN_SWIFTPM && !CRYPTO_IN_SWIFTPM_FORCE_BUILD_API

0 comments on commit 1556601

Please sign in to comment.